Aneeth's picture
Upload 2 files
ec6df95 verified
raw
history blame
1.76 kB
from peft import AutoPeftModelForCausalLM
from transformers import GenerationConfig
from transformers import AutoTokenizer
import torch
import streamlit as st
# model = AutoModelForCausalLM.from_pretrained(
# "tiiuae/falcon-7b-instruct",
# torch_dtype=torch.bfloat16,
# trust_remote_code=True,
# device_map="auto",
# low_cpu_mem_usage=True,
# )
model = AutoPeftModelForCausalLM.from_pretrained(
"Aneeth/zephyr_10k",
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained("Aneeth/zephyr_10k")
generation_config = GenerationConfig(
do_sample=True,
top_k=1,
temperature=0.5,
max_new_tokens=5000,
pad_token_id=tokenizer.eos_token_id,
)
def process_data_sample(example):
processed_example = "<|system|>\n Generate an authentic job description using the given input.\n<|user|>\n" + example["instruction"] + "\n<|assistant|>\n"
return processed_example
def generate_text(prompt):
inp_str = process_data_sample(
{
"instruction": prompt,
}
)
inputs = tokenizer(inp_str, return_tensors="pt").to("cpu")
outputs = model.generate(**inputs, generation_config=generation_config)
response=tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
def main():
st.title("Zephyr Inference")
# Get input from user
input_text = st.text_area("Input JD prompt", "Type here...")
# Generate text on button click
if st.button("Generate Text"):
generated_text = generate_text(input_text)
st.subheader("Generated Text:")
st.write(generated_text)
if __name__ == "__main__":
main()