futranbg's picture
Update app.py
39ff296
raw
history blame contribute delete
No virus
2.46 kB
import os
import time
import gradio as gr
from huggingface_hub import InferenceClient
bloom_repo = "bigscience/bloom"
bloom_template = """Text translation.
{source} text:
<s>{query}</s>
{target} translated text:
<s>"""
bloom_model_kwargs=dict(
max_new_tokens=1000,
temperature=0.3,
# truncate=1512,
seed=42,
stop_sequences=["</s>","<|endoftext|>","<|end|>"],
top_p=0.95,
repetition_penalty=1.1,
)
client = InferenceClient(model=bloom_repo, token=os.environ.get("HUGGINGFACEHUB_API_TOKEN", None))
def split_text_into_chunks(text, chunk_size=1000):
lines = text.split('\n')
chunks = []
chunk = ""
for line in lines:
# If adding the current line doesn't exceed the chunk size, add the line to the chunk
if len(chunk) + len(line) <= chunk_size:
chunk += line + "<newline>"
else:
# If adding the line exceeds chunk size, store the current chunk and start a new one
chunks.append(chunk)
chunk = line + "<newline>"
# Don't forget the last chunk
chunks.append(chunk)
return chunks
def translation(source, target, text):
output = ""
result = ""
chunks = split_text_into_chunks(text)
for chunk in chunks:
try:
input_prompt = bloom_template.replace("{source}", source)
input_prompt = input_prompt.replace("{target}", target)
input_prompt = input_prompt.replace("{query}", chunk)
stream = client.text_generation(input_prompt, stream=True, details=True, return_full_text=False, **bloom_model_kwargs)
for response in stream:
output += response.token.text
for stop_str in bloom_model_kwargs['stop_sequences']:
if output.endswith(stop_str):
output = output[:-len(stop_str)]
yield output.replace("<newline>","\n").replace("</newline>","\n")
#yield output.replace("<newline>","\n")
result += output
except Exception as e:
print(f"ERROR: LLM show {e}")
time.sleep(1)
#yield result.replace("<newline>","\n").strip()
if result == "": result = text
return result.replace("<newline>","\n").replace("</newline>","\n").strip()
gr.Interface(translation, inputs=["text","text","text"], outputs="text").queue(concurrency_count=100).launch()