File size: 932 Bytes
e6b17ef
 
 
1bce6d6
e6b17ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("AkshatSurolia/ICD-10-Code-Prediction")
model = AutoModelForSequenceClassification.from_pretrained("AkshatSurolia/ICD-10-Code-Prediction")

# Create a Streamlit input text box
input_text = st.text_input("Enter your text:")

# If input is provided
if input_text:
    # Limit the input length
    truncated_input = input_text[:512]
    
    # Tokenize the input
    tokens = tokenizer(truncated_input, truncation=True, padding=True, return_tensors="pt")

    # Get model output
    output = model(**tokens)

    # The output of the model is a logits vector, so we take the argmax to get the predicted class index
    predicted_class_idx = torch.argmax(output.logits, dim=-1).item()
    
    st.write(f"Predicted class index: {predicted_class_idx}")