File size: 2,122 Bytes
885b434
 
b8774c1
9e25427
885b434
9e25427
cb34ab7
9e25427
5f17867
9e25427
b8774c1
7be036b
 
 
 
 
 
885b434
d48bb43
885b434
d48bb43
9e25427
885b434
d48bb43
 
 
9e25427
40c4a66
b9bec37
9e25427
885b434
 
 
 
9e25427
 
885b434
9e25427
 
 
885b434
091f6c9
943a7c3
1025e47
9e25427
885b434
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gradio as gr

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch

MODEL_URL = "kingabzpro/Llama-3.1-8B-Instruct-Mental-Health-Classification"

tokenizer = AutoTokenizer.from_pretrained(MODEL_URL)
tokenizer.pad_token_id = tokenizer.eos_token_id



model = AutoModelForCausalLM.from_pretrained(MODEL_URL, 
                                             low_cpu_mem_usage=True, 
                                             return_dict=True,
                                             torch_dtype=torch.float16,
                                             device_map="cpu")

def prediction(text):
    # create pipeline
    pipe = pipeline("text-generation", tokenizer=tokenizer, model=model, torch_dtype=torch.float16,
    device_map="cpu",)

    prompt = f"""Classify the text into Normal, Depression, Anxiety, Bipolar, and return the answer as the corresponding mental health disorder label.
    text: {text}
    label: """.strip()
    outputs = pipe(prompt, max_new_tokens=2, do_sample=True, temperature=0.1)
    preds = outputs[0]["generated_text"].split("label: ")[-1].strip()

    return preds


gradio_ui = gr.Interface(
    fn=prediction,
    title="Mental Health Disorder Classification",
    description=f"Input the text to generate a Mental Health Disorder.\n For this classification, the {MODEL_URL} model was used.",
    examples=[
        ['trouble sleeping, confused mind, restless heart. All out of tune'],
        ["In the quiet hours, even the shadows seem too heavy to bear."],
        ["Riding a tempest of emotions, where ecstatic highs crash into desolate lows without warning."]
    ],
    inputs=gr.Textbox(lines=10, label="Write the text here"),
    outputs=gr.Label(num_top_classes=4, label="Mental Health Disorder Category"),
    theme= gr.themes.Soft(),
    article="<p style='text-align: center'>Please read the tutorial to fine-tune the Llama 3.1 model on Mental Health Classification <a href='https://www.datacamp.com/tutorial/fine-tuning-llama-3-1' target='_blank'>https://www.datacamp.com/tutorial/fine-tuning-llama-3-1</a></p>",
)

gradio_ui.launch()