CognitiveScience commited on
Commit
960cd05
·
1 Parent(s): 3516070

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -1
app.py CHANGED
@@ -24,11 +24,92 @@ import time
24
 
25
  from huggingface_hub import hf_hub_download
26
 
 
27
  #hf_hub_download(repo_id="CogSphere/aCogSphere", filename="./reviews.csv")
28
 
29
  from huggingface_hub import login
30
  from datasets import load_dataset
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  #dataset = load_dataset("csv", data_files="./data.csv")
33
 
34
 
@@ -256,5 +337,5 @@ scheduler2.start()
256
  scheduler3 = BackgroundScheduler()
257
  scheduler3.add_job(func=backup_db_csv, trigger="interval", seconds=3666)
258
  scheduler3.start()
259
-
260
  demo.launch()
 
24
 
25
  from huggingface_hub import hf_hub_download
26
 
27
+
28
  #hf_hub_download(repo_id="CogSphere/aCogSphere", filename="./reviews.csv")
29
 
30
  from huggingface_hub import login
31
  from datasets import load_dataset
32
 
33
+ client = InferenceClient(
34
+ "mistralai/Mistral-7B-Instruct-v0.1"
35
+ )
36
+
37
+
38
+ def format_prompt(message, history):
39
+ prompt = "<s>"
40
+ for user_prompt, bot_response in history:
41
+ prompt += f"[INST] {user_prompt} [/INST]"
42
+ prompt += f" {bot_response}</s> "
43
+ prompt += f"[INST] {message} [/INST]"
44
+ return prompt
45
+
46
+ def generate(
47
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
48
+ ):
49
+ temperature = float(temperature)
50
+ if temperature < 1e-2:
51
+ temperature = 1e-2
52
+ top_p = float(top_p)
53
+
54
+ generate_kwargs = dict(
55
+ temperature=temperature,
56
+ max_new_tokens=max_new_tokens,
57
+ top_p=top_p,
58
+ repetition_penalty=repetition_penalty,
59
+ do_sample=True,
60
+ seed=42,
61
+ )
62
+
63
+ formatted_prompt = format_prompt(prompt, history)
64
+
65
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
66
+ output = ""
67
+
68
+ for response in stream:
69
+ output += response.token.text
70
+ yield output
71
+ return output
72
+
73
+
74
+ additional_inputs=[
75
+ gr.Slider(
76
+ label="Temperature",
77
+ value=0.9,
78
+ minimum=0.0,
79
+ maximum=1.0,
80
+ step=0.05,
81
+ interactive=True,
82
+ info="Higher values produce more diverse outputs",
83
+ ),
84
+ gr.Slider(
85
+ label="Max new tokens",
86
+ value=256,
87
+ minimum=0,
88
+ maximum=5000,
89
+ step=64,
90
+ interactive=True,
91
+ info="The maximum numbers of new tokens",
92
+ ),
93
+ gr.Slider(
94
+ label="Top-p (nucleus sampling)",
95
+ value=0.90,
96
+ minimum=0.0,
97
+ maximum=1,
98
+ step=0.05,
99
+ interactive=True,
100
+ info="Higher values sample more low-probability tokens",
101
+ ),
102
+ gr.Slider(
103
+ label="Repetition penalty",
104
+ value=1.2,
105
+ minimum=1.0,
106
+ maximum=2.0,
107
+ step=0.05,
108
+ interactive=True,
109
+ info="Penalize repeated tokens",
110
+ )
111
+ ]
112
+
113
  #dataset = load_dataset("csv", data_files="./data.csv")
114
 
115
 
 
337
  scheduler3 = BackgroundScheduler()
338
  scheduler3.add_job(func=backup_db_csv, trigger="interval", seconds=3666)
339
  scheduler3.start()
340
+ demo.queue()
341
  demo.launch()