|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("AkshatSurolia/ICD-10-Code-Prediction") |
|
model = AutoModelForSequenceClassification.from_pretrained("AkshatSurolia/ICD-10-Code-Prediction") |
|
|
|
|
|
input_text = st.text_input("Enter your text:") |
|
|
|
|
|
if input_text: |
|
|
|
truncated_input = input_text[:512] |
|
|
|
|
|
tokens = tokenizer(truncated_input, truncation=True, padding=True, return_tensors="pt") |
|
|
|
|
|
output = model(**tokens) |
|
|
|
|
|
predicted_class_idx = torch.argmax(output.logits, dim=-1).item() |
|
|
|
st.write(f"Predicted class index: {predicted_class_idx}") |
|
|