Navyabhat's picture
Upload 3 files
486ca0f verified
raw
history blame
3.07 kB
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
pipeline
)
model_name = "RaviNaik/Phi2-Osst"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
device_map=device
)
model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, device_map=device)
tokenizer.pad_token = tokenizer.eos_token
chat_template = """<|im_start|>system
You are a helpful assistant who always respond to user queries<|im_end|>
<im_start>user
{prompt}<|im_end|>
<|im_start|>assistant
"""
def generate(prompt, max_length, temperature, num_samples):
prompt = prompt.strip()
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=max_length, temperature=temperature, num_return_sequences=num_samples)
# result = pipe(chat_template.format(prompt=prompt))
result = pipe(prompt)
return {output: result}
with gr.Blocks() as app:
gr.Markdown("## ERA Session27 - Phi2 Model Finetuning with QLoRA on OpenAssistant Conversations Dataset (OASST1)")
gr.Markdown(
"""This is an implementation of [Phi2](https://huggingface.co/microsoft/phi-2) model finetuning using QLoRA stratergy on [OpenAssistant Conversations Dataset (OASST1)](https://huggingface.co/datasets/OpenAssistant/oasst1)
Please find the source code and training details [here](https://github.com/RaviNaik/ERA-SESSION27).
Dataset used to finetune: [OpenAssistant Conversations Dataset (OASST1)](https://huggingface.co/datasets/OpenAssistant/oasst1)
ChatML modified OSST Dataset: [RaviNaik/oasst1-chatml](https://huggingface.co/datasets/RaviNaik/oasst1-chatml)
Finetuned Model: [RaviNaik/Phi2-Osst](https://huggingface.co/RaviNaik/Phi2-Osst)
"""
)
with gr.Row():
with gr.Column():
prompt_box = gr.Textbox(label="Initial Prompt", interactive=True)
max_length = gr.Slider(
minimum=50,
maximum=500,
value=200,
step=10,
label="Select Number of Tokens to be Generated",
interactive=True,
)
temperature = gr.Slider(
minimum=0.1,
maximum=1,
value=0.7,
step=0.1,
label="Select Temperature",
interactive=True,
)
num_samples = gr.Dropdown(
choices=[1, 2, 5, 10],
value=1,
interactive=True,
label="Select No. of outputs to be generated",
)
submit_btn = gr.Button(value="Generate")
with gr.Column():
output = gr.JSON(label="Generated Text")
submit_btn.click(
generate,
inputs=[prompt_box, max_length, temperature, num_samples],
outputs=[output],
)
app.launch()