RAMYASRI-39 commited on
Commit
dfae233
·
verified ·
1 Parent(s): 62a1b90

Update backend/query_llm.py

Browse files
Files changed (1) hide show
  1. backend/query_llm.py +176 -175
backend/query_llm.py CHANGED
@@ -1,175 +1,176 @@
1
-
2
-
3
- import openai
4
- import gradio as gr
5
-
6
- from os import getenv
7
- from typing import Any, Dict, Generator, List
8
-
9
- from huggingface_hub import InferenceClient
10
- from transformers import AutoTokenizer
11
- from gradio_client import Client
12
- #tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
13
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
14
- #tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x22B-Instruct-v0.1")
15
- temperature = 0.5
16
- top_p = 0.7
17
- repetition_penalty = 1.2
18
-
19
- OPENAI_KEY = getenv("OPENAI_API_KEY")
20
- HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
21
-
22
- # hf_client = InferenceClient(
23
- # "mistralai/Mistral-7B-Instruct-v0.1",
24
- # token=HF_TOKEN
25
- # )
26
-
27
- client = Client("Qwen/Qwen1.5-110B-Chat-demo")
28
- hf_client = InferenceClient(
29
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
30
- token=HF_TOKEN
31
- )
32
- def format_prompt(message: str, api_kind: str):
33
- """
34
- Formats the given message using a chat template.
35
-
36
- Args:
37
- message (str): The user message to be formatted.
38
-
39
- Returns:
40
- str: Formatted message after applying the chat template.
41
- """
42
-
43
- # Create a list of message dictionaries with role and content
44
- messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
45
-
46
- if api_kind == "openai":
47
- return messages
48
- elif api_kind == "hf":
49
- return tokenizer.apply_chat_template(messages, tokenize=False)
50
- elif api_kind:
51
- raise ValueError("API is not supported")
52
-
53
-
54
- def generate_hf(prompt: str, history: str, temperature: float = 0.5, max_new_tokens: int = 4000,
55
- top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
56
- """
57
- Generate a sequence of tokens based on a given prompt and history using Mistral client.
58
-
59
- Args:
60
- prompt (str): The initial prompt for the text generation.
61
- history (str): Context or history for the text generation.
62
- temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
63
- max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
64
- top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
65
- repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
66
-
67
- Returns:
68
- Generator[str, None, str]: A generator yielding chunks of generated text.
69
- Returns a final string if an error occurs.
70
- """
71
-
72
- temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
73
- top_p = float(top_p)
74
-
75
- generate_kwargs = {
76
- 'temperature': temperature,
77
- 'max_new_tokens': max_new_tokens,
78
- 'top_p': top_p,
79
- 'repetition_penalty': repetition_penalty,
80
- 'do_sample': True,
81
- 'seed': 42,
82
- }
83
-
84
- formatted_prompt = format_prompt(prompt, "hf")
85
-
86
- try:
87
- stream = hf_client.text_generation(formatted_prompt, **generate_kwargs,
88
- stream=True, details=True, return_full_text=False)
89
- output = ""
90
- for response in stream:
91
- output += response.token.text
92
- yield output
93
-
94
- except Exception as e:
95
- if "Too Many Requests" in str(e):
96
- print("ERROR: Too many requests on Mistral client")
97
- gr.Warning("Unfortunately Mistral is unable to process")
98
- return "Unfortunately, I am not able to process your request now."
99
- elif "Authorization header is invalid" in str(e):
100
- print("Authetification error:", str(e))
101
- gr.Warning("Authentication error: HF token was either not provided or incorrect")
102
- return "Authentication error"
103
- else:
104
- print("Unhandled Exception:", str(e))
105
- gr.Warning("Unfortunately Mistral is unable to process")
106
- return "I do not know what happened, but I couldn't understand you."
107
-
108
- def generate_qwen(formatted_prompt: str, history: str):
109
- response = client.predict(
110
- query=formatted_prompt,
111
- history=[],
112
- system='You are wonderful',
113
- api_name="/model_chat"
114
- )
115
- print('Response:',response)
116
-
117
- #return output
118
- #return response[1][0][1]
119
- return response[1][0][1]
120
-
121
-
122
-
123
- def generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 256,
124
- top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
125
- """
126
- Generate a sequence of tokens based on a given prompt and history using Mistral client.
127
-
128
- Args:
129
- prompt (str): The initial prompt for the text generation.
130
- history (str): Context or history for the text generation.
131
- temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
132
- max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
133
- top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
134
- repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
135
-
136
- Returns:
137
- Generator[str, None, str]: A generator yielding chunks of generated text.
138
- Returns a final string if an error occurs.
139
- """
140
-
141
- temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
142
- top_p = float(top_p)
143
-
144
- generate_kwargs = {
145
- 'temperature': temperature,
146
- 'max_tokens': max_new_tokens,
147
- 'top_p': top_p,
148
- 'frequency_penalty': max(-2., min(repetition_penalty, 2.)),
149
- }
150
-
151
- formatted_prompt = format_prompt(prompt, "openai")
152
-
153
- try:
154
- stream = openai.ChatCompletion.create(model="gpt-3.5-turbo-0301",
155
- messages=formatted_prompt,
156
- **generate_kwargs,
157
- stream=True)
158
- output = ""
159
- for chunk in stream:
160
- output += chunk.choices[0].delta.get("content", "")
161
- yield output
162
-
163
- except Exception as e:
164
- if "Too Many Requests" in str(e):
165
- print("ERROR: Too many requests on OpenAI client")
166
- gr.Warning("Unfortunately OpenAI is unable to process")
167
- return "Unfortunately, I am not able to process your request now."
168
- elif "You didn't provide an API key" in str(e):
169
- print("Authetification error:", str(e))
170
- gr.Warning("Authentication error: OpenAI key was either not provided or incorrect")
171
- return "Authentication error"
172
- else:
173
- print("Unhandled Exception:", str(e))
174
- gr.Warning("Unfortunately OpenAI is unable to process")
175
- return "I do not know what happened, but I couldn't understand you."
 
 
1
+
2
+
3
+ import openai
4
+ import gradio as gr
5
+
6
+ from os import getenv
7
+ from typing import Any, Dict, Generator, List
8
+
9
+ from huggingface_hub import InferenceClient
10
+ from transformers import AutoTokenizer
11
+ from gradio_client import Client
12
+ #tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
13
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1")
14
+ #tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x22B-Instruct-v0.1")
15
+ temperature = 0.5
16
+ top_p = 0.7
17
+ repetition_penalty = 1.2
18
+
19
+ OPENAI_KEY = getenv("OPENAI_API_KEY")
20
+ HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
21
+
22
+ # hf_client = InferenceClient(
23
+ # "mistralai/Mistral-7B-Instruct-v0.1",
24
+ # token=HF_TOKEN
25
+ # )
26
+
27
+ client = Client("Qwen/Qwen1.5-110B-Chat-demo")
28
+ hf_client=''
29
+ # hf_client = InferenceClient(
30
+ # "mistralai/Mixtral-8x7B-Instruct-v0.1",
31
+ # token=HF_TOKEN
32
+ # )
33
+ def format_prompt(message: str, api_kind: str):
34
+ """
35
+ Formats the given message using a chat template.
36
+
37
+ Args:
38
+ message (str): The user message to be formatted.
39
+
40
+ Returns:
41
+ str: Formatted message after applying the chat template.
42
+ """
43
+
44
+ # Create a list of message dictionaries with role and content
45
+ messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
46
+
47
+ if api_kind == "openai":
48
+ return messages
49
+ elif api_kind == "hf":
50
+ return tokenizer.apply_chat_template(messages, tokenize=False)
51
+ elif api_kind:
52
+ raise ValueError("API is not supported")
53
+
54
+
55
+ def generate_hf(prompt: str, history: str, temperature: float = 0.5, max_new_tokens: int = 4000,
56
+ top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
57
+ """
58
+ Generate a sequence of tokens based on a given prompt and history using Mistral client.
59
+
60
+ Args:
61
+ prompt (str): The initial prompt for the text generation.
62
+ history (str): Context or history for the text generation.
63
+ temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
64
+ max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
65
+ top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
66
+ repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
67
+
68
+ Returns:
69
+ Generator[str, None, str]: A generator yielding chunks of generated text.
70
+ Returns a final string if an error occurs.
71
+ """
72
+
73
+ temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
74
+ top_p = float(top_p)
75
+
76
+ generate_kwargs = {
77
+ 'temperature': temperature,
78
+ 'max_new_tokens': max_new_tokens,
79
+ 'top_p': top_p,
80
+ 'repetition_penalty': repetition_penalty,
81
+ 'do_sample': True,
82
+ 'seed': 42,
83
+ }
84
+
85
+ formatted_prompt = format_prompt(prompt, "hf")
86
+
87
+ try:
88
+ stream = hf_client.text_generation(formatted_prompt, **generate_kwargs,
89
+ stream=True, details=True, return_full_text=False)
90
+ output = ""
91
+ for response in stream:
92
+ output += response.token.text
93
+ yield output
94
+
95
+ except Exception as e:
96
+ if "Too Many Requests" in str(e):
97
+ print("ERROR: Too many requests on Mistral client")
98
+ gr.Warning("Unfortunately Mistral is unable to process")
99
+ return "Unfortunately, I am not able to process your request now."
100
+ elif "Authorization header is invalid" in str(e):
101
+ print("Authetification error:", str(e))
102
+ gr.Warning("Authentication error: HF token was either not provided or incorrect")
103
+ return "Authentication error"
104
+ else:
105
+ print("Unhandled Exception:", str(e))
106
+ gr.Warning("Unfortunately Mistral is unable to process")
107
+ return "I do not know what happened, but I couldn't understand you."
108
+
109
+ def generate_qwen(formatted_prompt: str, history: str):
110
+ response = client.predict(
111
+ query=formatted_prompt,
112
+ history=[],
113
+ system='You are wonderful',
114
+ api_name="/model_chat"
115
+ )
116
+ print('Response:',response)
117
+
118
+ #return output
119
+ #return response[1][0][1]
120
+ return response[1][0][1]
121
+
122
+
123
+
124
+ def generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 256,
125
+ top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
126
+ """
127
+ Generate a sequence of tokens based on a given prompt and history using Mistral client.
128
+
129
+ Args:
130
+ prompt (str): The initial prompt for the text generation.
131
+ history (str): Context or history for the text generation.
132
+ temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
133
+ max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
134
+ top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
135
+ repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
136
+
137
+ Returns:
138
+ Generator[str, None, str]: A generator yielding chunks of generated text.
139
+ Returns a final string if an error occurs.
140
+ """
141
+
142
+ temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
143
+ top_p = float(top_p)
144
+
145
+ generate_kwargs = {
146
+ 'temperature': temperature,
147
+ 'max_tokens': max_new_tokens,
148
+ 'top_p': top_p,
149
+ 'frequency_penalty': max(-2., min(repetition_penalty, 2.)),
150
+ }
151
+
152
+ formatted_prompt = format_prompt(prompt, "openai")
153
+
154
+ try:
155
+ stream = openai.ChatCompletion.create(model="gpt-3.5-turbo-0301",
156
+ messages=formatted_prompt,
157
+ **generate_kwargs,
158
+ stream=True)
159
+ output = ""
160
+ for chunk in stream:
161
+ output += chunk.choices[0].delta.get("content", "")
162
+ yield output
163
+
164
+ except Exception as e:
165
+ if "Too Many Requests" in str(e):
166
+ print("ERROR: Too many requests on OpenAI client")
167
+ gr.Warning("Unfortunately OpenAI is unable to process")
168
+ return "Unfortunately, I am not able to process your request now."
169
+ elif "You didn't provide an API key" in str(e):
170
+ print("Authetification error:", str(e))
171
+ gr.Warning("Authentication error: OpenAI key was either not provided or incorrect")
172
+ return "Authentication error"
173
+ else:
174
+ print("Unhandled Exception:", str(e))
175
+ gr.Warning("Unfortunately OpenAI is unable to process")
176
+ return "I do not know what happened, but I couldn't understand you."