File size: 2,462 Bytes
48f8356
ff26b87
48f8356
e5db612
48f8356
e5db612
48f8356
69acd7a
1bf3861
7c99e18
6bf3d15
7c99e18
f4c633e
e5db612
 
 
 
 
 
 
 
 
48f8356
e5db612
2354351
4517722
 
2354351
4517722
2354351
 
4517722
e761cd8
2354351
 
4517722
e761cd8
2354351
4517722
2354351
77e2e43
48f8356
e5db612
 
2354351
 
 
 
 
4517722
e5db612
 
 
 
 
 
39ff296
e5db612
 
2354351
 
e5db612
 
 
39ff296
48f8356
e5db612
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import time
import gradio as gr
from huggingface_hub import InferenceClient

bloom_repo = "bigscience/bloom"

bloom_template = """Text translation.
{source} text:
<s>{query}</s>
{target} translated text:
<s>"""

bloom_model_kwargs=dict(
            max_new_tokens=1000,
            temperature=0.3,
#            truncate=1512,
            seed=42,
            stop_sequences=["</s>","<|endoftext|>","<|end|>"],
            top_p=0.95,
            repetition_penalty=1.1,
            )

client = InferenceClient(model=bloom_repo, token=os.environ.get("HUGGINGFACEHUB_API_TOKEN", None))

def split_text_into_chunks(text, chunk_size=1000):
    lines = text.split('\n')
    chunks = []
    chunk = ""
    for line in lines:
        # If adding the current line doesn't exceed the chunk size, add the line to the chunk
        if len(chunk) + len(line) <= chunk_size:
            chunk += line + "<newline>"
        else:
            # If adding the line exceeds chunk size, store the current chunk and start a new one
            chunks.append(chunk)
            chunk = line + "<newline>"
    # Don't forget the last chunk
    chunks.append(chunk)
    return chunks
    
def translation(source, target, text):
    output = ""
    result = ""
    chunks = split_text_into_chunks(text)
    for chunk in chunks:
        try:
            input_prompt = bloom_template.replace("{source}", source)
            input_prompt = input_prompt.replace("{target}", target)
            input_prompt = input_prompt.replace("{query}", chunk)
            stream = client.text_generation(input_prompt, stream=True, details=True, return_full_text=False, **bloom_model_kwargs)
            for response in stream:
                output += response.token.text
                for stop_str in bloom_model_kwargs['stop_sequences']:
                    if output.endswith(stop_str):
                        output = output[:-len(stop_str)]
                    yield output.replace("<newline>","\n").replace("</newline>","\n")
            #yield output.replace("<newline>","\n")
            result += output
        except Exception as e:
            print(f"ERROR: LLM show {e}")
        time.sleep(1)
    #yield result.replace("<newline>","\n").strip()
    if result == "": result = text
    return result.replace("<newline>","\n").replace("</newline>","\n").strip()

gr.Interface(translation, inputs=["text","text","text"], outputs="text").queue(concurrency_count=100).launch()