Mohammed-Altaf commited on
Commit
f56f478
·
1 Parent(s): be9ad62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -34
app.py CHANGED
@@ -1,44 +1,34 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
 
 
 
 
 
 
 
 
4
 
5
- model_id = None
6
 
7
- if model_id:
8
-
9
- model = AutoModelForCausalLM.from_pretrained(
10
- model_id,
11
- torch_dtype=torch.bfloat16,
12
- trust_remote_code=True,
13
- device_map="auto",
14
- low_cpu_mem_usage=True,
 
 
 
 
15
  )
16
- tokenizer = AutoTokenizer.from_pretrained(model_id)
17
- else:
18
- model = 'We have not downloaded model yet'
19
- tokenizer = 'we have not downloaded the tokenizer yet'
20
 
 
 
21
 
22
- def generate_text(input_text):
23
- if model_id:
24
- input_ids = tokenizer.encode(input_text, return_tensors="pt")
25
- attention_mask = torch.ones(input_ids.shape)
26
-
27
- output = model.generate(
28
- input_ids,
29
- attention_mask=attention_mask,
30
- max_length=200,
31
- do_sample=True,
32
- top_k=10,
33
- num_return_sequences=1,
34
- eos_token_id=tokenizer.eos_token_id,
35
- )
36
-
37
- output_text = tokenizer.decode(output[0], skip_special_tokens=True)
38
- print(output_text)
39
- else:
40
- output_text = model
41
- print(model)
42
  # Remove Prompt Echo from Generated Text
43
  cleaned_output_text = output_text.replace(input_text, "")
44
  return cleaned_output_text
@@ -50,5 +40,5 @@ text_generation_interface = gr.Interface(
50
  gr.inputs.Textbox(label="Input Text"),
51
  ],
52
  outputs=gr.inputs.Textbox(label="Generated Text"),
53
- title="Falcon-7B Instruct",
54
  ).launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
+ model_id = "Mohammed-Altaf/Medical-ChatBot"
5
+ model = AutoModelForCausalLM.from_pretrained(
6
+ model_id,
7
+ torch_dtype=torch.bfloat16,
8
+ trust_remote_code=True,
9
+ device_map="auto",
10
+ low_cpu_mem_usage=True,
11
+ )
12
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
13
 
 
14
 
15
+ def generate_text(input_text):
16
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
17
+ attention_mask = torch.ones(input_ids.shape)
18
+
19
+ output = model.generate(
20
+ input_ids,
21
+ attention_mask=attention_mask,
22
+ max_length=200,
23
+ do_sample=True,
24
+ top_k=10,
25
+ num_return_sequences=1,
26
+ eos_token_id=tokenizer.eos_token_id,
27
  )
 
 
 
 
28
 
29
+ output_text = tokenizer.decode(output[0], skip_special_tokens=True)
30
+ print(output_text)
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # Remove Prompt Echo from Generated Text
33
  cleaned_output_text = output_text.replace(input_text, "")
34
  return cleaned_output_text
 
40
  gr.inputs.Textbox(label="Input Text"),
41
  ],
42
  outputs=gr.inputs.Textbox(label="Generated Text"),
43
+ title="Medical ChatBot",
44
  ).launch()