AjayP13 commited on
Commit
ae36130
·
verified ·
1 Parent(s): 3545246

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -124
app.py CHANGED
@@ -1,153 +1,177 @@
1
- import itertools
2
  import torch
3
- from statistics import mean
4
- import numpy as np
5
- from torch.nn.utils.rnn import pad_sequence
6
  import gradio as gr
7
- from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
8
- from sentence_transformers import SentenceTransformer
9
- from mutual_implication_score import MIS
10
- from time import time
11
-
12
- # Load the model and tokenizer
13
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
- model_name = "google/flan-t5-large"
15
- tokenizer = AutoTokenizer.from_pretrained(model_name)
16
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
17
- model.to(device)
18
- embedding_model = SentenceTransformer('AnnaWegmann/Style-Embedding', device='cpu').half()
19
- luar_model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", revision="51b0d9ecec5336314e02f191dd8ca4acc0652fe1", trust_remote_code=True).half()
20
- luar_model.to(device)
21
- luar_tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD", revision="51b0d9ecec5336314e02f191dd8ca4acc0652fe1", trust_remote_code=True)
22
- mis_model = MIS(device=device)
23
-
24
- def get_target_style_embeddings(target_texts_batch):
25
- all_target_texts = [target_text for target_texts in target_texts_batch for target_text in target_texts]
26
- embeddings = embedding_model.encode(all_target_texts, batch_size=len(all_target_texts), convert_to_tensor=True, show_progress_bar=False)
27
- lengths = [len(target_texts) for target_texts in target_texts_batch]
28
- split_embeddings = torch.split(embeddings, lengths)
29
- padded_embeddings = pad_sequence(split_embeddings, batch_first=True, padding_value=0.0)
30
- mask = (torch.arange(padded_embeddings.size(1))[None, :] < torch.tensor(lengths)[:, None]).to(embeddings.dtype).unsqueeze(-1)
31
- mean_embeddings = torch.sum(padded_embeddings * mask, dim=1) / mask.sum(dim=1)
32
- return mean_embeddings.float().cpu().numpy()
33
-
34
- @torch.no_grad()
35
- def get_luar_embeddings(texts_batch):
36
- assert len(set([len(texts) for texts in texts_batch])) == 1
37
- episodes = texts_batch
38
- tokenized_episodes = [luar_tokenizer(episode, max_length=512, padding="longest", truncation=True, return_tensors="pt").to(device) for episode in episodes]
39
- episode_lengths = [t["attention_mask"].shape[0] for t in tokenized_episodes]
40
- max_episode_length = max(episode_lengths)
41
- sequence_lengths = [t["attention_mask"].shape[1] for t in tokenized_episodes]
42
- max_sequence_length = max(sequence_lengths)
43
- padded_input_ids = [torch.nn.functional.pad(t["input_ids"], (0, 0, 0, max_episode_length - t["input_ids"].shape[0])) for t in tokenized_episodes]
44
- padded_attention_mask = [torch.nn.functional.pad(t["attention_mask"], (0, 0, 0, max_episode_length - t["attention_mask"].shape[0])) for t in tokenized_episodes]
45
- input_ids = torch.stack(padded_input_ids)
46
- attention_mask = torch.stack(padded_attention_mask)
47
- return luar_model(input_ids=input_ids, attention_mask=attention_mask).float().cpu().numpy()
48
-
49
- def compute_mis(texts, target_texts_batch):
50
- a_texts = list(itertools.chain.from_iterable([[t] * len(target_texts) for t, target_texts in zip(texts, target_texts_batch)]))
51
- b_texts = list(itertools.chain.from_iterable(target_texts_batch))
52
- scores = mis_model.compute(a_texts, b_texts, batch_size=len(a_texts))
53
- for idx, (score, a_text, b_text) in enumerate(zip(scores, a_texts, b_texts)):
54
- if a_text == b_text:
55
- scores[idx] = 1.0
56
- final_scores = []
57
- current_idx = 0
58
- for target_texts in target_texts_batch:
59
- final_scores.append(mean(scores[idx:idx+len(target_texts)]))
60
- return final_scores
61
-
62
- def run_tinystyler_batch(source_texts, target_texts_batch, reranking, temperature, top_p):
63
- inputs = tokenizer(source_texts, return_tensors="pt").to(device)
64
- target_style_embeddings = get_target_style_embeddings(target_texts_batch)
65
- source_style_luar_embeddings = get_luar_embeddings([[st] for st in source_texts])
66
- print("Log 0", time(), source_style_luar_embeddings.shape)
67
- target_style_luar_embeddings = get_luar_embeddings(target_texts_batch)
68
- print("Log 1", time(), target_style_luar_embeddings.shape)
69
- baseline_sim = compute_mis(source_texts, target_texts_batch)
70
- print("Log 1.5", time(), len(baseline_sim))
71
-
72
-
73
- # Generate the output with specified temperature and top_p
74
- output = model.generate(
75
- inputs["input_ids"],
76
- do_sample=True,
77
- temperature=temperature,
78
- top_p=top_p,
79
- max_length=1024,
80
- num_return_sequences=reranking,
81
  )
82
- print("Log 2", time(), output.shape)
83
- generated_texts = tokenizer.batch_decode(output, skip_special_tokens=True)
84
- generated_texts = [generated_texts[i * reranking:(i + 1) * reranking] for i in range(inputs["input_ids"].shape[0])] # Unflatten
85
-
86
- # Evaluate candidates
87
- candidates_luar_embeddings = [get_luar_embeddings([[candidates[i]] for candidates in generated_texts]) for i in range(reranking)]
88
- candidates_sim = [compute_mis([candidates[i] for candidates in generated_texts], target_texts_batch) for i in range(reranking)]
89
- print("Log 3", time(), len(candidates_luar_embeddings), len(candidates_luar_embeddings[0]))
90
-
91
- # Get best based on re-ranking
92
- generated_texts = [texts[0] for texts in generated_texts]
93
- print("Final Log", time(), len(generated_texts))
94
-
95
- return generated_texts
96
-
 
 
 
97
  def run_tinystyler(source_text, target_texts, reranking, temperature, top_p):
98
- target_texts = [target_text.strip() for target_text in target_texts.split("\n")]
99
- return run_tinystyler_batch([source_text], [target_texts], reranking, temperature, top_p)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # Preset examples with cached generations
102
  preset_examples = {
103
- "Example 1": {
104
- "source_text": "Once upon a time in a small village",
105
- "target_texts": "In a land far away, there was a kingdom ruled by a wise king. Every day, the people of the kingdom would gather to listen to the king's stories, which were full of wisdom and kindness.",
106
- "reranking": 5,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  "temperature": 1.0,
108
  "top_p": 1.0,
109
- "output": "Once upon a time in a small village in a land far away, there was a kingdom ruled by a wise king. Every day, the people of the kingdom would gather to listen to the king's stories, which were full of wisdom and kindness."
110
  },
111
- "Example 2": {
112
- "source_text": "The quick brown fox",
113
- "target_texts": "A nimble, chocolate-colored fox swiftly darted through the emerald forest, weaving between trees with grace and agility.",
114
- "reranking": 5,
115
- "temperature": 0.9,
116
- "top_p": 0.9,
117
- "output": "The quick brown fox, a nimble, chocolate-colored fox, swiftly darted through the emerald forest, weaving between trees with grace and agility."
118
- }
119
  }
120
 
121
- # Define Gradio interface
122
  with gr.Blocks(theme="ParityError/[email protected]") as demo:
 
 
 
 
 
 
 
 
123
  gr.Markdown("# TinyStyler Demo")
124
- gr.Markdown("Style transfer the source text into the target style, given some example texts of the target style. You can adjust re-ranking and top_p to your desire to control the quality of style transfer. A higher re-ranking value will generally result in better generations, at slower speed.")
125
-
 
 
126
  with gr.Row():
127
- example_dropdown = gr.Dropdown(label="Examples", choices=list(preset_examples.keys()))
128
-
129
- source_text = gr.Textbox(lines=3, placeholder="Enter the source text to transform into the target style...", label="Source Text")
130
- target_texts = gr.Textbox(lines=5, placeholder="Enter example texts of the target style (one per line)...", label="Example Texts of the Target Style")
131
- reranking = gr.Slider(1, 10, value=5, step=1, label="Re-ranking")
 
 
 
 
 
 
 
 
 
 
132
  temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
133
  top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.1, label="Top-P")
134
-
135
- output = gr.Textbox(lines=5, placeholder="Click 'Generate' to transform the source text into the target style.", label="Output", interactive=False)
 
 
 
 
 
136
 
137
  def set_example(example_name):
138
  example = preset_examples[example_name]
139
- return example["source_text"], example["target_texts"], example["reranking"], example["temperature"], example["top_p"], example["output"]
 
 
 
 
 
 
 
140
 
141
  example_dropdown.change(
142
  set_example,
143
  inputs=[example_dropdown],
144
- outputs=[source_text, target_texts, reranking, temperature, top_p, output]
145
  )
146
-
147
  btn = gr.Button("Generate")
148
- btn.click(run_tinystyler, [source_text, target_texts, reranking, temperature, top_p], output)
 
 
 
 
149
 
150
  # Initialize the fields with the first example
151
- example_dropdown.value, (source_text.value, target_texts.value, reranking.value, temperature.value, top_p.value, output.value) = list(preset_examples.keys())[0], set_example(list(preset_examples.keys())[0])
 
 
 
 
 
 
 
 
 
 
152
 
153
- demo.launch()
 
 
1
  import torch
 
 
 
2
  import gradio as gr
3
+ from huggingface_hub import hf_hub_download
4
+ import importlib
5
+ from functools import lru_cache
6
+
7
+ # Import TinyStyler
8
+ tinystyler_module = importlib.util.module_from_spec(
9
+ importlib.util.spec_from_file_location(
10
+ "tinystyler",
11
+ hf_hub_download(repo_id="tinystyler/tinystyler", filename="tinystyler.py"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  )
13
+ )
14
+ tinystyler_module.__spec__.loader.exec_module(tinystyler_module)
15
+ (
16
+ get_tinystyle_model,
17
+ get_style_embedding_model,
18
+ get_luar_model,
19
+ get_mis_model,
20
+ run_tinystyler_batch,
21
+ ) = (
22
+ tinystyler_module.get_tinystyle_model,
23
+ tinystyler_module.get_style_embedding_model,
24
+ tinystyler_module.get_luar_model,
25
+ tinystyler_module.get_mis_model,
26
+ tinystyler_module.run_tinystyler_batch,
27
+ )
28
+
29
+
30
+ @lru_cache(maxsize=256)
31
  def run_tinystyler(source_text, target_texts, reranking, temperature, top_p):
32
+ try:
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ target_texts = [target_text.strip() for target_text in target_texts.split("\n")]
35
+ assert (
36
+ len(source_text) <= 200
37
+ ), "Please enter a shorter source text (max 200 chars) for the purposes of this demo."
38
+ assert (
39
+ len(target_texts) <= 16
40
+ ), "Please enter fewer than 16 or fewer examples for the purposes of this demo."
41
+ for target_text in target_texts:
42
+ assert len(
43
+ target_text
44
+ ), "Please enter a shorter target texts (max 200 chars per line) for the purposes of this demo."
45
+ return run_tinystyler_batch(
46
+ [source_text],
47
+ [target_texts],
48
+ reranking,
49
+ temperature,
50
+ top_p,
51
+ 200,
52
+ device=device,
53
+ seed=42,
54
+ )[0]
55
+ except Exception as e:
56
+ return f"Error: {e}"
57
+
58
+
59
+ #########################################################################
60
+ # Define Gradio Demo Interfance
61
+ #########################################################################
62
 
63
  # Preset examples with cached generations
64
  preset_examples = {
65
+ "Robert De Niro in Taxi Driver's Style": {
66
+ "source_text": "I know that you and Frank were planning to disconnect me. And I'm afraid that's something I cannot allow to happen.",
67
+ "target_texts": "You talkin' to me? You talkin' to me? You talkin' to me?\nThen who the hell else are you talking... you talking to me? Well I'm the only one here.\nWho the fuck do you think you're talking to? Oh yeah? OK.",
68
+ "reranking": 3,
69
+ "temperature": 1.0,
70
+ "top_p": 1.0,
71
+ "output": "Yeah the fuck? I know you and Frank were planning to disconnect me.",
72
+ },
73
+ "Informal Style": {
74
+ "source_text": "Innovation is where bold ideas meet the relentless pursuit of progress.",
75
+ "target_texts": "the real world, the newly weds and laguna beach\nContact Warner Bros.or just go to ebay.I dont think youll find any\nthat I'm a woman's man with no time to talk!\nWhen you have an eye problem so you see 3,not 2 ( :\ncant wait for a new album from him.\nI'll pick one of my favorite country ones...\nto me, jamie foxx aint all that sexy.\nidk.....but i have faith in you lol\nWang Chung - Everybody Have Fun Tonight\ni am gonna have to defend the werewolf here.\nYEAH, AND I WASN'T VERY COMFORTABLE WITH IT EITHER...\nIF YOU TEXT YOUR ANSWER IN IT MIGHT IF YOU DON'T HAVE TEXT MESSAGES IN YOUR PLAN\nhe is about 83 yrs old\nHE IS TO ME FOR NOW, OUR BLACK GEORGE CLOONEY.\nTill they run out of ideas\neminem because his some of his music is just so funny and relevent to todays pop music enviorment.",
76
+ "reranking": 3,
77
+ "temperature": 1.0,
78
+ "top_p": 1.0,
79
+ "output": "innovation, where bold ideas meet the relentless pursuit of progress...lol",
80
+ },
81
+ "Barack Obama's Style": {
82
+ "source_text": "i heard that new pizza joint is lit af",
83
+ "target_texts": "Good afternoon, everybody.\nLet me start out by saying that I was sorely tempted to wear a tan suit today -- (laughter) -- for my last press conference.\nBut Michelle, whose fashion sense is a little better than mine, tells me that's not appropriate in January.\nI covered a lot of the ground that I would want to cover in my farewell address last week.\nSo I'm just going to say a couple of quick things before I start taking questions.\nFirst, we have been in touch with the Bush family today, after hearing about President George H.W. Bush and Barbara Bush being admitted to the hospital this morning.\nThey have not only dedicated their lives to this country, they have been a constant source of friendship and support and good counsel for Michelle and me over the years.\nThey are as fine a couple as we know. And so we want to send our prayers and our love to them. Really good people.\nSecond thing I want to do is to thank all of you.\nSome of you have been covering me for a long time -- folks like Christi and Win.\nSome of you I've just gotten to know. We have traveled the world together. \nWe’ve hit a few singles, a few doubles together.\nI’ve offered advice that I thought was pretty sound, like “don’t do stupid…stuff.” (Laughter.)\nAnd even when you complained about my long answers, I just want you to know that the only reason they were long was because you asked six-part questions. (Laughter.) \nBut I have enjoyed working with all of you.\nThat does not, of course, mean that I’ve enjoyed every story that you have filed.",
84
+ "reranking": 3,
85
+ "temperature": 1.0,
86
+ "top_p": 1.0,
87
+ "output": "Well, according to my friends I heard that the new pizza joint is in full swing.",
88
+ },
89
+ "Donald Trump's Style": {
90
+ "source_text": "I hereby request your formal approval.",
91
+ "target_texts": "great American Patriots who voted for me, AMERICA FIRST, and MAKE AMERICA GREAT AGAIN, will have a GIANT VOICE long into the future.\nThey will not be disrespected or treated unfairly in any way, shape or form!!!\nTHE REPUBLICAN PARTY AND, MORE IMPORTANTLY, OUR COUNTRY, NEEDS THE PRESIDENCY MORE THAN EVER BEFORE - THE POWER OF THE VETO.\nSTAY STRONG!\nGet smart Republicans.\nFIGHT!\nGeorgia, we have a job to do TODAY.\nWe have to STOP socialism.\nWe have to PROTECT the American Dream.\nHow do you certify numbers that have now proven to be wrong and, in many cases, fraudulent!\nSad to watch!\nSleepy Eyes Chuck Todd is so happy with the fake voter tabulation process that he can’t even get the words out straight.\nThey found out they voted on a FRAUD.\nThe 75,000,000 great American Patriots who voted for me, AMERICA FIRST, and MAKE AMERICA GREAT AGAIN, will have a GIANT VOICE long into the future.\nThey will not be disrespected or treated unfairly in any way, shape or form!!!\nUSA demands the truth!",
92
+ "reranking": 3,
93
  "temperature": 1.0,
94
  "top_p": 1.0,
95
+ "output": "NOW I need your formal approval!",
96
  },
 
 
 
 
 
 
 
 
97
  }
98
 
99
+
100
  with gr.Blocks(theme="ParityError/[email protected]") as demo:
101
+ device = "cuda" if torch.cuda.is_available() else "cpu"
102
+
103
+ # Immediately load models
104
+ get_tinystyle_model(device)
105
+ get_style_embedding_model(device)
106
+ get_luar_model(device)
107
+ get_mis_model(device)
108
+
109
  gr.Markdown("# TinyStyler Demo")
110
+ gr.Markdown(
111
+ "Style transfer the source text into the target style, given some example texts of the target style. You can adjust re-ranking and top_p to your desire to control the quality of style transfer. A higher re-ranking value will generally result in better generations, at slower speed.\n\n*Please note: this demo runs on a CPU-only machine, generation is much faster when run locally with a GPU.*"
112
+ )
113
+
114
  with gr.Row():
115
+ example_dropdown = gr.Dropdown(
116
+ label="Examples", choices=list(preset_examples.keys())
117
+ )
118
+
119
+ source_text = gr.Textbox(
120
+ lines=3,
121
+ placeholder="Enter the source text to transform into the target style...",
122
+ label="Source Text",
123
+ )
124
+ target_texts = gr.Textbox(
125
+ lines=5,
126
+ placeholder="Enter example texts of the target style (one per line)...",
127
+ label="Example Texts of the Target Style",
128
+ )
129
+ reranking = gr.Slider(1, 5, value=3, step=1, label="Re-ranking")
130
  temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="Temperature")
131
  top_p = gr.Slider(0.0, 1.0, value=1.0, step=0.1, label="Top-P")
132
+
133
+ output = gr.Textbox(
134
+ lines=5,
135
+ placeholder="Click 'Generate' to transform the source text into the target style.",
136
+ label="Output",
137
+ interactive=False,
138
+ )
139
 
140
  def set_example(example_name):
141
  example = preset_examples[example_name]
142
+ return (
143
+ example["source_text"],
144
+ example["target_texts"],
145
+ example["reranking"],
146
+ example["temperature"],
147
+ example["top_p"],
148
+ example["output"],
149
+ )
150
 
151
  example_dropdown.change(
152
  set_example,
153
  inputs=[example_dropdown],
154
+ outputs=[source_text, target_texts, reranking, temperature, top_p, output],
155
  )
156
+
157
  btn = gr.Button("Generate")
158
+ btn.click(
159
+ run_tinystyler,
160
+ [source_text, target_texts, reranking, temperature, top_p],
161
+ output,
162
+ )
163
 
164
  # Initialize the fields with the first example
165
+ (
166
+ example_dropdown.value,
167
+ (
168
+ source_text.value,
169
+ target_texts.value,
170
+ reranking.value,
171
+ temperature.value,
172
+ top_p.value,
173
+ output.value,
174
+ ),
175
+ ) = list(preset_examples.keys())[0], set_example(list(preset_examples.keys())[0])
176
 
177
+ demo.launch()