import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM # --------------------------------------------------------------------------- # 1) Load the model and tokenizer # --------------------------------------------------------------------------- # If you want to load in 8-bit or 4-bit precision with bitsandbytes, # uncomment and install bitsandbytes, and set load_in_8bit=True or load_in_4bit=True. # For example: # # from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # bnb_config = BitsAndBytesConfig( # load_in_4bit=True, # or load_in_8bit=True # bnb_4bit_compute_dtype=torch.float16, # recommended compute dtype # bnb_4bit_use_double_quant=True, # optional # bnb_4bit_quant_type='nf4', # optional # ) # # model = AutoModelForCausalLM.from_pretrained( # "cheberle/autotrain-35swc-b4r9z", # quantization_config=bnb_config, # device_map="auto", # trust_remote_code=True # ) # tokenizer = AutoTokenizer.from_pretrained("cheberle/autotrain-35swc-b4r9z", trust_remote_code=True) # For a standard FP16 or FP32 load (no bitsandbytes): model = AutoModelForCausalLM.from_pretrained( "cheberle/autotrain-35swc-b4r9z", torch_dtype=torch.float16, # Or "auto", or float32 trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( "cheberle/autotrain-35swc-b4r9z", trust_remote_code=True ) # --------------------------------------------------------------------------- # 2) Define a text generation function # --------------------------------------------------------------------------- def generate_text(prompt): # Tokenize input inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate output (configure generation args as needed) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=128, temperature=0.7, top_p=0.9, do_sample=True ) # Decode decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) return decoded # --------------------------------------------------------------------------- # 3) Create the Gradio interface # --------------------------------------------------------------------------- with gr.Blocks() as demo: gr.Markdown("

Demo: cheberle/autotrain-35swc-b4r9z

") with gr.Row(): with gr.Column(): prompt_in = gr.Textbox( lines=5, label="Enter your prompt", placeholder="Ask something here..." ) submit_btn = gr.Button("Generate") with gr.Column(): output_box = gr.Textbox(lines=15, label="Model Output") # Define what happens on button click submit_btn.click(fn=generate_text, inputs=prompt_in, outputs=output_box) # --------------------------------------------------------------------------- # 4) Launch! # --------------------------------------------------------------------------- if __name__ == "__main__": demo.launch()