File size: 932 Bytes
e6b17ef 1bce6d6 e6b17ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("AkshatSurolia/ICD-10-Code-Prediction")
model = AutoModelForSequenceClassification.from_pretrained("AkshatSurolia/ICD-10-Code-Prediction")
# Create a Streamlit input text box
input_text = st.text_input("Enter your text:")
# If input is provided
if input_text:
# Limit the input length
truncated_input = input_text[:512]
# Tokenize the input
tokens = tokenizer(truncated_input, truncation=True, padding=True, return_tensors="pt")
# Get model output
output = model(**tokens)
# The output of the model is a logits vector, so we take the argmax to get the predicted class index
predicted_class_idx = torch.argmax(output.logits, dim=-1).item()
st.write(f"Predicted class index: {predicted_class_idx}")
|