import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# --------------------------------------------------------------------------- | |
# 1) Load the model and tokenizer | |
# --------------------------------------------------------------------------- | |
# If you want to load in 8-bit or 4-bit precision with bitsandbytes, | |
# uncomment and install bitsandbytes, and set load_in_8bit=True or load_in_4bit=True. | |
# For example: | |
# | |
# from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
# bnb_config = BitsAndBytesConfig( | |
# load_in_4bit=True, # or load_in_8bit=True | |
# bnb_4bit_compute_dtype=torch.float16, # recommended compute dtype | |
# bnb_4bit_use_double_quant=True, # optional | |
# bnb_4bit_quant_type='nf4', # optional | |
# ) | |
# | |
# model = AutoModelForCausalLM.from_pretrained( | |
# "cheberle/autotrain-35swc-b4r9z", | |
# quantization_config=bnb_config, | |
# device_map="auto", | |
# trust_remote_code=True | |
# ) | |
# tokenizer = AutoTokenizer.from_pretrained("cheberle/autotrain-35swc-b4r9z", trust_remote_code=True) | |
# For a standard FP16 or FP32 load (no bitsandbytes): | |
model = AutoModelForCausalLM.from_pretrained( | |
"cheberle/autotrain-35swc-b4r9z", | |
torch_dtype=torch.float16, # Or "auto", or float32 | |
trust_remote_code=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"cheberle/autotrain-35swc-b4r9z", | |
trust_remote_code=True | |
) | |
# --------------------------------------------------------------------------- | |
# 2) Define a text generation function | |
# --------------------------------------------------------------------------- | |
def generate_text(prompt): | |
# Tokenize input | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Generate output (configure generation args as needed) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=128, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
# Decode | |
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return decoded | |
# --------------------------------------------------------------------------- | |
# 3) Create the Gradio interface | |
# --------------------------------------------------------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("<h3>Demo: cheberle/autotrain-35swc-b4r9z</h3>") | |
with gr.Row(): | |
with gr.Column(): | |
prompt_in = gr.Textbox( | |
lines=5, | |
label="Enter your prompt", | |
placeholder="Ask something here..." | |
) | |
submit_btn = gr.Button("Generate") | |
with gr.Column(): | |
output_box = gr.Textbox(lines=15, label="Model Output") | |
# Define what happens on button click | |
submit_btn.click(fn=generate_text, inputs=prompt_in, outputs=output_box) | |
# --------------------------------------------------------------------------- | |
# 4) Launch! | |
# --------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
demo.launch() |