Spaces:
Sleeping
Sleeping
import spaces | |
import gradio as gr | |
import torch | |
from peft import PeftModel, PeftConfig | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import json | |
from datetime import datetime | |
from uuid import uuid4 | |
import os | |
from pathlib import Path | |
from huggingface_hub import CommitScheduler | |
from utils import hide_code, hide_css | |
# TODO make it so that feedback is only saved on prev. example if user makes another obfuscation | |
# and changes slider but doesn't hit obfuscate | |
# TODO maybe make it save and reset if user hits submit feedback | |
# TODO sampling params for modles | |
# TODO obfuscation ID? | |
# Converts text to the correct format for LoRA adapters in StyleRemix | |
def convert_data_to_format(text): | |
output = f"### Original: {text}\n ### Rewrite:" | |
return output | |
MODEL_PATHS = { | |
"length_more": "hallisky/lora-length-long-llama-3-8b", | |
"length_less": "hallisky/lora-length-short-llama-3-8b", | |
"function_more": "hallisky/lora-function-more-llama-3-8b", | |
"function_less": "hallisky/lora-function-less-llama-3-8b", | |
"grade_more": "hallisky/lora-grade-highschool-llama-3-8b", | |
"grade_less": "hallisky/lora-grade-elementary-llama-3-8b", | |
"formality_more": "hallisky/lora-formality-formal-llama-3-8b", | |
"formality_less": "hallisky/lora-formality-informal-llama-3-8b", | |
"sarcasm_more": "hallisky/lora-sarcasm-more-llama-3-8b", | |
"sarcasm_less": "hallisky/lora-sarcasm-less-llama-3-8b", | |
"voice_passive": "hallisky/lora-voice-passive-llama-3-8b", | |
"voice_active": "hallisky/lora-voice-active-llama-3-8b", | |
"type_persuasive": "hallisky/lora-type-persuasive-llama-3-8b", | |
"type_expository": "hallisky/lora-type-expository-llama-3-8b", | |
"type_narrative": "hallisky/lora-type-narrative-llama-3-8b", | |
"type_descriptive": "hallisky/lora-type-descriptive-llama-3-8b", | |
} | |
FIRST_MODEL = list(MODEL_PATHS.keys())[0] | |
MAX_NEW_TOKENS=1024 | |
DESCRIPTION = """\ | |
# Authorship Obfuscation with StyleRemix | |
This Space demonstrates StyleRemix, a controllable and interpretable method for authorship obfuscation. At its core, it uses a Llama-3 model with 8B parameters and various LoRA adapters fine-tuned to rewrite text towards specific stylistic attributes (like text being longer or shorter). Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also deploy the model on [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints). | |
<br> π΅οΈ Want to learn more? Check out our paper [here](https://arxiv.org/abs/2408.15666v1) and our code [here](https://github.com/jfisher52/StyleRemix)! | |
<br> π§ Have questions about our work or issues with the demo? Feel free to email us at [email protected] and [email protected]. | |
<br> <b>Disclaimer</b>: <em>We may collect and use your queries for further research and development purposes. The data collected will be anonymized and used to enhance our understanding of desired stylistic attributes when rewriting text.</em> | |
""" | |
import subprocess | |
def print_nvidia_smi(): | |
try: | |
# Run the nvidia-smi command | |
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, check=True) | |
print(result.stdout) | |
except subprocess.CalledProcessError as e: | |
# Handle errors in the subprocess | |
print(f"Failed to run nvidia-smi: {e}") | |
except FileNotFoundError: | |
# Handle the case where nvidia-smi is not installed | |
print("nvidia-smi is not installed or not in the PATH.") | |
# Load models | |
if not torch.cuda.is_available(): | |
device = "cpu" | |
DESCRIPTION += "\n<p>Running on CPU π₯Ά This demo does not work on CPU.</p>" | |
if torch.cuda.is_available(): | |
device = "cuda" | |
model_id = "meta-llama/Meta-Llama-3-8B" | |
tokenizer = AutoTokenizer.from_pretrained(model_id, add_bos_token=True, add_eos_token=False, padding_side="left") | |
tokenizer.add_special_tokens({'pad_token': '<padding_token>'}) | |
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(device) # device_map="auto" requires accelerate | |
base_model.resize_token_embeddings(len(tokenizer)) # Resize to add pad token. Value doesn't matter | |
# Load in the first model | |
model = PeftModel.from_pretrained(base_model, MODEL_PATHS[FIRST_MODEL], adapter_name=FIRST_MODEL).to(device) | |
# Load in the rest of the models | |
for cur_adapter in MODEL_PATHS.keys(): | |
if cur_adapter != FIRST_MODEL: | |
model.load_adapter(MODEL_PATHS[cur_adapter], adapter_name=cur_adapter) | |
# print(model.device) # Seems it re-allocates to CPU | |
model.to(device) | |
model.eval() | |
# Global variable to store the latest obfuscation result | |
user_id = str(uuid4()) # Generate a unique session-specific user ID | |
JSON_DATASET_DIR = Path("json_dataset") | |
JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
JSON_DATASET_PATH = JSON_DATASET_DIR / f"train-{user_id}.json" | |
scheduler = CommitScheduler( | |
repo_id="authorship-obfuscation-demo-data", | |
repo_type="dataset", | |
folder_path=JSON_DATASET_DIR, | |
path_in_repo="data", | |
every=0.5 | |
) | |
def save_data(data): | |
with scheduler.lock: | |
with JSON_DATASET_PATH.open("a") as f: | |
json.dump(data, f) | |
f.write("\n") | |
def save_feedback(feedback_rating, feedback_text, latest_obfuscation): | |
latest_obfuscation["feedback_rating"] = feedback_rating | |
latest_obfuscation["feedback_text"] = feedback_text | |
save_data(latest_obfuscation) | |
return "No Feedback Selected", "", gr.update(visible=True) | |
def greet(input_text, length, function_words, grade_level, formality, sarcasm, voice, persuasive, descriptive, narrative, expository): | |
global latest_obfuscation, user_id | |
current_time = datetime.now().isoformat() | |
sliders_dict = {} | |
cur_keys = [] | |
cur_keys.append(("length_more" if length > 0 else (None if length == 0 else "length_less"), abs(length))) | |
cur_keys.append(("function_more" if function_words > 0 else (None if function_words == 0 else "function_less"), abs(function_words))) | |
cur_keys.append(("grade_more" if grade_level > 0 else (None if grade_level == 0 else "grade_less"), abs(grade_level))) | |
cur_keys.append(("sarcasm_more" if sarcasm > 0 else (None if sarcasm == 0 else "sarcasm_less"), abs(sarcasm))) | |
cur_keys.append(("formality_more" if formality > 0 else (None if formality == 0 else "formality_less"), abs(formality))) | |
cur_keys.append(("voice_active" if voice > 0 else (None if voice == 0 else "voice_passive"),abs(voice))) | |
cur_keys.append(("type_persuasive" if persuasive != 0 else None, abs(persuasive))) | |
cur_keys.append(("type_descriptive" if descriptive != 0 else None, abs(descriptive))) | |
cur_keys.append(("type_narrative" if narrative != 0 else None, abs(narrative))) | |
cur_keys.append(("type_expository" if expository != 0 else None, abs(expository))) | |
for cur_key in cur_keys: | |
if cur_key[0] is not None: | |
sliders_dict[cur_key[0]] = cur_key[1] | |
# Make the adapter and switch to it | |
print(sliders_dict) | |
if len(sliders_dict) > 0: | |
combo_adapter_name = "" | |
for slider_key in sliders_dict: | |
print(slider_key) | |
print(sliders_dict[slider_key]) | |
combo_adapter_name += slider_key + str(int(100*sliders_dict[slider_key])) + "-" | |
combo_adapter_name = combo_adapter_name[:-1] | |
print(combo_adapter_name) | |
print(list(sliders_dict.values())) | |
print(list(sliders_dict.keys())) | |
print(list(model.peft_config.keys())) | |
# Add and set the weighted adapater | |
model.add_weighted_adapter( | |
list(sliders_dict.keys()), | |
weights = list(sliders_dict.values()), | |
adapter_name = combo_adapter_name, | |
combination_type = "cat" | |
) | |
model.set_adapter(combo_adapter_name) | |
# Convert the list of strings in data to a list of model inputs | |
converted_text = convert_data_to_format(input_text) | |
inputs = tokenizer(converted_text, return_tensors="pt", max_length=2048, truncation=True).to(device) | |
input_length = inputs.input_ids.shape[1] | |
with torch.no_grad(): | |
outputs = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, top_p = 0.95) | |
response = tokenizer.decode(outputs[0, input_length:], skip_special_tokens=True).strip() | |
full_output = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
else: | |
response = input_text # If no sliders passed, do not do anything | |
full_output = response | |
# print_nvidia_smi() # Print GPU usage | |
# Save the new obfuscation result and reset feedback | |
latest_obfuscation = { | |
"datetime": current_time, | |
"user_id": user_id, | |
"input_text": input_text, | |
"sliders": { | |
"length": length, | |
"function_words": function_words, | |
"grade_level": grade_level, | |
"sarcasm": sarcasm, | |
"formality": formality, | |
"voice": voice, | |
"persuasive": persuasive, | |
"descriptive": descriptive, | |
"narrative": narrative, | |
"expository": expository | |
}, | |
"input": input_text, | |
"output": response, | |
"full_output": full_output, | |
"feedback_rating": "No Feedback Selected", | |
"feedback_text": "" | |
} | |
# Save the obfuscation result | |
save_data(latest_obfuscation) | |
return response, gr.update(interactive=True), gr.update(interactive=True), latest_obfuscation | |
def auto_sliders(): | |
return [0.5] * 7 + [0] * 3 | |
def reset_sliders(): | |
return [0] * 7 + [0] * 3 | |
def toggle_slider(checked, value): | |
if checked: | |
return gr.update(value=value, interactive=True) | |
else: | |
return gr.update(value=0, interactive=False) | |
def reset_writing_type_sliders(selected_type): | |
reset_values = [gr.update(value=0, interactive=False) for _ in range(4)] | |
if selected_type != "None": | |
index = ["Persuasive", "Descriptive", "Narrative", "Expository"].index(selected_type) | |
reset_values[index] = gr.update(value=0, interactive=True) | |
return reset_values | |
def update_save_feedback_button(feedback_rating, feedback_text): | |
if feedback_rating != "No Feedback Selected" or feedback_text.strip() != "": | |
return gr.update(interactive=True), gr.update(visible=False) | |
else: | |
return gr.update(interactive=False), gr.update(visible=True) | |
def update_obfuscate_button(input_text): | |
if input_text.strip() == "": | |
return gr.update(interactive=False), gr.update(visible=True) | |
else: | |
return gr.update(interactive=True), gr.update(visible=False) | |
def check_initial_feedback_state(feedback_rating, feedback_text): | |
return update_save_feedback_button(feedback_rating, feedback_text) | |
demo = gr.Blocks() | |
with demo: | |
latest_obfuscation = gr.State({}) | |
gr.Markdown(DESCRIPTION) | |
gr.HTML(hide_css) | |
with gr.Row(): | |
with gr.Column(variant="panel"): | |
gr.Markdown("# 1) Input Text\n### Enter the text to be obfuscated. We recommend *full sentences* or *paragraphs*.") | |
input_text = gr.Textbox( | |
label="Input Text", | |
placeholder="The quick brown fox jumped over the lazy dogs." | |
) | |
gr.Markdown("# 2) Style Element Sliders\n### Adjust the style element sliders to the desired levels to steer the obfuscation.") | |
with gr.Row(): | |
auto_button = gr.Button("Choose slider values automatically (based on input text)") | |
reset_button = gr.Button("Reset slider values") | |
sliders = [] | |
slider_values = [ | |
("Length (Shorter \u2192 Longer)", -1, 1, 0), | |
("Function Words (Fewer \u2192 More)", -1, 1, 0), | |
("Grade Level (Lower \u2192 Higher)", -1, 1, 0), | |
("Formality (Less \u2192 More)", -1, 1, 0), | |
("Sarcasm (Less \u2192 More)", -1, 1, 0), | |
("Voice (Passive \u2192 Active)", -1, 1, 0), | |
("Writing Type: Persuasive (None \u2192 More)", 0, 1, 0), | |
("Writing Type: Descriptive (None \u2192 More)", 0, 1, 0), | |
("Writing Type: Narrative (None \u2192 More)", 0, 1, 0), | |
("Writing Type: Expository (None \u2192 More)", 0, 1, 0) | |
] | |
non_writing_type_sliders = [] | |
writing_type_sliders = [] | |
for idx, (label, min_val, max_val, default) in enumerate(slider_values): | |
if "Writing Type" not in label: | |
with gr.Row(): | |
# with gr.Column(scale=1, min_width=25): | |
checkbox = gr.Checkbox(label=label.split("(")[0], scale=1) | |
#with gr.Column(scale=2, min_width=50): | |
slider = gr.Slider(label=label.split("(")[1][:-1], minimum=min_val, maximum=max_val, step=0.01, value=default, interactive=False, scale=3) | |
checkbox.change(fn=toggle_slider, inputs=[checkbox, gr.State(default)], outputs=slider) | |
non_writing_type_sliders.append(slider) | |
sliders.append(slider) | |
writing_type_radio = gr.Radio( | |
label="Writing Type", | |
choices=["None", "Persuasive", "Descriptive", "Narrative", "Expository"], | |
value="None" | |
) | |
writing_type_radio.change(fn=reset_writing_type_sliders, inputs=writing_type_radio, outputs=writing_type_sliders) | |
for idx, (label, min_val, max_val, default) in enumerate(slider_values): | |
if "Writing Type" in label: | |
with gr.Row(): | |
slider = gr.Slider(label=label, minimum=min_val, maximum=max_val, step=0.01, value=default, interactive=False) | |
writing_type_sliders.append(slider) | |
sliders.append(slider) | |
obfuscate_button = gr.Button("Obfuscate Text", interactive=False) | |
warning_message = gr.Markdown( | |
"<div style='text-align: center; color: red;'>β οΈ Please enter text before obfuscating. β οΈ</div>", visible=True | |
) | |
auto_button.click(fn=auto_sliders, inputs=[], outputs=sliders) | |
reset_button.click(fn=reset_sliders, inputs=[], outputs=sliders) | |
input_text.change(fn=update_obfuscate_button, inputs=input_text, outputs=[obfuscate_button, warning_message]) | |
# Initialize the button and warning message state on page load | |
demo.load(fn=update_obfuscate_button, inputs=input_text, outputs=[obfuscate_button, warning_message]) | |
# with gr.Column(variant="panel"): | |
# gr.Markdown("# 3) Obfuscated Output") | |
with gr.Column(variant="panel"): | |
gr.Markdown("# 3) Obfuscated Output") | |
output = gr.Textbox(label="Output", lines=3) | |
gr.Markdown("## Feedback [Optional]") | |
# Add thumbs up / thumbs down | |
gr.Markdown("### Is the response good or bad?") | |
feedback_rating = gr.Radio(choices=["No Feedback Selected", "Good π", "Bad π"], value="No Feedback Selected", interactive=False, label="Rate the Response") | |
# Add feedback box | |
gr.Markdown("### Provide any feedback on the obfuscation") | |
feedback_text = gr.Textbox(label="Feedback", lines=3, interactive=False) | |
obfuscate_button.click( | |
fn=greet, | |
inputs=[input_text] + sliders, | |
outputs=[output, feedback_rating, feedback_text, latest_obfuscation]) | |
save_feedback_button = gr.Button("Submit Feedback", interactive=False) | |
confirmation_message = gr.Markdown( | |
"<div id='confirmation-message' style='text-align: center; color: green;'>π₯³ Feedback has been submitted successfully! π</div>", visible=False | |
) | |
feedback_warning_message = gr.Markdown( | |
"<div id='feedback-warning' style='text-align: center; color: red;'>β οΈ Please provide feedback or a rating before submitting. β οΈ</div>", visible=True | |
) | |
# Update the interactivity of the save_feedback_button based on feedback_rating and feedback_text | |
feedback_rating.change(fn=update_save_feedback_button, inputs=[feedback_rating, feedback_text], outputs=[save_feedback_button, feedback_warning_message]) | |
feedback_text.change(fn=update_save_feedback_button, inputs=[feedback_rating, feedback_text], outputs=[save_feedback_button, feedback_warning_message]) | |
save_feedback_button.click( | |
fn=save_feedback, | |
inputs=[feedback_rating, feedback_text, latest_obfuscation], | |
outputs=[feedback_rating, feedback_text, confirmation_message] | |
) | |
save_feedback_button.click( | |
fn=None, | |
inputs=[], | |
outputs=None, | |
js=hide_code | |
) | |
# Initialize the save feedback button and warning message state on page load | |
demo.load(fn=check_initial_feedback_state, inputs=[feedback_rating, feedback_text], outputs=[save_feedback_button, feedback_warning_message]) | |
demo.launch() | |