wop commited on
Commit
5fba34b
·
verified ·
1 Parent(s): 0edd643

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -60
app.py CHANGED
@@ -1,64 +1,47 @@
1
- from huggingface_hub import InferenceClient
 
2
  import gradio as gr
3
- import json
4
 
5
- client = InferenceClient(
6
- "mistralai/Mistral-7B-Instruct-v0.3"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  )
8
 
9
- DATABASE_PATH = "database.json"
10
-
11
- def load_database():
12
- try:
13
- with open(DATABASE_PATH, "r") as file:
14
- return json.load(file)
15
- except FileNotFoundError:
16
- return {}
17
-
18
- def save_database(database):
19
- with open(DATABASE_PATH, "w") as file:
20
- json.dump(database, file)
21
-
22
- def format_prompt(message, history):
23
- prompt = "<s>"
24
- for user_prompt, bot_response in history:
25
- prompt += f"[INST] {user_prompt} [/INST]"
26
- prompt += f" {bot_response}</s> "
27
- prompt += f"[INST] {message} [/INST]"
28
- return prompt
29
-
30
- def generate(
31
- prompt, history, temperature=0.9, max_new_tokens=4096, top_p=0.9, repetition_penalty=1.2,
32
- ):
33
- database = load_database() # Load the database
34
- temperature = float(temperature)
35
- if temperature < 1e-2:
36
- temperature = 1e-2
37
- top_p = float(top_p)
38
-
39
- formatted_prompt = format_prompt(prompt, history)
40
- if formatted_prompt in database:
41
- response = database[formatted_prompt]
42
- else:
43
- response = client.text_generation(formatted_prompt, details=True, return_full_text=False)
44
- response_text = response.generated_text
45
- database[formatted_prompt] = response_text
46
- save_database(database) # Save the updated database
47
-
48
- yield response_text
49
-
50
- css = """
51
- #mkd {
52
- height: 500px;
53
- overflow: auto;
54
- border: 1px solid #ccc;
55
- }
56
- """
57
-
58
- with gr.Blocks(css=css) as demo:
59
- gr.ChatInterface(
60
- generate,
61
- examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."], ["Write a short story about Paris."]]
62
- )
63
-
64
- demo.launch(debug=True)
 
1
+ import torch
2
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
3
  import gradio as gr
 
4
 
5
+ # Check if a GPU is available and use it, otherwise use CPU
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ # Load the pre-trained model and tokenizer from the saved directory
9
+ model_path = "blexus_pretrained_test"
10
+ tokenizer = GPT2Tokenizer.from_pretrained(model_path)
11
+ model = GPT2LMHeadModel.from_pretrained(model_path).to(device)
12
+
13
+ # Set model to evaluation mode
14
+ model.eval()
15
+
16
+ # Function to generate text based on input prompt
17
+ def generate_text(prompt):
18
+ # Tokenize and encode the input prompt
19
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
20
+
21
+ # Generate continuation
22
+ with torch.no_grad():
23
+ generated_ids = model.generate(
24
+ input_ids,
25
+ max_length=50, # Maximum length of generated text
26
+ num_return_sequences=1, # Generate 1 sequence
27
+ pad_token_id=tokenizer.eos_token_id, # Use EOS token for padding
28
+ do_sample=True, # Enable sampling
29
+ top_k=50, # Top-k sampling
30
+ top_p=0.95 # Nucleus sampling
31
+ )
32
+
33
+ # Decode the generated text
34
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
35
+ return generated_text
36
+
37
+ # Create a Gradio interface
38
+ interface = gr.Interface(
39
+ fn=generate_text, # Function to call when interacting with the UI
40
+ inputs="text", # Input type: Single-line text
41
+ outputs="text", # Output type: Text (the generated output)
42
+ title="Quble Text Generation", # Title of the UI
43
+ description="Enter a prompt to generate text using Quble." # Simple description
44
  )
45
 
46
+ # Launch the Gradio app
47
+ interface.launch()