vilarin commited on
Commit
c4c656e
·
verified ·
1 Parent(s): 2bea947

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -98
app.py CHANGED
@@ -1,82 +1,75 @@
1
  import torch
2
  from PIL import Image
3
  import gradio as gr
4
- #import spaces
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
6
  import os
7
  from threading import Thread
8
-
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
- MODEL_ID = "CohereForAI/aya-23-8B"
12
- MODEL_ID2 = "CohereForAI/aya-23-35B"
13
- MODELS = os.environ.get("MODELS")
14
- MODEL_NAME = MODELS.split("/")[-1]
15
-
16
- TITLE = "<h1><center>Aya-23-Chatbox</center></h1>"
17
-
18
- DESCRIPTION = f'<h3><center>MODEL: <a href="https://hf.co/{MODELS}">{MODEL_NAME}</a></center></h3>'
19
-
20
- CSS = """
21
- .duplicate-button {
22
- margin: auto !important;
23
- color: white !important;
24
- background: black !important;
25
- border-radius: 100vh !important;
26
- }
27
- """
28
-
29
-
30
- #QUANTIZE
31
- QUANTIZE_4BIT = True
32
- USE_GRAD_CHECKPOINTING = True
33
- TRAIN_BATCH_SIZE = 2
34
- TRAIN_MAX_SEQ_LENGTH = 512
35
- USE_FLASH_ATTENTION = False
36
- GRAD_ACC_STEPS = 16
37
-
38
- quantization_config = None
39
-
40
- if QUANTIZE_4BIT:
41
- quantization_config = BitsAndBytesConfig(
42
- load_in_4bit=True,
43
- bnb_4bit_quant_type="nf4",
44
- bnb_4bit_use_double_quant=True,
45
- bnb_4bit_compute_dtype=torch.bfloat16,
46
  )
 
47
 
48
- attn_implementation = None
49
- if USE_FLASH_ATTENTION:
50
- attn_implementation="flash_attention_2"
51
-
52
- model = AutoModelForCausalLM.from_pretrained(
53
- MODELS,
54
- quantization_config=quantization_config,
55
- attn_implementation=attn_implementation,
56
- torch_dtype=torch.bfloat16,
57
- device_map="auto",
58
- )
59
- tokenizer = AutoTokenizer.from_pretrained(MODELS)
60
-
61
- #@spaces.GPU()
62
- def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
63
- print(f'message is - {message}')
64
- print(f'history is - {history}')
65
- conversation = []
66
- for prompt, answer in history:
67
- conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
68
- conversation.append({"role": "user", "content": message})
69
-
70
- print(f"Conversation is -\n{conversation}")
71
 
72
- input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
75
 
76
  generate_kwargs = dict(
77
  input_ids=input_ids,
78
  streamer=streamer,
79
- max_new_tokens=max_new_tokens,
80
  do_sample=True,
81
  temperature=temperature,
82
  )
@@ -89,45 +82,71 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
89
  buffer += new_text
90
  yield buffer
91
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
-
94
- chatbot = gr.Chatbot(height=450)
95
-
96
- with gr.Blocks(css=CSS) as demo:
97
- gr.HTML(TITLE)
98
- gr.HTML(DESCRIPTION)
99
- gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
100
- gr.ChatInterface(
101
- fn=stream_chat,
102
- chatbot=chatbot,
103
- fill_height=True,
104
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
105
- additional_inputs=[
106
- gr.Slider(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  minimum=0,
108
  maximum=1,
 
109
  step=0.1,
110
- value=0.8,
111
- label="Temperature",
112
- render=False,
113
- ),
114
- gr.Slider(
115
- minimum=128,
116
- maximum=4096,
117
- step=1,
118
- value=1024,
119
- label="Max new tokens",
120
- render=False,
121
- ),
122
- ],
123
- examples=[
124
- ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
125
- ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
126
- ["Tell me a random fun fact about the Roman Empire."],
127
- ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
128
- ],
129
- cache_examples=False,
130
- )
131
 
132
 
133
  if __name__ == "__main__":
 
1
  import torch
2
  from PIL import Image
3
  import gradio as gr
4
+ import spaces
5
+ from transformers import LlamaForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
  import os
7
  from threading import Thread
8
+ from polyglot.detect import Detector
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ MODEL = "LLaMAX/LLaMAX3-8B-Alpaca"
12
+
13
+ TITLE = "<h1><center>LLaMAX3-8B-Translation</center></h1>"
14
+
15
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
16
+
17
+ model = LlamaForCausalLM.from_pretrained(
18
+ MODEL,
19
+ torch_dtype=torch.bfloat16,
20
+ device_map="auto",
21
+ quantization_config=quantization_config)
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
23
+
24
+
25
+ def lang_detector(text):
26
+ min_chars = 5
27
+ if len(text) < min_chars:
28
+ return "Input text too short"
29
+ try:
30
+ detector = Detector(text).language
31
+ lang_info = str(detector)
32
+ code = re.search(r"name: (\w+)", lang_info).group(1)
33
+ return code
34
+ except Exception as e:
35
+ return f"ERROR:{str(e)}"
36
+
37
+ def Prompt_template(query, src_language, trg_language):
38
+ instruction = f'Translate the following sentences from {src_language} to {trg_language}.'
39
+ prompt = (
40
+ 'Below is an instruction that describes a task, paired with an input that provides further context. '
41
+ 'Write a response that appropriately completes the request.\n'
42
+ f'### Instruction:\n{instruction}\n'
43
+ f'### Input:\n{query}\n### Response:'
 
 
44
  )
45
+ return prompt
46
 
47
+ # Unfinished
48
+ def chunk_text():
49
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ @spaces.GPU()
52
+ def translate(
53
+ source_text: str,
54
+ source_lang: str,
55
+ target_lang: str,
56
+ max_chunk: int,
57
+ max_length: int,
58
+ temperature: float):
59
+
60
+ print(f'Text is - {source_text}')
61
+
62
+ prompt = Prompt_template(source_text, source_lang, target_lang)
63
+ inputs = tokenizer(prompt, return_tensors="pt")
64
+
65
+ input_ids = inputs.to(model.device)
66
 
67
  streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
68
 
69
  generate_kwargs = dict(
70
  input_ids=input_ids,
71
  streamer=streamer,
72
+ max_length=max_length,
73
  do_sample=True,
74
  temperature=temperature,
75
  )
 
82
  buffer += new_text
83
  yield buffer
84
 
85
+ CSS = """
86
+ h1 {
87
+ text-align: center;
88
+ display: block;
89
+ height: 10vh;
90
+ align-content: center;
91
+ }
92
+ footer {
93
+ visibility: hidden;
94
+ }
95
+ """
96
 
97
+ chatbot = gr.Chatbot(height=600)
98
+
99
+ with gr.Blocks(theme="soft", css=CSS) as demo:
100
+ gr.Markdown(TITLE)
101
+ with gr.Row():
102
+ with gr.Column(scale=1):
103
+ source_lang = gr.Textbox(
104
+ label="Source Lang(Auto-Detect)",
105
+ value="English",
106
+ )
107
+ target_lang = gr.Textbox(
108
+ label="Target Lang",
109
+ value="Spanish",
110
+ )
111
+ max_chunk = gr.Slider(
112
+ label="Max tokens Per Chunk",
113
+ minimum=512,
114
+ maximum=2046,
115
+ value=1000,
116
+ step=8,
117
+ )
118
+ max_length = gr.Slider(
119
+ label="Context Window",
120
+ minimum=512,
121
+ maximum=8192,
122
+ value=4096,
123
+ step=8,
124
+ )
125
+ temperature = gr.Slider(
126
+ label="Temperature",
127
  minimum=0,
128
  maximum=1,
129
+ value=0.3,
130
  step=0.1,
131
+ )
132
+ with gr.Column(scale=4):
133
+ gr.Markdown(DESCRIPTION)
134
+ source_text = gr.Textbox(
135
+ label="Source Text",
136
+ value="How we live is so different from how we ought to live that he who studies "+\
137
+ "what ought to be done rather than what is done will learn the way to his downfall "+\
138
+ "rather than to his preservation.",
139
+ lines=10,
140
+ )
141
+ output_text = gr.Textbox(
142
+ label="Output Text",
143
+ lines=10,
144
+ )
145
+ with gr.Row():
146
+ submit = gr.Button(value="Submit")
147
+ clear = gr.ClearButton([source_text, output_text])
148
+
149
+ submit.click(fn=huanik, inputs=[source_lang, target_lang, source_text, max_chunk, max_length, temperature], outputs=[output_text])
 
 
150
 
151
 
152
  if __name__ == "__main__":