Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,12 @@
|
|
|
|
1 |
import torch
|
|
|
2 |
import numpy as np
|
3 |
from torch.nn.utils.rnn import pad_sequence
|
4 |
import gradio as gr
|
5 |
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
|
6 |
from sentence_transformers import SentenceTransformer
|
|
|
7 |
from time import time
|
8 |
|
9 |
# Load the model and tokenizer
|
@@ -16,7 +19,7 @@ embedding_model = SentenceTransformer('AnnaWegmann/Style-Embedding', device='cpu
|
|
16 |
luar_model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", revision="51b0d9ecec5336314e02f191dd8ca4acc0652fe1", trust_remote_code=True).half()
|
17 |
luar_model.to(device)
|
18 |
luar_tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD", revision="51b0d9ecec5336314e02f191dd8ca4acc0652fe1", trust_remote_code=True)
|
19 |
-
|
20 |
|
21 |
def get_target_style_embeddings(target_texts_batch):
|
22 |
all_target_texts = [target_text for target_texts in target_texts_batch for target_text in target_texts]
|
@@ -43,6 +46,19 @@ def get_luar_embeddings(texts_batch):
|
|
43 |
attention_mask = torch.stack(padded_attention_mask)
|
44 |
return luar_model(input_ids=input_ids, attention_mask=attention_mask).float().cpu().numpy()
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
def run_tinystyler_batch(source_texts, target_texts_batch, reranking, temperature, top_p):
|
47 |
inputs = tokenizer(source_texts, return_tensors="pt").to(device)
|
48 |
target_style_embeddings = get_target_style_embeddings(target_texts_batch)
|
@@ -50,6 +66,8 @@ def run_tinystyler_batch(source_texts, target_texts_batch, reranking, temperatur
|
|
50 |
print("Log 0", time(), source_style_luar_embeddings.shape)
|
51 |
target_style_luar_embeddings = get_luar_embeddings(target_texts_batch)
|
52 |
print("Log 1", time(), target_style_luar_embeddings.shape)
|
|
|
|
|
53 |
|
54 |
|
55 |
# Generate the output with specified temperature and top_p
|
@@ -67,6 +85,7 @@ def run_tinystyler_batch(source_texts, target_texts_batch, reranking, temperatur
|
|
67 |
|
68 |
# Evaluate candidates
|
69 |
candidates_luar_embeddings = [get_luar_embeddings([[candidates[i]] for candidates in generated_texts]) for i in range(reranking)]
|
|
|
70 |
print("Log 3", time(), len(candidates_luar_embeddings), len(candidates_luar_embeddings[0]))
|
71 |
|
72 |
# Get best based on re-ranking
|
|
|
1 |
+
import itertools
|
2 |
import torch
|
3 |
+
from statistics import mean
|
4 |
import numpy as np
|
5 |
from torch.nn.utils.rnn import pad_sequence
|
6 |
import gradio as gr
|
7 |
from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
+
from mutual_implication_score import MIS
|
10 |
from time import time
|
11 |
|
12 |
# Load the model and tokenizer
|
|
|
19 |
luar_model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", revision="51b0d9ecec5336314e02f191dd8ca4acc0652fe1", trust_remote_code=True).half()
|
20 |
luar_model.to(device)
|
21 |
luar_tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD", revision="51b0d9ecec5336314e02f191dd8ca4acc0652fe1", trust_remote_code=True)
|
22 |
+
mis_model = MIS(device=device)
|
23 |
|
24 |
def get_target_style_embeddings(target_texts_batch):
|
25 |
all_target_texts = [target_text for target_texts in target_texts_batch for target_text in target_texts]
|
|
|
46 |
attention_mask = torch.stack(padded_attention_mask)
|
47 |
return luar_model(input_ids=input_ids, attention_mask=attention_mask).float().cpu().numpy()
|
48 |
|
49 |
+
def compute_mis(texts, target_texts_batch):
|
50 |
+
a_texts = list(itertools.chain.from_iterable([[st] * len(target_texts) for st, target_texts in zip(source_texts, target_texts_batch)]))
|
51 |
+
b_texts = list(itertools.chain.from_iterable(target_texts_batch))
|
52 |
+
scores = mis.compute(a_texts, b_texts, batch_size=len(a_texts))
|
53 |
+
for idx, (score, a_text, b_text) in enumerate(zip(scores, a_texts, b_texts)):
|
54 |
+
if a_text == b_text:
|
55 |
+
scores[idx] = 1.0
|
56 |
+
final_scores = []
|
57 |
+
current_idx = 0
|
58 |
+
for target_texts in target_texts_batch:
|
59 |
+
final_scores.append(mean(scores[idx:idx+len(target_texts)]))
|
60 |
+
return final_scores
|
61 |
+
|
62 |
def run_tinystyler_batch(source_texts, target_texts_batch, reranking, temperature, top_p):
|
63 |
inputs = tokenizer(source_texts, return_tensors="pt").to(device)
|
64 |
target_style_embeddings = get_target_style_embeddings(target_texts_batch)
|
|
|
66 |
print("Log 0", time(), source_style_luar_embeddings.shape)
|
67 |
target_style_luar_embeddings = get_luar_embeddings(target_texts_batch)
|
68 |
print("Log 1", time(), target_style_luar_embeddings.shape)
|
69 |
+
baseline_sim = compute_mis(source_texts, target_texts_batch)
|
70 |
+
print("Log 1.5", time(), len(baseline_sim))
|
71 |
|
72 |
|
73 |
# Generate the output with specified temperature and top_p
|
|
|
85 |
|
86 |
# Evaluate candidates
|
87 |
candidates_luar_embeddings = [get_luar_embeddings([[candidates[i]] for candidates in generated_texts]) for i in range(reranking)]
|
88 |
+
candidates_sim = [compute_mis([candidates[i] for candidates in generated_texts], target_texts_batch) for i in range(reranking)]
|
89 |
print("Log 3", time(), len(candidates_luar_embeddings), len(candidates_luar_embeddings[0]))
|
90 |
|
91 |
# Get best based on re-ranking
|