Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,64 +1,113 @@
|
|
1 |
-
import
|
2 |
-
from
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
client =
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
):
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
""
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
)
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
|
64 |
-
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
|
2 |
+
from datasets import Dataset
|
3 |
+
from groq import Groq
|
4 |
+
import os
|
5 |
+
|
6 |
+
# Initialize Groq client with your API key
|
7 |
+
client = Groq(api_key="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")
|
8 |
+
|
9 |
+
# Paths to your books
|
10 |
+
book_paths = {
|
11 |
+
"DSM": "/content/Diagnostic and statistical manual of mental disorders _ DSM-5 ( PDFDrive.com ).pdf",
|
12 |
+
"Personality": "/content/b6c3v8_Theories_of_Personality_10.pdf",
|
13 |
+
"SearchForMeaning": "/content/Mans-Search-For-Meaning.pdf"
|
14 |
+
}
|
15 |
+
|
16 |
+
# Function to load and preprocess the data from books
|
17 |
+
def load_data(paths):
|
18 |
+
data = []
|
19 |
+
for title, path in paths.items():
|
20 |
+
with open(path, "r", encoding="utf-8", errors='ignore') as file:
|
21 |
+
text = file.read()
|
22 |
+
paragraphs = text.split("\n\n") # Split by paragraphs (adjust as needed)
|
23 |
+
for paragraph in paragraphs:
|
24 |
+
if paragraph.strip(): # Skip empty paragraphs
|
25 |
+
data.append({"text": paragraph.strip()})
|
26 |
+
return Dataset.from_list(data)
|
27 |
+
|
28 |
+
# Load and preprocess dataset for fine-tuning
|
29 |
+
dataset = load_data(book_paths)
|
30 |
+
|
31 |
+
# Load pretrained model and tokenizer from Hugging Face
|
32 |
+
model_name = "gpt2" # Replace with a larger model if needed and feasible
|
33 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
34 |
+
|
35 |
+
# Set the pad_token to be the same as eos_token (fix for missing padding token)
|
36 |
+
tokenizer.pad_token = tokenizer.eos_token
|
37 |
+
|
38 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
39 |
+
|
40 |
+
# Tokenize data and create labels (shifted input for causal language modeling)
|
41 |
+
def tokenize_function(examples):
|
42 |
+
# Tokenize the input text
|
43 |
+
encodings = tokenizer(examples["text"], truncation=True, padding=True, max_length=512)
|
44 |
+
|
45 |
+
# Create labels by shifting the input ids by one position (for causal LM)
|
46 |
+
labels = encodings["input_ids"].copy()
|
47 |
+
labels = [l if l != tokenizer.pad_token_id else -100 for l in labels]
|
48 |
+
|
49 |
+
# Return the encodings with labels
|
50 |
+
encodings["labels"] = labels
|
51 |
+
return encodings
|
52 |
+
|
53 |
+
tokenized_dataset = dataset.map(tokenize_function, batched=True)
|
54 |
+
|
55 |
+
# Split dataset into train and eval (explicit split for better validation)
|
56 |
+
train_test_split = tokenized_dataset.train_test_split(test_size=0.1)
|
57 |
+
train_dataset = train_test_split["train"]
|
58 |
+
eval_dataset = train_test_split["test"]
|
59 |
+
|
60 |
+
# Define training arguments
|
61 |
+
training_args = TrainingArguments(
|
62 |
+
output_dir="./results", # Output directory for model and logs
|
63 |
+
eval_strategy="epoch", # Use eval_strategy instead of evaluation_strategy
|
64 |
+
learning_rate=2e-5, # Learning rate
|
65 |
+
per_device_train_batch_size=8, # Batch size for training
|
66 |
+
per_device_eval_batch_size=8, # Batch size for evaluation
|
67 |
+
num_train_epochs=3, # Number of training epochs
|
68 |
+
weight_decay=0.01, # Weight decay for regularization
|
69 |
+
)
|
70 |
+
|
71 |
+
# Initialize the Trainer
|
72 |
+
trainer = Trainer(
|
73 |
+
model=model,
|
74 |
+
args=training_args,
|
75 |
+
train_dataset=train_dataset,
|
76 |
+
eval_dataset=eval_dataset, # Pass eval dataset for evaluation
|
77 |
+
tokenizer=tokenizer, # Provide tokenizer for model inference
|
78 |
)
|
79 |
|
80 |
+
# Fine-tune the model
|
81 |
+
trainer.train()
|
82 |
+
|
83 |
+
# Save the model after fine-tuning
|
84 |
+
model.save_pretrained("./fine_tuned_model")
|
85 |
+
tokenizer.save_pretrained("./fine_tuned_model")
|
86 |
+
|
87 |
+
# Step 4: Define response function with emergency keyword check
|
88 |
+
def get_response(user_input):
|
89 |
+
# Check for emergency/distress keywords
|
90 |
+
distress_keywords = ["hopeless", "emergency", "help", "crisis", "urgent"]
|
91 |
+
is_distress = any(word in user_input.lower() for word in distress_keywords)
|
92 |
+
|
93 |
+
# Use Groq API for generating a response
|
94 |
+
chat_completion = client.chat.completions.create(
|
95 |
+
messages=[{"role": "user", "content": user_input}],
|
96 |
+
model="llama3-8b-8192", # Or replace with another model
|
97 |
+
)
|
98 |
+
response = chat_completion.choices[0].message.content
|
99 |
+
|
100 |
+
# Append emergency message if distress keywords are detected
|
101 |
+
if is_distress:
|
102 |
+
response += "\n\nThis seems serious. Please consider reaching out to an emergency contact immediately. In case of an emergency, call [emergency number]."
|
103 |
+
|
104 |
+
return response
|
105 |
+
|
106 |
+
# Step 5: Set up Gradio Interface
|
107 |
+
import gradio as gr
|
108 |
+
|
109 |
+
def chatbot_interface(input_text):
|
110 |
+
return get_response(input_text)
|
111 |
|
112 |
+
# Launch the Gradio app
|
113 |
+
gr.Interface(fn=chatbot_interface, inputs="text", outputs="text", title="Virtual Psychiatrist Chatbot").launch()
|