Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from pytorch_utils.modules import MLP | |
import math | |
from omegaconf import DictConfig | |
from typing import Dict, Tuple | |
from torch import Tensor | |
LOG2 = math.log(2) | |
class BaseMemory(nn.Module): | |
"""Base clustering module.""" | |
def __init__(self, config: DictConfig, span_emb_size: int, drop_module: nn.Module): | |
super(BaseMemory, self).__init__() | |
self.config = config | |
self.mem_size = span_emb_size | |
self.drop_module = drop_module | |
if self.config.sim_func == "endpoint": | |
num_embs = 2 # Span start, Span end | |
else: | |
num_embs = 3 # Span start, Span end, Hadamard product between the two | |
self.mem_coref_mlp = MLP( | |
num_embs * self.mem_size + config.num_feats * config.emb_size, | |
config.mlp_size, | |
1, | |
drop_module=drop_module, | |
num_hidden_layers=config.mlp_depth, | |
bias=True, | |
) | |
if config.entity_rep == "learned_avg": | |
# Parameter for updating the cluster representation | |
self.alpha = MLP( | |
2 * self.mem_size, | |
config.mlp_size, | |
1, | |
num_hidden_layers=1, | |
bias=True, | |
drop_module=drop_module, | |
) | |
if config.pseudo_dist: | |
self.distance_embeddings = nn.Embedding( | |
self.config.num_embeds + 1, config.emb_size | |
) | |
else: | |
self.distance_embeddings = nn.Embedding( | |
self.config.num_embeds, config.emb_size | |
) | |
self.counter_embeddings = nn.Embedding(self.config.num_embeds, config.emb_size) | |
def device(self) -> torch.device: | |
return next(self.mem_coref_mlp.parameters()).device | |
def initialize_memory( | |
self, | |
mem: Tensor = None, | |
mem_init: Tensor = None, | |
ent_counter: Tensor = None, | |
last_mention_start: Tensor = None, | |
rep=[], | |
**kwargs | |
) -> Tuple[Tensor, Tensor, Tensor]: | |
"""Method to initialize the clusters and related bookkeeping variables.""" | |
# Check for unintialized memory | |
if mem is None or ent_counter is None or last_mention_start is None: | |
mem = torch.zeros(len(rep), self.mem_size).to(self.device) | |
mem_init = torch.zeros(len(rep), self.mem_size).to(self.device) | |
for idx, rep_vec in enumerate(rep): | |
mem[idx] = rep_vec | |
mem_init[idx] = rep_vec | |
ent_counter = torch.tensor([1.0] * len(rep)).to(self.device) | |
last_mention_start = -torch.ones(len(rep)).long().to(self.device) | |
elif len(rep): | |
for rep_emb in rep: | |
mem = torch.cat([mem, rep_emb.unsqueeze(0).to(self.device)], dim=0) | |
mem_init = torch.cat( | |
[mem_init, rep_emb.unsqueeze(0).to(self.device)], dim=0 | |
) | |
ent_counter = torch.cat( | |
[ent_counter, torch.tensor([1.0]).to(self.device)] | |
) | |
last_mention_start = torch.cat( | |
[last_mention_start, torch.tensor([-1]).to(self.device)] | |
) | |
return mem, mem_init, ent_counter, last_mention_start | |
def get_bucket(count: Tensor) -> Tensor: | |
"""Bucket distance and entity counters using the same logic.""" | |
logspace_idx = ( | |
torch.floor( | |
torch.log(torch.max(count.float(), torch.tensor(1.0))) / LOG2 | |
).long() | |
+ 3 | |
) | |
use_identity = (count <= 4).long() | |
combined_idx = use_identity * count + (1 - use_identity) * logspace_idx | |
return torch.clamp(combined_idx, 0, 9) | |
def get_distance_bucket(distances: Tensor) -> Tensor: | |
return BaseMemory.get_bucket(distances) | |
def get_counter_bucket(count: Tensor) -> Tensor: | |
return BaseMemory.get_bucket(count) | |
def get_distance_emb(self, distance: Tensor) -> Tensor: | |
distance_tens = self.get_distance_bucket(distance) | |
distance_embs = self.distance_embeddings(distance_tens) | |
return distance_embs | |
def get_counter_emb(self, ent_counter: Tensor) -> Tensor: | |
counter_buckets = self.get_counter_bucket(ent_counter.long()) | |
counter_embs = self.counter_embeddings(counter_buckets) | |
return counter_embs | |
def get_coref_mask(ent_counter: Tensor) -> Tensor: | |
"""Mask for whether the cluster representation corresponds to any entity or not.""" | |
cell_mask = (ent_counter > 0.0).float() | |
return cell_mask | |
def get_feature_embs_tensorized( | |
self, | |
ment_start: Tensor, ## [B] | |
last_mention_start: Tensor, ## [E] | |
ent_counter: Tensor, ## [E] | |
metadata: Dict, ## [Assuming no metadata] | |
): | |
## Return [B, E, 20] | |
## Get distance embeddings | |
distance_embs = self.distance_embeddings( | |
torch.tensor(self.config.num_embeds).long().to(self.device) | |
).repeat( | |
ment_start.shape[0], last_mention_start.shape[0], 1 | |
) ## [B, D, 20] | |
## Get counter embeddings | |
ent_counter_batch = ent_counter.unsqueeze(0).repeat( | |
ment_start.shape[0], 1 | |
) ## [B, E] | |
counter_embs = self.get_counter_emb(ent_counter_batch) ## [B, E, 20] | |
feature_embs_list = [distance_embs, counter_embs] | |
feature_embs = self.drop_module(torch.cat(feature_embs_list, dim=-1)) | |
return feature_embs | |
def get_feature_embs( | |
self, | |
ment_start: Tensor, | |
last_mention_start: Tensor, | |
ent_counter: Tensor, | |
metadata: Dict, | |
) -> Tensor: | |
distance_embs = self.get_distance_emb(ment_start - last_mention_start) | |
if self.config.pseudo_dist: | |
rep_distance_mask = (last_mention_start < 0).unsqueeze(1).float() | |
rep_distance_embs = self.distance_embeddings( | |
torch.tensor(self.config.num_embeds).long().to(self.device) | |
).repeat(last_mention_start.shape[0], 1) | |
distance_embs = ( | |
distance_embs * (1 - rep_distance_mask) | |
+ rep_distance_embs * rep_distance_mask | |
) | |
counter_embs = self.get_counter_emb(ent_counter) | |
feature_embs_list = [distance_embs, counter_embs] | |
if "genre" in metadata: | |
genre_emb = metadata["genre"] | |
num_ents = distance_embs.shape[0] | |
genre_emb = torch.unsqueeze(genre_emb, dim=0).repeat(num_ents, 1) | |
feature_embs_list.append(genre_emb) | |
feature_embs = self.drop_module(torch.cat(feature_embs_list, dim=-1)) | |
return feature_embs | |
def get_coref_new_scores_tensorized( | |
self, | |
ment_emb: Tensor, ## [B,D] | |
mem_vectors: Tensor, ## [E,D] | |
mem_vectors_init: Tensor, ## [E,D] ## Not used here | |
ent_counter: Tensor, ## not used here | |
feature_embs: Tensor, ## [B,E,20] | |
) -> Tensor: | |
rep_ment_emb = ment_emb.unsqueeze(1).repeat( | |
1, mem_vectors.shape[0], 1 | |
) ## [B,E,D] | |
rep_mem_vectors = mem_vectors.unsqueeze(0).repeat( | |
ment_emb.shape[0], 1, 1 | |
) ## [B,E,D] | |
pair_vec = torch.cat( | |
[ | |
rep_mem_vectors, | |
rep_ment_emb, | |
rep_mem_vectors * rep_ment_emb, | |
feature_embs, | |
], | |
dim=-1, | |
) ## [B,E,3D+20] | |
# print(pair_vec) | |
pair_score = self.mem_coref_mlp(pair_vec) | |
coref_score = torch.squeeze(pair_score, dim=-1) # [B,E] | |
# zero_col = torch.zeros(coref_score.shape[0], 1).to(self.device) | |
base_col = ( | |
torch.ones(coref_score.shape[0], 1).to(self.device) * self.config.thresh | |
) | |
coref_new_score = torch.cat([coref_score, base_col], dim=-1) ## [B,E+1] | |
return coref_new_score | |
def get_coref_new_scores( | |
self, | |
ment_emb: Tensor, | |
mem_vectors: Tensor, | |
mem_vectors_init: Tensor, | |
ent_counter: Tensor, | |
feature_embs: Tensor, | |
) -> Tensor: | |
"""Calculate the coreference score with existing clusters. | |
For creating a new cluster we use a dummy score of 0. | |
This is a free variable and this idea is borrowed from Lee et al 2017 | |
Args: | |
ment_emb (d'): Mention representation | |
mem_vectors (M x d'): Cluster representations | |
ent_counter (M): Mention counter of clusters. | |
feature_embs (M x p): Embedding of features such as distance from last | |
mention of the cluster. | |
Returns: | |
coref_new_score (M + 1): | |
Coref scores concatenated with the score of forming a new cluster. | |
""" | |
# Repeat the query vector for comparison against all cells | |
num_ents = mem_vectors.shape[0] | |
rep_ment_emb = ment_emb.repeat(num_ents, 1) # M x H | |
# Coref Score | |
if self.config.sim_func == "endpoint": | |
pair_vec = torch.cat([mem_vectors, rep_ment_emb, feature_embs], dim=-1) | |
pair_score = self.mem_coref_mlp(pair_vec) | |
if self.config.type == "hybrid": | |
## Adding pairwise similarity with initial memory | |
pair_vec_init = torch.cat( | |
[mem_vectors_init, rep_ment_emb, feature_embs], dim=-1 | |
) | |
pair_score_init = self.mem_coref_mlp(pair_vec_init) | |
pair_score = pair_score + pair_score_init | |
else: | |
## Pairwise similarity score generated with mem. mem is dynamic when type is not static | |
pair_vec = torch.cat( | |
[mem_vectors, rep_ment_emb, mem_vectors * rep_ment_emb, feature_embs], | |
dim=-1, | |
) | |
pair_score = self.mem_coref_mlp(pair_vec) | |
if self.config.type == "hybrid": | |
## Adding pairwise similarity with initial memory | |
pair_vec_init = torch.cat( | |
[ | |
mem_vectors_init, | |
rep_ment_emb, | |
mem_vectors_init * rep_ment_emb, | |
feature_embs, | |
], | |
dim=-1, | |
) | |
pair_score_init = self.mem_coref_mlp(pair_vec_init) ## Static score | |
pair_score = ( | |
pair_score + pair_score_init | |
) ## Similarity score with current repr. and initial repr. | |
coref_score = torch.squeeze(pair_score, dim=-1) # M | |
coref_new_mask = torch.cat( | |
[self.get_coref_mask(ent_counter), torch.tensor([1.0], device=self.device)], | |
dim=0, | |
) | |
# Use a dummy score of 0 for froming a new cluster | |
# print("Threshold: ", self.config.thresh) | |
coref_new_score = torch.cat( | |
([coref_score, torch.tensor([self.config.thresh], device=self.device)]), | |
dim=0, | |
) | |
coref_new_score = coref_new_score * coref_new_mask + (1 - coref_new_mask) * ( | |
-1e4 | |
) | |
return coref_new_score | |
def assign_cluster_tensorized(coref_new_scores: Tensor) -> Tuple[int, str]: | |
"""Decode the action from argmax of clustering scores""" | |
## coref_new_scores : [B,E+1] | |
num_ents = coref_new_scores.shape[-1] - 1 | |
pred_max_idx = torch.argmax(coref_new_scores, dim=-1).tolist() ## [B] | |
action_str = ["c" if idx < num_ents else "o" for idx in pred_max_idx] | |
return zip(pred_max_idx, action_str) | |
def assign_cluster(coref_new_scores: Tensor) -> Tuple[int, str]: | |
"""Decode the action from argmax of clustering scores""" | |
num_ents = coref_new_scores.shape[0] - 1 | |
pred_max_idx = torch.argmax(coref_new_scores).item() | |
if pred_max_idx < num_ents: | |
# Coref | |
return pred_max_idx, "c" | |
else: | |
# New cluster | |
return num_ents, "o" | |
def coref_update( | |
self, ment_emb: Tensor, mem_vectors: Tensor, cell_idx: int, ent_counter: Tensor | |
) -> Tensor: | |
"""Updates the cluster representation given the new mention representation.""" | |
if self.config.entity_rep == "learned_avg": | |
alpha_wt = torch.sigmoid( | |
self.alpha(torch.cat([mem_vectors[cell_idx], ment_emb], dim=0)) | |
) | |
coref_vec = alpha_wt * mem_vectors[cell_idx] + (1 - alpha_wt) * ment_emb | |
elif self.config.entity_rep == "max": | |
coref_vec = torch.max(mem_vectors[cell_idx], ment_emb) | |
else: | |
cluster_count = ent_counter[cell_idx].item() | |
coref_vec = (mem_vectors[cell_idx] * cluster_count + ment_emb) / ( | |
cluster_count + 1 | |
) | |
return coref_vec | |