File size: 2,167 Bytes
1e1d5ae
 
9926c93
 
1e1d5ae
 
 
5a45b9a
9926c93
 
 
f76a5b5
 
002e0f5
f76a5b5
9926c93
 
1e1d5ae
 
5a45b9a
 
9926c93
5a45b9a
9926c93
 
 
 
1185cee
9926c93
 
 
 
 
 
 
 
 
 
1e1d5ae
 
 
 
 
5a45b9a
 
1e1d5ae
9926c93
1e1d5ae
 
 
 
002e0f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")

class Assistant:
    def __init__(self):
        model_name = "ruslanmv/Medical-Llama3-8B"
        device_map = 'auto'
        # bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.float16,)
        # self.model = AutoModelForCausalLM.from_pretrained( model_name,quantization_config=bnb_config, trust_remote_code=True,use_cache=False,device_map=device_map)
        self.model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True,use_cache=False,device_map=device_map)
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token


    def respond(
        self,
        message
    ):
        sys_message = ''' 
        You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
        provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
        '''   
        messages = [{"role": "system", "content": sys_message}, {"role": "user", "content": message}]
        
        # Applying chat template
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.tokenizer(prompt, return_tensors="pt").to("cuda")
        outputs = self.model.generate(**inputs, max_new_tokens=100, use_cache=True)
        
        # Extract and return the generated text, removing the prompt
        response_text = self.tokenizer.batch_decode(outputs)[0].strip()
        answer = response_text.split('<|im_start|>assistant')[-1].strip()
        return answer


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
assistant = Assistant()

demo = gr.ChatInterface(
    assistant.respond
)


if __name__ == "__main__":
    demo.launch()