File size: 3,905 Bytes
86a2cca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import pandas as pd
import torch
import faiss
from transformers import DistilBertTokenizer, DistilBertModel
import streamlit as st
import numpy as np
# Initialize tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
# Load and preprocess drug names
def load_drug_names(file_path):
df = pd.read_csv(file_path)
if 'drug_name' in df.columns:
return df['drug_name'].str.lower().str.strip().tolist()
else:
st.error("Column 'drug_name' not found in the CSV file.")
st.stop()
# Get embeddings
def get_embeddings(texts):
inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).numpy()
# Create FAISS index
def create_faiss_index(embeddings):
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index
# Load FAISS index
def load_faiss_index(index_file):
return faiss.read_index(index_file)
# Check if FAISS index is empty
def is_faiss_index_empty(index_file):
try:
index = faiss.read_index(index_file)
return index.ntotal == 0
except:
return True
# Search FAISS index
def search_index(index, embedding, k=1):
distances, indices = index.search(embedding, k)
return distances, indices
# Load drug names
drug_names = load_drug_names('drug_names.csv')
# Check if FAISS index needs to be created
if is_faiss_index_empty('faiss_index.index'):
embeddings = get_embeddings(drug_names)
index = create_faiss_index(embeddings)
faiss.write_index(index, 'faiss_index.index')
else:
index = load_faiss_index('faiss_index.index')
# Streamlit app
st.title("Doctor's Handwritten Prescription Prediction")
# Single input prediction
single_drug_name = st.text_input("Enter the partial or misspelled drug name:")
if st.button("Predict Single Drug Name"):
if single_drug_name:
single_embedding = get_embeddings([single_drug_name.lower().strip()])
distances, indices = search_index(index, single_embedding)
closest_drug_name = drug_names[indices[0][0]]
st.write(f"Predicted Drug Name: {closest_drug_name}")
else:
st.write("Please enter a drug name to predict.")
# Batch prediction
st.header("Batch Prediction")
uploaded_pred_file = st.file_uploader("Choose a CSV file with predictions", type="csv")
if uploaded_pred_file is not None:
st.write("Uploaded prediction file preview:")
pred_df = pd.read_csv(uploaded_pred_file)
st.write(pred_df.head())
if 'predicted_drug_name' in pred_df.columns:
pred_texts = pred_df['predicted_drug_name'].str.lower().str.strip().tolist()
elif 'drug_name' in pred_df.columns:
pred_texts = pred_df['drug_name'].str.lower().str.strip().tolist()
else:
st.error("The CSV file must contain a column named 'predicted_drug_name' or 'drug_name'.")
st.stop()
pred_embeddings = get_embeddings(pred_texts)
predictions = []
for i, (pred_text, pred_embedding) in enumerate(zip(pred_texts, pred_embeddings), start=1):
distances, indices = search_index(index, np.expand_dims(pred_embedding, axis=0))
closest_drug_name = drug_names[indices[0][0]]
predictions.append((i, pred_text, closest_drug_name))
results_df = pd.DataFrame(predictions, columns=['Serial No', 'Original Prediction', 'Closest Drug Name'])
results_df.to_csv('predictions_with_matches.csv', index=False)
st.write("Batch prediction completed. You can download the results below.")
st.download_button(
label="Download Predictions",
data=results_df.to_csv(index=False).encode('utf-8'),
file_name='predictions_with_matches.csv',
mime='text/csv',
)
|