izh97's picture
Upload app.py
e8b0c24 verified
raw
history blame
1.47 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, GPTQConfig, TrainingArguments
from threading import Thread
from peft import AutoPeftModelForCausalLM
from transformers import GenerationConfig
tokenizer = AutoTokenizer.from_pretrained("izh97/zephyr-beta-climate-change-assistant")
model = AutoPeftModelForCausalLM.from_pretrained(
"izh97/zephyr-beta-climate-change-assistant",
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float16,
device_map="cuda")
model = model.to('cuda:0')
generation_config = GenerationConfig(
do_sample=True,
top_k=10,
temperature=0.2,
max_new_tokens=256,
pad_token_id=tokenizer.unk_token_id
)
def ask(text):
tokenizer = AutoTokenizer.from_pretrained("izh97/zephyr-beta-climate-change-assistant")
model = AutoPeftModelForCausalLM.from_pretrained(
"izh97/zephyr-beta-climate-change-assistant",
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float16,
device_map="cuda")
inputs = tokenizer.apply_chat_template(text, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
input_length = inputs.input_ids.shape[1]
outputs = model.generate(**inputs, generation_config=generation_config,
return_dict_in_generate=True)
tokens = outputs.sequences[0, input_length:]
return tokenizer.decode(tokens)