import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM # --------------------------------------------------------------------------- # 1) Load the model and tokenizer # --------------------------------------------------------------------------- # If you want to load in 8-bit or 4-bit precision with bitsandbytes, # uncomment and install bitsandbytes, and set load_in_8bit=True or load_in_4bit=True. # For example: # # from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # bnb_config = BitsAndBytesConfig( # load_in_4bit=True, # or load_in_8bit=True # bnb_4bit_compute_dtype=torch.float16, # recommended compute dtype # bnb_4bit_use_double_quant=True, # optional # bnb_4bit_quant_type='nf4', # optional # ) # # model = AutoModelForCausalLM.from_pretrained( # "cheberle/autotrain-35swc-b4r9z", # quantization_config=bnb_config, # device_map="auto", # trust_remote_code=True # ) # tokenizer = AutoTokenizer.from_pretrained("cheberle/autotrain-35swc-b4r9z", trust_remote_code=True) # For a standard FP16 or FP32 load (no bitsandbytes): model = AutoModelForCausalLM.from_pretrained( "cheberle/autotrain-35swc-b4r9z", torch_dtype=torch.float16, # Or "auto", or float32 trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( "cheberle/autotrain-35swc-b4r9z", trust_remote_code=True ) # --------------------------------------------------------------------------- # 2) Define a text generation function # --------------------------------------------------------------------------- def generate_text(prompt): # Tokenize input inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate output (configure generation args as needed) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=128, temperature=0.7, top_p=0.9, do_sample=True ) # Decode decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) return decoded # --------------------------------------------------------------------------- # 3) Create the Gradio interface # --------------------------------------------------------------------------- with gr.Blocks() as demo: gr.Markdown("