MEIRa / model /memory /base_memory.py
KawshikManikantan's picture
commit log print
a7f75f6
import torch
import torch.nn as nn
from pytorch_utils.modules import MLP
import math
from omegaconf import DictConfig
from typing import Dict, Tuple
from torch import Tensor
LOG2 = math.log(2)
class BaseMemory(nn.Module):
"""Base clustering module."""
def __init__(self, config: DictConfig, span_emb_size: int, drop_module: nn.Module):
super(BaseMemory, self).__init__()
self.config = config
self.mem_size = span_emb_size
self.drop_module = drop_module
if self.config.sim_func == "endpoint":
num_embs = 2 # Span start, Span end
else:
num_embs = 3 # Span start, Span end, Hadamard product between the two
self.mem_coref_mlp = MLP(
num_embs * self.mem_size + config.num_feats * config.emb_size,
config.mlp_size,
1,
drop_module=drop_module,
num_hidden_layers=config.mlp_depth,
bias=True,
)
if config.entity_rep == "learned_avg":
# Parameter for updating the cluster representation
self.alpha = MLP(
2 * self.mem_size,
config.mlp_size,
1,
num_hidden_layers=1,
bias=True,
drop_module=drop_module,
)
if config.pseudo_dist:
self.distance_embeddings = nn.Embedding(
self.config.num_embeds + 1, config.emb_size
)
else:
self.distance_embeddings = nn.Embedding(
self.config.num_embeds, config.emb_size
)
self.counter_embeddings = nn.Embedding(self.config.num_embeds, config.emb_size)
@property
def device(self) -> torch.device:
return next(self.mem_coref_mlp.parameters()).device
def initialize_memory(
self,
mem: Tensor = None,
mem_init: Tensor = None,
ent_counter: Tensor = None,
last_mention_start: Tensor = None,
rep=[],
**kwargs
) -> Tuple[Tensor, Tensor, Tensor]:
"""Method to initialize the clusters and related bookkeeping variables."""
# Check for unintialized memory
if mem is None or ent_counter is None or last_mention_start is None:
mem = torch.zeros(len(rep), self.mem_size).to(self.device)
mem_init = torch.zeros(len(rep), self.mem_size).to(self.device)
for idx, rep_vec in enumerate(rep):
mem[idx] = rep_vec
mem_init[idx] = rep_vec
ent_counter = torch.tensor([1.0] * len(rep)).to(self.device)
last_mention_start = -torch.ones(len(rep)).long().to(self.device)
elif len(rep):
for rep_emb in rep:
mem = torch.cat([mem, rep_emb.unsqueeze(0).to(self.device)], dim=0)
mem_init = torch.cat(
[mem_init, rep_emb.unsqueeze(0).to(self.device)], dim=0
)
ent_counter = torch.cat(
[ent_counter, torch.tensor([1.0]).to(self.device)]
)
last_mention_start = torch.cat(
[last_mention_start, torch.tensor([-1]).to(self.device)]
)
return mem, mem_init, ent_counter, last_mention_start
@staticmethod
def get_bucket(count: Tensor) -> Tensor:
"""Bucket distance and entity counters using the same logic."""
logspace_idx = (
torch.floor(
torch.log(torch.max(count.float(), torch.tensor(1.0))) / LOG2
).long()
+ 3
)
use_identity = (count <= 4).long()
combined_idx = use_identity * count + (1 - use_identity) * logspace_idx
return torch.clamp(combined_idx, 0, 9)
@staticmethod
def get_distance_bucket(distances: Tensor) -> Tensor:
return BaseMemory.get_bucket(distances)
@staticmethod
def get_counter_bucket(count: Tensor) -> Tensor:
return BaseMemory.get_bucket(count)
def get_distance_emb(self, distance: Tensor) -> Tensor:
distance_tens = self.get_distance_bucket(distance)
distance_embs = self.distance_embeddings(distance_tens)
return distance_embs
def get_counter_emb(self, ent_counter: Tensor) -> Tensor:
counter_buckets = self.get_counter_bucket(ent_counter.long())
counter_embs = self.counter_embeddings(counter_buckets)
return counter_embs
@staticmethod
def get_coref_mask(ent_counter: Tensor) -> Tensor:
"""Mask for whether the cluster representation corresponds to any entity or not."""
cell_mask = (ent_counter > 0.0).float()
return cell_mask
def get_feature_embs_tensorized(
self,
ment_start: Tensor, ## [B]
last_mention_start: Tensor, ## [E]
ent_counter: Tensor, ## [E]
metadata: Dict, ## [Assuming no metadata]
):
## Return [B, E, 20]
## Get distance embeddings
distance_embs = self.distance_embeddings(
torch.tensor(self.config.num_embeds).long().to(self.device)
).repeat(
ment_start.shape[0], last_mention_start.shape[0], 1
) ## [B, D, 20]
## Get counter embeddings
ent_counter_batch = ent_counter.unsqueeze(0).repeat(
ment_start.shape[0], 1
) ## [B, E]
counter_embs = self.get_counter_emb(ent_counter_batch) ## [B, E, 20]
feature_embs_list = [distance_embs, counter_embs]
feature_embs = self.drop_module(torch.cat(feature_embs_list, dim=-1))
return feature_embs
def get_feature_embs(
self,
ment_start: Tensor,
last_mention_start: Tensor,
ent_counter: Tensor,
metadata: Dict,
) -> Tensor:
distance_embs = self.get_distance_emb(ment_start - last_mention_start)
if self.config.pseudo_dist:
rep_distance_mask = (last_mention_start < 0).unsqueeze(1).float()
rep_distance_embs = self.distance_embeddings(
torch.tensor(self.config.num_embeds).long().to(self.device)
).repeat(last_mention_start.shape[0], 1)
distance_embs = (
distance_embs * (1 - rep_distance_mask)
+ rep_distance_embs * rep_distance_mask
)
counter_embs = self.get_counter_emb(ent_counter)
feature_embs_list = [distance_embs, counter_embs]
if "genre" in metadata:
genre_emb = metadata["genre"]
num_ents = distance_embs.shape[0]
genre_emb = torch.unsqueeze(genre_emb, dim=0).repeat(num_ents, 1)
feature_embs_list.append(genre_emb)
feature_embs = self.drop_module(torch.cat(feature_embs_list, dim=-1))
return feature_embs
def get_coref_new_scores_tensorized(
self,
ment_emb: Tensor, ## [B,D]
mem_vectors: Tensor, ## [E,D]
mem_vectors_init: Tensor, ## [E,D] ## Not used here
ent_counter: Tensor, ## not used here
feature_embs: Tensor, ## [B,E,20]
) -> Tensor:
rep_ment_emb = ment_emb.unsqueeze(1).repeat(
1, mem_vectors.shape[0], 1
) ## [B,E,D]
rep_mem_vectors = mem_vectors.unsqueeze(0).repeat(
ment_emb.shape[0], 1, 1
) ## [B,E,D]
pair_vec = torch.cat(
[
rep_mem_vectors,
rep_ment_emb,
rep_mem_vectors * rep_ment_emb,
feature_embs,
],
dim=-1,
) ## [B,E,3D+20]
# print(pair_vec)
pair_score = self.mem_coref_mlp(pair_vec)
coref_score = torch.squeeze(pair_score, dim=-1) # [B,E]
# zero_col = torch.zeros(coref_score.shape[0], 1).to(self.device)
base_col = (
torch.ones(coref_score.shape[0], 1).to(self.device) * self.config.thresh
)
coref_new_score = torch.cat([coref_score, base_col], dim=-1) ## [B,E+1]
return coref_new_score
def get_coref_new_scores(
self,
ment_emb: Tensor,
mem_vectors: Tensor,
mem_vectors_init: Tensor,
ent_counter: Tensor,
feature_embs: Tensor,
) -> Tensor:
"""Calculate the coreference score with existing clusters.
For creating a new cluster we use a dummy score of 0.
This is a free variable and this idea is borrowed from Lee et al 2017
Args:
ment_emb (d'): Mention representation
mem_vectors (M x d'): Cluster representations
ent_counter (M): Mention counter of clusters.
feature_embs (M x p): Embedding of features such as distance from last
mention of the cluster.
Returns:
coref_new_score (M + 1):
Coref scores concatenated with the score of forming a new cluster.
"""
# Repeat the query vector for comparison against all cells
num_ents = mem_vectors.shape[0]
rep_ment_emb = ment_emb.repeat(num_ents, 1) # M x H
# Coref Score
if self.config.sim_func == "endpoint":
pair_vec = torch.cat([mem_vectors, rep_ment_emb, feature_embs], dim=-1)
pair_score = self.mem_coref_mlp(pair_vec)
if self.config.type == "hybrid":
## Adding pairwise similarity with initial memory
pair_vec_init = torch.cat(
[mem_vectors_init, rep_ment_emb, feature_embs], dim=-1
)
pair_score_init = self.mem_coref_mlp(pair_vec_init)
pair_score = pair_score + pair_score_init
else:
## Pairwise similarity score generated with mem. mem is dynamic when type is not static
pair_vec = torch.cat(
[mem_vectors, rep_ment_emb, mem_vectors * rep_ment_emb, feature_embs],
dim=-1,
)
pair_score = self.mem_coref_mlp(pair_vec)
if self.config.type == "hybrid":
## Adding pairwise similarity with initial memory
pair_vec_init = torch.cat(
[
mem_vectors_init,
rep_ment_emb,
mem_vectors_init * rep_ment_emb,
feature_embs,
],
dim=-1,
)
pair_score_init = self.mem_coref_mlp(pair_vec_init) ## Static score
pair_score = (
pair_score + pair_score_init
) ## Similarity score with current repr. and initial repr.
coref_score = torch.squeeze(pair_score, dim=-1) # M
coref_new_mask = torch.cat(
[self.get_coref_mask(ent_counter), torch.tensor([1.0], device=self.device)],
dim=0,
)
# Use a dummy score of 0 for froming a new cluster
# print("Threshold: ", self.config.thresh)
coref_new_score = torch.cat(
([coref_score, torch.tensor([self.config.thresh], device=self.device)]),
dim=0,
)
coref_new_score = coref_new_score * coref_new_mask + (1 - coref_new_mask) * (
-1e4
)
return coref_new_score
@staticmethod
def assign_cluster_tensorized(coref_new_scores: Tensor) -> Tuple[int, str]:
"""Decode the action from argmax of clustering scores"""
## coref_new_scores : [B,E+1]
num_ents = coref_new_scores.shape[-1] - 1
pred_max_idx = torch.argmax(coref_new_scores, dim=-1).tolist() ## [B]
action_str = ["c" if idx < num_ents else "o" for idx in pred_max_idx]
return zip(pred_max_idx, action_str)
@staticmethod
def assign_cluster(coref_new_scores: Tensor) -> Tuple[int, str]:
"""Decode the action from argmax of clustering scores"""
num_ents = coref_new_scores.shape[0] - 1
pred_max_idx = torch.argmax(coref_new_scores).item()
if pred_max_idx < num_ents:
# Coref
return pred_max_idx, "c"
else:
# New cluster
return num_ents, "o"
def coref_update(
self, ment_emb: Tensor, mem_vectors: Tensor, cell_idx: int, ent_counter: Tensor
) -> Tensor:
"""Updates the cluster representation given the new mention representation."""
if self.config.entity_rep == "learned_avg":
alpha_wt = torch.sigmoid(
self.alpha(torch.cat([mem_vectors[cell_idx], ment_emb], dim=0))
)
coref_vec = alpha_wt * mem_vectors[cell_idx] + (1 - alpha_wt) * ment_emb
elif self.config.entity_rep == "max":
coref_vec = torch.max(mem_vectors[cell_idx], ment_emb)
else:
cluster_count = ent_counter[cell_idx].item()
coref_vec = (mem_vectors[cell_idx] * cluster_count + ment_emb) / (
cluster_count + 1
)
return coref_vec