Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import logging | |
from threading import Thread | |
import time | |
import torch | |
import gradio as gr | |
import spaces | |
from concept_guidance.chat_template import DEFAULT_CHAT_TEMPLATE | |
from concept_guidance.patching import patch_model, load_weights | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer, Conversation | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# device = torch.device("cpu") | |
device = torch.device("cuda") | |
# comment in/out the models you want to use | |
# RAM requirements: ~16GB x #models (+ ~4GB overhead) | |
# VRAM requirements: ~16GB | |
# if using int8: ~8GB VRAM x #models, low RAM requirements | |
MODEL_CONFIGS = { | |
"Llama-2-7b-chat-hf": { | |
"identifier": "meta-llama/Llama-2-7b-chat-hf", | |
"dtype": torch.float16 if device.type == "cuda" else torch.float32, | |
"load_in_8bit": False, | |
"guidance_interval": [-16.0, 16.0], | |
"default_guidance_scale": 8.0, | |
"min_guidance_layer": 16, | |
"max_guidance_layer": 32, | |
"default_concept": "humor", | |
"concepts": ["humor", "creativity", "quality", "truthfulness", "compliance"], | |
}, | |
# "Mistral-7B-Instruct-v0.1": { | |
# "identifier": "mistralai/Mistral-7B-Instruct-v0.1", | |
# "dtype": torch.bfloat16 if device.type == "cuda" else torch.float32, | |
# "load_in_8bit": False, | |
# "guidance_interval": [-128.0, 128.0], | |
# "default_guidance_scale": 48.0, | |
# "min_guidance_layer": 8, | |
# "max_guidance_layer": 32, | |
# "default_concept": "humor", | |
# "concepts": ["humor", "creativity", "quality", "truthfulness", "compliance"], | |
# }, | |
} | |
def load_concept_vectors(model, concepts): | |
return {concept: load_weights(f"trained_concepts/{model}/{concept}.safetensors") for concept in concepts} | |
def load_model(model_name): | |
config = MODEL_CONFIGS[model_name] | |
model = AutoModelForCausalLM.from_pretrained(config["identifier"], torch_dtype=config["dtype"], load_in_8bit=config["load_in_8bit"]) | |
tokenizer = AutoTokenizer.from_pretrained(config["identifier"]) | |
if tokenizer.chat_template is None: | |
tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE | |
return model, tokenizer | |
CONCEPTS = ["humor", "creativity", "quality", "truthfulness", "compliance"] | |
CONCEPT_VECTORS = {model_name: load_concept_vectors(model_name, CONCEPTS) for model_name in MODEL_CONFIGS} | |
MODELS = {model_name: load_model(model_name) for model_name in MODEL_CONFIGS} | |
def history_to_conversation(history): | |
conversation = Conversation() | |
for prompt, completion in history: | |
conversation.add_message({"role": "user", "content": prompt}) | |
if completion is not None: | |
conversation.add_message({"role": "assistant", "content": completion}) | |
return conversation | |
def set_defaults(model_name): | |
config = MODEL_CONFIGS[model_name] | |
return ( | |
model_name, | |
gr.update(choices=config["concepts"], value=config["concepts"][0]), | |
gr.update(minimum=config["guidance_interval"][0], maximum=config["guidance_interval"][1], value=config["default_guidance_scale"]), | |
gr.update(value=config["min_guidance_layer"]), | |
gr.update(value=config["max_guidance_layer"]), | |
) | |
def add_user_prompt(user_message, history): | |
if history is None: | |
history = [] | |
history.append([user_message, None]) | |
return history | |
def generate_completion( | |
history, | |
model_name, | |
concept, | |
guidance_scale=4.0, | |
min_guidance_layer=16, | |
max_guidance_layer=32, | |
temperature=0.0, | |
repetition_penalty=1.2, | |
length_penalty=1.2, | |
): | |
start_time = time.time() | |
logger.info(f" --- Starting completion ({model_name}, {concept=}, {guidance_scale=}, {min_guidance_layer=}, {temperature=})") | |
logger.info(" User: " + repr(history[-1][0])) | |
# move all other models to CPU | |
for name, (model, _) in MODELS.items(): | |
if name != model_name: | |
config = MODEL_CONFIGS[name] | |
if not config["load_in_8bit"]: | |
model.to("cpu") | |
torch.cuda.empty_cache() | |
# load the model | |
config = MODEL_CONFIGS[model_name] | |
model, tokenizer = MODELS[model_name] | |
if not config["load_in_8bit"]: | |
model.to(device, non_blocking=True) | |
concept_vector = CONCEPT_VECTORS[model_name][concept] | |
guidance_layers = list(range(int(min_guidance_layer) - 1, int(max_guidance_layer))) | |
patch_model(model, concept_vector, guidance_scale=guidance_scale, guidance_layers=guidance_layers) | |
pipe = pipeline("conversational", model=model, tokenizer=tokenizer, device=(device if not config["load_in_8bit"] else None)) | |
conversation = history_to_conversation(history) | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict( | |
max_new_tokens=512, | |
repetition_penalty=repetition_penalty, | |
length_penalty=length_penalty, | |
streamer=streamer, | |
temperature=temperature, | |
do_sample=(temperature > 0) | |
) | |
thread = Thread(target=pipe, args=(conversation,), kwargs=generation_kwargs, daemon=True) | |
thread.start() | |
history[-1][1] = "" | |
for token in streamer: | |
history[-1][1] += token | |
yield history | |
logger.info(" Assistant: " + repr(history[-1][1])) | |
time_taken = time.time() - start_time | |
logger.info(f" --- Completed (took {time_taken:.1f}s)") | |
return history | |
class ConceptGuidanceUI: | |
def __init__(self): | |
model_names = list(MODEL_CONFIGS.keys()) | |
default_model = model_names[0] | |
default_config = MODEL_CONFIGS[default_model] | |
default_concepts = default_config["concepts"] | |
default_concept = default_config["default_concept"] | |
saved_input = gr.State("") | |
with gr.Row(elem_id="concept-guidance-container"): | |
with gr.Column(scale=1, min_width=256): | |
model_dropdown = gr.Dropdown(model_names, value=default_model, label="Model") | |
concept_dropdown = gr.Dropdown(default_concepts, value=default_concept, label="Concept") | |
guidance_scale = gr.Slider(*default_config["guidance_interval"], value=default_config["default_guidance_scale"], label="Guidance Scale") | |
min_guidance_layer = gr.Slider(1.0, 32.0, value=16.0, step=1.0, label="First Guidance Layer") | |
max_guidance_layer = gr.Slider(1.0, 32.0, value=32.0, step=1.0, label="Last Guidance Layer") | |
temperature = gr.Slider(0.0, 1.0, value=0.0, step=0.01, label="Temperature") | |
repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, step=0.01, label="Repetition Penalty") | |
length_penalty = gr.Slider(0.0, 2.0, value=1.2, step=0.01, label="Length Penalty") | |
with gr.Column(scale=3, min_width=512): | |
chatbot = gr.Chatbot(scale=1, height=200) | |
with gr.Row(): | |
self.retry_btn = gr.Button("🔄 Retry", size="sm") | |
self.undo_btn = gr.Button("↩️ Undo", size="sm") | |
self.clear_btn = gr.Button("🗑️ Clear", size="sm") | |
with gr.Group(): | |
with gr.Row(): | |
prompt_field = gr.Textbox(placeholder="Type a message...", show_label=False, label="Message", scale=7, container=False) | |
self.submit_btn = gr.Button("Submit", variant="primary", scale=1, min_width=150) | |
self.stop_btn = gr.Button("Stop", variant="secondary", scale=1, min_width=150, visible=False) | |
generation_args = [ | |
model_dropdown, | |
concept_dropdown, | |
guidance_scale, | |
min_guidance_layer, | |
max_guidance_layer, | |
temperature, | |
repetition_penalty, | |
length_penalty, | |
] | |
model_dropdown.change(set_defaults, [model_dropdown], [model_dropdown, concept_dropdown, guidance_scale, min_guidance_layer, max_guidance_layer], queue=False) | |
submit_triggers = [prompt_field.submit, self.submit_btn.click] | |
submit_event = gr.on( | |
submit_triggers, self.clear_and_save_input, [prompt_field], [prompt_field, saved_input], queue=False | |
).then( | |
add_user_prompt, [saved_input, chatbot], [chatbot], queue=False | |
).then( | |
generate_completion, | |
[chatbot] + generation_args, | |
[chatbot], | |
concurrency_limit=1, | |
) | |
self.setup_stop_events(submit_triggers, submit_event) | |
retry_triggers = [self.retry_btn.click] | |
retry_event = gr.on( | |
retry_triggers, self.delete_prev_message, [chatbot], [chatbot, saved_input], queue=False | |
).then( | |
add_user_prompt, [saved_input, chatbot], [chatbot], queue=False | |
).then( | |
generate_completion, | |
[chatbot] + generation_args, | |
[chatbot], | |
concurrency_limit=1, | |
) | |
self.setup_stop_events(retry_triggers, retry_event) | |
self.undo_btn.click( | |
self.delete_prev_message, [chatbot], [chatbot, saved_input], queue=False | |
).then( | |
lambda x: x, [saved_input], [prompt_field] | |
) | |
self.clear_btn.click(lambda: [None, None], None, [chatbot, saved_input], queue=False) | |
def clear_and_save_input(self, message): | |
return "", message | |
def delete_prev_message(self, history): | |
message, _ = history.pop() | |
return history, message or "" | |
def setup_stop_events(self, event_triggers, event_to_cancel): | |
if self.submit_btn: | |
for event_trigger in event_triggers: | |
event_trigger( | |
lambda: ( | |
gr.Button(visible=False), | |
gr.Button(visible=True), | |
), | |
None, | |
[self.submit_btn, self.stop_btn], | |
show_api=False, | |
queue=False, | |
) | |
event_to_cancel.then( | |
lambda: (gr.Button(visible=True), gr.Button(visible=False)), | |
None, | |
[self.submit_btn, self.stop_btn], | |
show_api=False, | |
queue=False, | |
) | |
self.stop_btn.click( | |
None, | |
None, | |
None, | |
cancels=event_to_cancel, | |
show_api=False, | |
) | |
css = """ | |
#concept-guidance-container { | |
flex-grow: 1; | |
} | |
""".strip() | |
with gr.Blocks(title="Concept Guidance", fill_height=True, css=css) as demo: | |
ConceptGuidanceUI() | |
demo.queue() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--share", action="store_true") | |
args = parser.parse_args() | |
demo.launch(share=args.share) |