File size: 3,746 Bytes
0e74df5
e520569
0e74df5
 
 
 
 
94e5cfe
e3be437
c6fbe87
e3be437
 
0e74df5
c6fbe87
 
 
 
 
 
 
 
 
 
 
 
 
 
0e74df5
d0d5021
 
 
 
 
00cdb9a
d0d5021
 
7d6bd00
d0d5021
 
 
 
 
 
 
 
 
 
7d6bd00
d0d5021
 
7d6bd00
d0d5021
 
 
 
0e74df5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3be437
0e74df5
 
 
 
 
 
 
 
 
 
 
 
 
00cdb9a
0e74df5
d0d5021
 
0e74df5
 
 
 
 
 
 
 
 
 
 
 
 
00cdb9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e74df5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
import os
from huggingface_hub import InferenceClient

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""

model_name = "meta-llama/Llama-3.2-1B"
huggingface_token = os.getenv("SECRET_ENV_VARIABLE")
#client = InferenceClient(api_key=huggingface_token)
client = InferenceClient(model=model_name, token=huggingface_token)

'''
import requests

API_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-3.2-1B"
headers = {"Authorization": "Bearer "}

def query(payload):
	response = requests.post(API_URL, headers=headers, json=payload)
	return response.json()
	
output = query({
	"inputs": "Can you please let us know more details about your ",
})
'''

def generate_text(
    prompt,
    system_message,
    max_tokens,
    temperature,
    top_p
):
    try:
        print(f"Attempting to generate text for prompt: {prompt[:50]}...")
        
        response = client.text_generation(
            prompt,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_k=50,
            top_p=top_p,
            do_sample=True
        )
        
        print(f"Generated text: {response[:100]}...")
        return response
    except Exception as e:
        print(f"Error in generate_text: {type(e).__name__}: {str(e)}")
        return f"An error occurred: {type(e).__name__}: {str(e)}"



def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content

        response += token
        yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface

demo = gr.ChatInterface(
    #respond,
    generate_text,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)
"""

with gr.Tab("Generate Email"):
    Query = gr.Textbox(label="Query")
    generate_button = gr.Button("Ask Query")
    output = gr.Textbox(label="Generated Answer", lines=10)
    
    generate_button.click(generate_text, 
    #inputs=[industry, recipient_role, company_details], 
        additional_inputs=[
            Query,
            gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
            gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
            gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
            gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.95,
                step=0.05,
                label="Top-p (nucleus sampling)",
            ),
        ],
    
    
    outputs=output)



if __name__ == "__main__":
    demo.launch()