llm_test2 / app.py
raduqus's picture
Update app.py
dc20801 verified
raw
history blame
1.84 kB
import gradio as gr
import numpy as np
import random
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Configuration
model_id = "raduqus/reco_1b_16bit"
MAX_SEED = np.iinfo(np.int32).max
# Load model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto"
)
def infer(
prompt,
seed=0,
randomize_seed=True,
max_length=100,
temperature=0.7,
top_p=0.9
):
# Seed handling
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Set random generator
generator = torch.Generator().manual_seed(seed)
# Generate
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
inputs.input_ids,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
generator=generator
)
# Decode
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text
# Create Gradio interface with API support
demo = gr.Interface(
fn=infer,
inputs=[
gr.Textbox(label="Prompt"),
gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=0),
gr.Checkbox(label="Randomize Seed", value=True),
gr.Slider(label="Max Length", minimum=10, maximum=200, value=100),
gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7),
gr.Slider(label="Top P", minimum=0.1, maximum=1.0, value=0.9)
],
outputs=gr.Textbox(label="Recommendation"),
title="Task Recommender",
description="Generate task recommendations"
)
# Enable API
demo.launch(
enable_queue=True, # Enable request queuing
server_port=7860,
api_open=True # Open API endpoint
)