|
import pandas as pd |
|
import torch |
|
import faiss |
|
from transformers import DistilBertTokenizer, DistilBertModel |
|
import streamlit as st |
|
import numpy as np |
|
|
|
|
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') |
|
model = DistilBertModel.from_pretrained('distilbert-base-uncased') |
|
|
|
|
|
def load_drug_names(file_path): |
|
df = pd.read_csv(file_path) |
|
if 'drug_name' in df.columns: |
|
return df['drug_name'].str.lower().str.strip().tolist() |
|
else: |
|
st.error("Column 'drug_name' not found in the CSV file.") |
|
st.stop() |
|
|
|
|
|
def get_embeddings(texts): |
|
inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
return outputs.last_hidden_state.mean(dim=1).numpy() |
|
|
|
|
|
def create_faiss_index(embeddings): |
|
dimension = embeddings.shape[1] |
|
index = faiss.IndexFlatL2(dimension) |
|
index.add(embeddings) |
|
return index |
|
|
|
|
|
def load_faiss_index(index_file): |
|
return faiss.read_index(index_file) |
|
|
|
|
|
def is_faiss_index_empty(index_file): |
|
try: |
|
index = faiss.read_index(index_file) |
|
return index.ntotal == 0 |
|
except: |
|
return True |
|
|
|
|
|
def search_index(index, embedding, k=1): |
|
distances, indices = index.search(embedding, k) |
|
return distances, indices |
|
|
|
|
|
drug_names = load_drug_names('drug_names.csv') |
|
|
|
|
|
if is_faiss_index_empty('faiss_index.index'): |
|
embeddings = get_embeddings(drug_names) |
|
index = create_faiss_index(embeddings) |
|
faiss.write_index(index, 'faiss_index.index') |
|
else: |
|
index = load_faiss_index('faiss_index.index') |
|
|
|
|
|
st.title("Doctor's Handwritten Prescription Prediction") |
|
|
|
|
|
single_drug_name = st.text_input("Enter the partial or misspelled drug name:") |
|
if st.button("Predict Single Drug Name"): |
|
if single_drug_name: |
|
single_embedding = get_embeddings([single_drug_name.lower().strip()]) |
|
distances, indices = search_index(index, single_embedding) |
|
closest_drug_name = drug_names[indices[0][0]] |
|
st.write(f"Predicted Drug Name: {closest_drug_name}") |
|
else: |
|
st.write("Please enter a drug name to predict.") |
|
|
|
|
|
st.header("Batch Prediction") |
|
uploaded_pred_file = st.file_uploader("Choose a CSV file with predictions", type="csv") |
|
if uploaded_pred_file is not None: |
|
st.write("Uploaded prediction file preview:") |
|
pred_df = pd.read_csv(uploaded_pred_file) |
|
st.write(pred_df.head()) |
|
|
|
if 'predicted_drug_name' in pred_df.columns: |
|
pred_texts = pred_df['predicted_drug_name'].str.lower().str.strip().tolist() |
|
elif 'drug_name' in pred_df.columns: |
|
pred_texts = pred_df['drug_name'].str.lower().str.strip().tolist() |
|
else: |
|
st.error("The CSV file must contain a column named 'predicted_drug_name' or 'drug_name'.") |
|
st.stop() |
|
|
|
pred_embeddings = get_embeddings(pred_texts) |
|
predictions = [] |
|
for i, (pred_text, pred_embedding) in enumerate(zip(pred_texts, pred_embeddings), start=1): |
|
distances, indices = search_index(index, np.expand_dims(pred_embedding, axis=0)) |
|
closest_drug_name = drug_names[indices[0][0]] |
|
predictions.append((i, pred_text, closest_drug_name)) |
|
|
|
results_df = pd.DataFrame(predictions, columns=['Serial No', 'Original Prediction', 'Closest Drug Name']) |
|
results_df.to_csv('predictions_with_matches.csv', index=False) |
|
st.write("Batch prediction completed. You can download the results below.") |
|
st.download_button( |
|
label="Download Predictions", |
|
data=results_df.to_csv(index=False).encode('utf-8'), |
|
file_name='predictions_with_matches.csv', |
|
mime='text/csv', |
|
) |
|
|