AjayP13 commited on
Commit
fbfca4a
·
verified ·
1 Parent(s): 37aa083

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -1
app.py CHANGED
@@ -1,9 +1,12 @@
 
1
  import torch
 
2
  import numpy as np
3
  from torch.nn.utils.rnn import pad_sequence
4
  import gradio as gr
5
  from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
6
  from sentence_transformers import SentenceTransformer
 
7
  from time import time
8
 
9
  # Load the model and tokenizer
@@ -16,7 +19,7 @@ embedding_model = SentenceTransformer('AnnaWegmann/Style-Embedding', device='cpu
16
  luar_model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", revision="51b0d9ecec5336314e02f191dd8ca4acc0652fe1", trust_remote_code=True).half()
17
  luar_model.to(device)
18
  luar_tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD", revision="51b0d9ecec5336314e02f191dd8ca4acc0652fe1", trust_remote_code=True)
19
-
20
 
21
  def get_target_style_embeddings(target_texts_batch):
22
  all_target_texts = [target_text for target_texts in target_texts_batch for target_text in target_texts]
@@ -43,6 +46,19 @@ def get_luar_embeddings(texts_batch):
43
  attention_mask = torch.stack(padded_attention_mask)
44
  return luar_model(input_ids=input_ids, attention_mask=attention_mask).float().cpu().numpy()
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def run_tinystyler_batch(source_texts, target_texts_batch, reranking, temperature, top_p):
47
  inputs = tokenizer(source_texts, return_tensors="pt").to(device)
48
  target_style_embeddings = get_target_style_embeddings(target_texts_batch)
@@ -50,6 +66,8 @@ def run_tinystyler_batch(source_texts, target_texts_batch, reranking, temperatur
50
  print("Log 0", time(), source_style_luar_embeddings.shape)
51
  target_style_luar_embeddings = get_luar_embeddings(target_texts_batch)
52
  print("Log 1", time(), target_style_luar_embeddings.shape)
 
 
53
 
54
 
55
  # Generate the output with specified temperature and top_p
@@ -67,6 +85,7 @@ def run_tinystyler_batch(source_texts, target_texts_batch, reranking, temperatur
67
 
68
  # Evaluate candidates
69
  candidates_luar_embeddings = [get_luar_embeddings([[candidates[i]] for candidates in generated_texts]) for i in range(reranking)]
 
70
  print("Log 3", time(), len(candidates_luar_embeddings), len(candidates_luar_embeddings[0]))
71
 
72
  # Get best based on re-ranking
 
1
+ import itertools
2
  import torch
3
+ from statistics import mean
4
  import numpy as np
5
  from torch.nn.utils.rnn import pad_sequence
6
  import gradio as gr
7
  from transformers import AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer
8
  from sentence_transformers import SentenceTransformer
9
+ from mutual_implication_score import MIS
10
  from time import time
11
 
12
  # Load the model and tokenizer
 
19
  luar_model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", revision="51b0d9ecec5336314e02f191dd8ca4acc0652fe1", trust_remote_code=True).half()
20
  luar_model.to(device)
21
  luar_tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD", revision="51b0d9ecec5336314e02f191dd8ca4acc0652fe1", trust_remote_code=True)
22
+ mis_model = MIS(device=device)
23
 
24
  def get_target_style_embeddings(target_texts_batch):
25
  all_target_texts = [target_text for target_texts in target_texts_batch for target_text in target_texts]
 
46
  attention_mask = torch.stack(padded_attention_mask)
47
  return luar_model(input_ids=input_ids, attention_mask=attention_mask).float().cpu().numpy()
48
 
49
+ def compute_mis(texts, target_texts_batch):
50
+ a_texts = list(itertools.chain.from_iterable([[st] * len(target_texts) for st, target_texts in zip(source_texts, target_texts_batch)]))
51
+ b_texts = list(itertools.chain.from_iterable(target_texts_batch))
52
+ scores = mis.compute(a_texts, b_texts, batch_size=len(a_texts))
53
+ for idx, (score, a_text, b_text) in enumerate(zip(scores, a_texts, b_texts)):
54
+ if a_text == b_text:
55
+ scores[idx] = 1.0
56
+ final_scores = []
57
+ current_idx = 0
58
+ for target_texts in target_texts_batch:
59
+ final_scores.append(mean(scores[idx:idx+len(target_texts)]))
60
+ return final_scores
61
+
62
  def run_tinystyler_batch(source_texts, target_texts_batch, reranking, temperature, top_p):
63
  inputs = tokenizer(source_texts, return_tensors="pt").to(device)
64
  target_style_embeddings = get_target_style_embeddings(target_texts_batch)
 
66
  print("Log 0", time(), source_style_luar_embeddings.shape)
67
  target_style_luar_embeddings = get_luar_embeddings(target_texts_batch)
68
  print("Log 1", time(), target_style_luar_embeddings.shape)
69
+ baseline_sim = compute_mis(source_texts, target_texts_batch)
70
+ print("Log 1.5", time(), len(baseline_sim))
71
 
72
 
73
  # Generate the output with specified temperature and top_p
 
85
 
86
  # Evaluate candidates
87
  candidates_luar_embeddings = [get_luar_embeddings([[candidates[i]] for candidates in generated_texts]) for i in range(reranking)]
88
+ candidates_sim = [compute_mis([candidates[i] for candidates in generated_texts], target_texts_batch) for i in range(reranking)]
89
  print("Log 3", time(), len(candidates_luar_embeddings), len(candidates_luar_embeddings[0]))
90
 
91
  # Get best based on re-ranking