Zeeshan42 commited on
Commit
4936193
·
verified ·
1 Parent(s): 5f12af9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -61
app.py CHANGED
@@ -1,64 +1,113 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- if __name__ == "__main__":
64
- demo.launch()
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
2
+ from datasets import Dataset
3
+ from groq import Groq
4
+ import os
5
+
6
+ # Initialize Groq client with your API key
7
+ client = Groq(api_key="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")
8
+
9
+ # Paths to your books
10
+ book_paths = {
11
+ "DSM": "/content/Diagnostic and statistical manual of mental disorders _ DSM-5 ( PDFDrive.com ).pdf",
12
+ "Personality": "/content/b6c3v8_Theories_of_Personality_10.pdf",
13
+ "SearchForMeaning": "/content/Mans-Search-For-Meaning.pdf"
14
+ }
15
+
16
+ # Function to load and preprocess the data from books
17
+ def load_data(paths):
18
+ data = []
19
+ for title, path in paths.items():
20
+ with open(path, "r", encoding="utf-8", errors='ignore') as file:
21
+ text = file.read()
22
+ paragraphs = text.split("\n\n") # Split by paragraphs (adjust as needed)
23
+ for paragraph in paragraphs:
24
+ if paragraph.strip(): # Skip empty paragraphs
25
+ data.append({"text": paragraph.strip()})
26
+ return Dataset.from_list(data)
27
+
28
+ # Load and preprocess dataset for fine-tuning
29
+ dataset = load_data(book_paths)
30
+
31
+ # Load pretrained model and tokenizer from Hugging Face
32
+ model_name = "gpt2" # Replace with a larger model if needed and feasible
33
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
34
+
35
+ # Set the pad_token to be the same as eos_token (fix for missing padding token)
36
+ tokenizer.pad_token = tokenizer.eos_token
37
+
38
+ model = AutoModelForCausalLM.from_pretrained(model_name)
39
+
40
+ # Tokenize data and create labels (shifted input for causal language modeling)
41
+ def tokenize_function(examples):
42
+ # Tokenize the input text
43
+ encodings = tokenizer(examples["text"], truncation=True, padding=True, max_length=512)
44
+
45
+ # Create labels by shifting the input ids by one position (for causal LM)
46
+ labels = encodings["input_ids"].copy()
47
+ labels = [l if l != tokenizer.pad_token_id else -100 for l in labels]
48
+
49
+ # Return the encodings with labels
50
+ encodings["labels"] = labels
51
+ return encodings
52
+
53
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
54
+
55
+ # Split dataset into train and eval (explicit split for better validation)
56
+ train_test_split = tokenized_dataset.train_test_split(test_size=0.1)
57
+ train_dataset = train_test_split["train"]
58
+ eval_dataset = train_test_split["test"]
59
+
60
+ # Define training arguments
61
+ training_args = TrainingArguments(
62
+ output_dir="./results", # Output directory for model and logs
63
+ eval_strategy="epoch", # Use eval_strategy instead of evaluation_strategy
64
+ learning_rate=2e-5, # Learning rate
65
+ per_device_train_batch_size=8, # Batch size for training
66
+ per_device_eval_batch_size=8, # Batch size for evaluation
67
+ num_train_epochs=3, # Number of training epochs
68
+ weight_decay=0.01, # Weight decay for regularization
69
+ )
70
+
71
+ # Initialize the Trainer
72
+ trainer = Trainer(
73
+ model=model,
74
+ args=training_args,
75
+ train_dataset=train_dataset,
76
+ eval_dataset=eval_dataset, # Pass eval dataset for evaluation
77
+ tokenizer=tokenizer, # Provide tokenizer for model inference
78
  )
79
 
80
+ # Fine-tune the model
81
+ trainer.train()
82
+
83
+ # Save the model after fine-tuning
84
+ model.save_pretrained("./fine_tuned_model")
85
+ tokenizer.save_pretrained("./fine_tuned_model")
86
+
87
+ # Step 4: Define response function with emergency keyword check
88
+ def get_response(user_input):
89
+ # Check for emergency/distress keywords
90
+ distress_keywords = ["hopeless", "emergency", "help", "crisis", "urgent"]
91
+ is_distress = any(word in user_input.lower() for word in distress_keywords)
92
+
93
+ # Use Groq API for generating a response
94
+ chat_completion = client.chat.completions.create(
95
+ messages=[{"role": "user", "content": user_input}],
96
+ model="llama3-8b-8192", # Or replace with another model
97
+ )
98
+ response = chat_completion.choices[0].message.content
99
+
100
+ # Append emergency message if distress keywords are detected
101
+ if is_distress:
102
+ response += "\n\nThis seems serious. Please consider reaching out to an emergency contact immediately. In case of an emergency, call [emergency number]."
103
+
104
+ return response
105
+
106
+ # Step 5: Set up Gradio Interface
107
+ import gradio as gr
108
+
109
+ def chatbot_interface(input_text):
110
+ return get_response(input_text)
111
 
112
+ # Launch the Gradio app
113
+ gr.Interface(fn=chatbot_interface, inputs="text", outputs="text", title="Virtual Psychiatrist Chatbot").launch()