FAISS / app.py
lmmithun's picture
Create app.py
86a2cca verified
raw
history blame
3.91 kB
import pandas as pd
import torch
import faiss
from transformers import DistilBertTokenizer, DistilBertModel
import streamlit as st
import numpy as np
# Initialize tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
# Load and preprocess drug names
def load_drug_names(file_path):
df = pd.read_csv(file_path)
if 'drug_name' in df.columns:
return df['drug_name'].str.lower().str.strip().tolist()
else:
st.error("Column 'drug_name' not found in the CSV file.")
st.stop()
# Get embeddings
def get_embeddings(texts):
inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).numpy()
# Create FAISS index
def create_faiss_index(embeddings):
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index
# Load FAISS index
def load_faiss_index(index_file):
return faiss.read_index(index_file)
# Check if FAISS index is empty
def is_faiss_index_empty(index_file):
try:
index = faiss.read_index(index_file)
return index.ntotal == 0
except:
return True
# Search FAISS index
def search_index(index, embedding, k=1):
distances, indices = index.search(embedding, k)
return distances, indices
# Load drug names
drug_names = load_drug_names('drug_names.csv')
# Check if FAISS index needs to be created
if is_faiss_index_empty('faiss_index.index'):
embeddings = get_embeddings(drug_names)
index = create_faiss_index(embeddings)
faiss.write_index(index, 'faiss_index.index')
else:
index = load_faiss_index('faiss_index.index')
# Streamlit app
st.title("Doctor's Handwritten Prescription Prediction")
# Single input prediction
single_drug_name = st.text_input("Enter the partial or misspelled drug name:")
if st.button("Predict Single Drug Name"):
if single_drug_name:
single_embedding = get_embeddings([single_drug_name.lower().strip()])
distances, indices = search_index(index, single_embedding)
closest_drug_name = drug_names[indices[0][0]]
st.write(f"Predicted Drug Name: {closest_drug_name}")
else:
st.write("Please enter a drug name to predict.")
# Batch prediction
st.header("Batch Prediction")
uploaded_pred_file = st.file_uploader("Choose a CSV file with predictions", type="csv")
if uploaded_pred_file is not None:
st.write("Uploaded prediction file preview:")
pred_df = pd.read_csv(uploaded_pred_file)
st.write(pred_df.head())
if 'predicted_drug_name' in pred_df.columns:
pred_texts = pred_df['predicted_drug_name'].str.lower().str.strip().tolist()
elif 'drug_name' in pred_df.columns:
pred_texts = pred_df['drug_name'].str.lower().str.strip().tolist()
else:
st.error("The CSV file must contain a column named 'predicted_drug_name' or 'drug_name'.")
st.stop()
pred_embeddings = get_embeddings(pred_texts)
predictions = []
for i, (pred_text, pred_embedding) in enumerate(zip(pred_texts, pred_embeddings), start=1):
distances, indices = search_index(index, np.expand_dims(pred_embedding, axis=0))
closest_drug_name = drug_names[indices[0][0]]
predictions.append((i, pred_text, closest_drug_name))
results_df = pd.DataFrame(predictions, columns=['Serial No', 'Original Prediction', 'Closest Drug Name'])
results_df.to_csv('predictions_with_matches.csv', index=False)
st.write("Batch prediction completed. You can download the results below.")
st.download_button(
label="Download Predictions",
data=results_df.to_csv(index=False).encode('utf-8'),
file_name='predictions_with_matches.csv',
mime='text/csv',
)