lmmithun commited on
Commit
86a2cca
·
verified ·
1 Parent(s): 85022cc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ import faiss
4
+ from transformers import DistilBertTokenizer, DistilBertModel
5
+ import streamlit as st
6
+ import numpy as np
7
+
8
+ # Initialize tokenizer and model
9
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
10
+ model = DistilBertModel.from_pretrained('distilbert-base-uncased')
11
+
12
+ # Load and preprocess drug names
13
+ def load_drug_names(file_path):
14
+ df = pd.read_csv(file_path)
15
+ if 'drug_name' in df.columns:
16
+ return df['drug_name'].str.lower().str.strip().tolist()
17
+ else:
18
+ st.error("Column 'drug_name' not found in the CSV file.")
19
+ st.stop()
20
+
21
+ # Get embeddings
22
+ def get_embeddings(texts):
23
+ inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
24
+ with torch.no_grad():
25
+ outputs = model(**inputs)
26
+ return outputs.last_hidden_state.mean(dim=1).numpy()
27
+
28
+ # Create FAISS index
29
+ def create_faiss_index(embeddings):
30
+ dimension = embeddings.shape[1]
31
+ index = faiss.IndexFlatL2(dimension)
32
+ index.add(embeddings)
33
+ return index
34
+
35
+ # Load FAISS index
36
+ def load_faiss_index(index_file):
37
+ return faiss.read_index(index_file)
38
+
39
+ # Check if FAISS index is empty
40
+ def is_faiss_index_empty(index_file):
41
+ try:
42
+ index = faiss.read_index(index_file)
43
+ return index.ntotal == 0
44
+ except:
45
+ return True
46
+
47
+ # Search FAISS index
48
+ def search_index(index, embedding, k=1):
49
+ distances, indices = index.search(embedding, k)
50
+ return distances, indices
51
+
52
+ # Load drug names
53
+ drug_names = load_drug_names('drug_names.csv')
54
+
55
+ # Check if FAISS index needs to be created
56
+ if is_faiss_index_empty('faiss_index.index'):
57
+ embeddings = get_embeddings(drug_names)
58
+ index = create_faiss_index(embeddings)
59
+ faiss.write_index(index, 'faiss_index.index')
60
+ else:
61
+ index = load_faiss_index('faiss_index.index')
62
+
63
+ # Streamlit app
64
+ st.title("Doctor's Handwritten Prescription Prediction")
65
+
66
+ # Single input prediction
67
+ single_drug_name = st.text_input("Enter the partial or misspelled drug name:")
68
+ if st.button("Predict Single Drug Name"):
69
+ if single_drug_name:
70
+ single_embedding = get_embeddings([single_drug_name.lower().strip()])
71
+ distances, indices = search_index(index, single_embedding)
72
+ closest_drug_name = drug_names[indices[0][0]]
73
+ st.write(f"Predicted Drug Name: {closest_drug_name}")
74
+ else:
75
+ st.write("Please enter a drug name to predict.")
76
+
77
+ # Batch prediction
78
+ st.header("Batch Prediction")
79
+ uploaded_pred_file = st.file_uploader("Choose a CSV file with predictions", type="csv")
80
+ if uploaded_pred_file is not None:
81
+ st.write("Uploaded prediction file preview:")
82
+ pred_df = pd.read_csv(uploaded_pred_file)
83
+ st.write(pred_df.head())
84
+
85
+ if 'predicted_drug_name' in pred_df.columns:
86
+ pred_texts = pred_df['predicted_drug_name'].str.lower().str.strip().tolist()
87
+ elif 'drug_name' in pred_df.columns:
88
+ pred_texts = pred_df['drug_name'].str.lower().str.strip().tolist()
89
+ else:
90
+ st.error("The CSV file must contain a column named 'predicted_drug_name' or 'drug_name'.")
91
+ st.stop()
92
+
93
+ pred_embeddings = get_embeddings(pred_texts)
94
+ predictions = []
95
+ for i, (pred_text, pred_embedding) in enumerate(zip(pred_texts, pred_embeddings), start=1):
96
+ distances, indices = search_index(index, np.expand_dims(pred_embedding, axis=0))
97
+ closest_drug_name = drug_names[indices[0][0]]
98
+ predictions.append((i, pred_text, closest_drug_name))
99
+
100
+ results_df = pd.DataFrame(predictions, columns=['Serial No', 'Original Prediction', 'Closest Drug Name'])
101
+ results_df.to_csv('predictions_with_matches.csv', index=False)
102
+ st.write("Batch prediction completed. You can download the results below.")
103
+ st.download_button(
104
+ label="Download Predictions",
105
+ data=results_df.to_csv(index=False).encode('utf-8'),
106
+ file_name='predictions_with_matches.csv',
107
+ mime='text/csv',
108
+ )