import pandas as pd import torch import faiss from transformers import DistilBertTokenizer, DistilBertModel import streamlit as st import numpy as np # Initialize tokenizer and model tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') model = DistilBertModel.from_pretrained('distilbert-base-uncased') # Load and preprocess drug names def load_drug_names(file_path): df = pd.read_csv(file_path) if 'drug_name' in df.columns: return df['drug_name'].str.lower().str.strip().tolist() else: st.error("Column 'drug_name' not found in the CSV file.") st.stop() # Get embeddings def get_embeddings(texts): inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True) with torch.no_grad(): outputs = model(**inputs) return outputs.last_hidden_state.mean(dim=1).numpy() # Create FAISS index def create_faiss_index(embeddings): dimension = embeddings.shape[1] index = faiss.IndexFlatL2(dimension) index.add(embeddings) return index # Load FAISS index def load_faiss_index(index_file): return faiss.read_index(index_file) # Check if FAISS index is empty def is_faiss_index_empty(index_file): try: index = faiss.read_index(index_file) return index.ntotal == 0 except: return True # Search FAISS index def search_index(index, embedding, k=1): distances, indices = index.search(embedding, k) return distances, indices # Load drug names drug_names = load_drug_names('drug_names.csv') # Check if FAISS index needs to be created if is_faiss_index_empty('faiss_index.index'): embeddings = get_embeddings(drug_names) index = create_faiss_index(embeddings) faiss.write_index(index, 'faiss_index.index') else: index = load_faiss_index('faiss_index.index') # Streamlit app st.title("Doctor's Handwritten Prescription Prediction") # Single input prediction single_drug_name = st.text_input("Enter the partial or misspelled drug name:") if st.button("Predict Single Drug Name"): if single_drug_name: single_embedding = get_embeddings([single_drug_name.lower().strip()]) distances, indices = search_index(index, single_embedding) closest_drug_name = drug_names[indices[0][0]] st.write(f"Predicted Drug Name: {closest_drug_name}") else: st.write("Please enter a drug name to predict.") # Batch prediction st.header("Batch Prediction") uploaded_pred_file = st.file_uploader("Choose a CSV file with predictions", type="csv") if uploaded_pred_file is not None: st.write("Uploaded prediction file preview:") pred_df = pd.read_csv(uploaded_pred_file) st.write(pred_df.head()) if 'predicted_drug_name' in pred_df.columns: pred_texts = pred_df['predicted_drug_name'].str.lower().str.strip().tolist() elif 'drug_name' in pred_df.columns: pred_texts = pred_df['drug_name'].str.lower().str.strip().tolist() else: st.error("The CSV file must contain a column named 'predicted_drug_name' or 'drug_name'.") st.stop() pred_embeddings = get_embeddings(pred_texts) predictions = [] for i, (pred_text, pred_embedding) in enumerate(zip(pred_texts, pred_embeddings), start=1): distances, indices = search_index(index, np.expand_dims(pred_embedding, axis=0)) closest_drug_name = drug_names[indices[0][0]] predictions.append((i, pred_text, closest_drug_name)) results_df = pd.DataFrame(predictions, columns=['Serial No', 'Original Prediction', 'Closest Drug Name']) results_df.to_csv('predictions_with_matches.csv', index=False) st.write("Batch prediction completed. You can download the results below.") st.download_button( label="Download Predictions", data=results_df.to_csv(index=False).encode('utf-8'), file_name='predictions_with_matches.csv', mime='text/csv', )