from os import environ, path from transformers import BertTokenizerFast, FlaxAutoModel import jax.numpy as jnp import jax from flax.training.train_state import TrainState import pandas as pd from tyrec.trainer import BaseTrainer, loss, HFConfig from tyrec.recommendations.model import RecommendationModel from tyrec.utils import compute_mean, logger from tyrec.evaluator import RetrivalEvaluator from tyrec.utils import trange class RecommendationsTrainer(BaseTrainer): def __init__( self, hf_config: HFConfig, data_dir, event_to_train, threshold=0, dimensions=0, other_features=False, *args, **kwargs, ): self.model_name = hf_config.model_name self.data_dir = data_dir self.query_prompt = hf_config.query_prompt self.doc_prompt = hf_config.doc_prompt self.event_to_train = event_to_train self.other_features = other_features self.text_encoder = FlaxAutoModel.from_pretrained(self.model_name) self.tokenizer = BertTokenizerFast.from_pretrained(self.model_name) self.base_path = ( "/content/drive/MyDrive/" if "COLAB_ENV" in environ and environ["COLAB_ENV"] == "true" else "./" ) self.train_file = path.join(self.base_path, data_dir, "train.jsonl") self.eval_file = path.join(self.base_path, data_dir, "test.jsonl") self.items_file = path.join(self.base_path, data_dir, "items.jsonl") self.item_embeddings_file = path.join( self.base_path, data_dir, "item_embeds.jsonl" ) self.train_user_embeddings_file = path.join( self.base_path, data_dir, "train_user_embeds.jsonl" ) self.test_user_embeddings_file = path.join( self.base_path, data_dir, "test_user_embeds.jsonl" ) self.dimensions = ( dimensions if dimensions > 0 else self.text_encoder.config.hidden_size ) model = RecommendationModel(self.dimensions) super().__init__(*args, model=model, **kwargs) self.dataloader = pd.read_json(self.train_file, lines=True) self.dataloader = self.dataloader[ self.dataloader["event"] == self.event_to_train ].reset_index(drop=True) self.test_dataset = pd.read_json(self.eval_file, lines=True) self.test_dataset = self.test_dataset[ self.test_dataset["event"] == self.event_to_train ].reset_index(drop=True) unique_did = ( pd.concat([self.dataloader, self.test_dataset], ignore_index=True)["did"] .unique() .tolist() ) self.items = pd.read_json(self.items_file, lines=True) self.items = self.items[self.items["did"].isin(unique_did)].reset_index( drop=True ) self.threshold = threshold self.evaluator: RetrivalEvaluator | None = None self.item_embeds = [] self.train_user_embeds = [] self.train_label_embeds = [] self.test_user_embeds = [] self.test_label_embeds = [] self.rng = jax.random.PRNGKey(0) def embed_items(self, examples): texts = [self.doc_prompt + x for x in examples["text"].tolist()] tokens = self.tokenizer( texts, truncation=True, padding="max_length", return_tensors="jax", max_length=70, ) embeddings = self.text_encoder(**tokens).last_hidden_state embeddings = self.mean_pooling(embeddings, tokens["attention_mask"]) embeddings = embeddings / jnp.linalg.norm(embeddings, axis=-1, keepdims=True) if self.other_features: embeddings = [embeddings[i] for i in range(embeddings.shape[0])] features = examples["features"].tolist() for i in range(len(embeddings)): embeddings[i] = jnp.concatenate([embeddings[i], jnp.array(features[i])]) embeddings = jnp.array(embeddings) return [embeddings[i] for i in range(embeddings.shape[0])] def embed_events(self, df): user_vecs = [] label_vecs = [] for _, row in df.iterrows(): label = self.items[self.items["did"] == row["label"]["did"]].index.tolist()[ 0 ] history = [x["did"] for x in row["data"]] multi_hot = [0] * len(self.item_embeds) indexes = ( self.items[self.items["did"].isin(history)] .index.reindex( self.items[self.items["did"].isin(history)]["did"] .map(dict(zip(history, range(len(history))))) .sort_values() .index )[0] .tolist() ) for idx in indexes: multi_hot[idx] = 1 multi_hot = jnp.array(multi_hot) user_vecs.append(compute_mean(self.item_embeds, multi_hot)) label_vecs.append(label) return jnp.array(user_vecs), jnp.array(label_vecs) def group_events(self, df): df = df.sort_values(["sid", "ts"]) def group_to_dict_array(to_dict): return to_dict.drop(["sid", "event"], axis=1).to_dict("records") grouped_data = [] for (sid,), group in df.groupby(["sid"]): data = group_to_dict_array(group) if len(data) > 2: grouped_data.append( { "sid": sid, "data": data[:-1], "label": data[-1], } ) grouped_data = pd.DataFrame(grouped_data) return grouped_data @staticmethod def users_to_sessions(file_path, threshold): df = pd.read_json(file_path, lines=True) if threshold > 0: def create_intervals(group): group = group.copy() group["time_diff"] = group["ts"].diff() group["interval"] = (group["time_diff"] > threshold).cumsum() return group.drop("time_diff", axis=1) df_list = [create_intervals(group) for _, group in df.groupby("sid")] df = pd.concat(df_list, ignore_index=True) else: df["interval"] = 0 return df def load_item_embeddings(self): item_embeds = pd.read_json(self.item_embeddings_file, lines=True) item_with_embeds = pd.merge(self.items, item_embeds, on="did", how="left") return [jnp.array(x) for x in item_with_embeds["embed"]] def load_user_embeddings(self, df, file_path): user_embeds = pd.read_json(file_path, lines=True) user_with_embeds = pd.merge(df, user_embeds, on="sid", how="left") return jnp.array([jnp.array(x) for x in user_with_embeds["embed"]]), jnp.array( [ self.items[self.items["did"] == x["did"]].index.tolist()[0] for x in user_with_embeds["label"] ] ) def setup(self): corpus = { f"{self.items.loc[x]['did']}": self.items.loc[x]["text"] for x in range(len(self.items)) } if path.exists(self.item_embeddings_file): logger.info("Found a saved item embedding file...") self.item_embeds = self.load_item_embeddings() else: for start in trange(0, len(self.items), 128, desc="Embedding items"): end = min(start + 128, len(self.items)) e = self.embed_items(self.items.loc[start : end - 1]) self.item_embeds.extend(e) self.item_embeds = jnp.array(self.item_embeds) self.dataloader = self.group_events(self.dataloader) self.dataset_len = len(self.dataloader) self.test_dataset = self.group_events(self.test_dataset) self.dataset_len = len(self.dataloader) if path.exists(self.train_user_embeddings_file): logger.info("Found a saved train embedding file...") self.train_user_embeds, self.train_label_embeds = self.load_user_embeddings( self.dataloader, self.train_user_embeddings_file ) else: self.train_user_embeds, self.train_label_embeds = self.embed_events( self.dataloader ) if path.exists(self.test_user_embeddings_file): logger.info("Found a saved test embedding file...") self.test_user_embeds, self.test_label_embeds = self.load_user_embeddings( self.test_dataset, self.test_user_embeddings_file ) else: self.test_user_embeds, self.test_label_embeds = self.embed_events( self.test_dataset ) users = { f"{self.test_dataset.loc[x]['sid']}": self.test_dataset.loc[x]["sid"] for x in range(len(self.test_dataset)) } relevant_docs = { f"{self.test_dataset.loc[x]['sid']}": [ f"{self.test_dataset.loc[x]['label']['did']}" ] for x in range(len(self.test_dataset)) } self.evaluator = RetrivalEvaluator( queries=users, corpus=corpus, relevant_docs=relevant_docs, corpus_chunk_size=40000, batch_size=512, show_progress_bar=True, ) def get_initial_params(self): batch = jnp.array([jnp.zeros(self.text_encoder.config.hidden_size)]) params = self.model.init(jax.random.PRNGKey(0), batch, batch, training=False) return params["params"] def train_step(self, _batch, start, end): self.rng, rng = jax.random.split(self.rng) batch = jax.random.permutation(rng, jnp.array(self.train_user_embeds))[ start:end ] labels = jax.random.permutation(rng, jnp.array(self.train_label_embeds))[ start:end ] user_vec = jnp.array(batch) items_vec = jnp.array(self.item_embeds) state, l = train_step(self.state, user_vec, items_vec, labels, rng) q, d = self.model.apply( {"params": self.state.params}, jnp.array(self.test_user_embeds), jnp.array(self.item_embeds), training=False, rngs=rng, ) q = q / jnp.linalg.norm(q, axis=1, keepdims=True) d = d / jnp.linalg.norm(d, axis=1, keepdims=True) val_l = loss.sparse_categorical_cross_entropy( q, d, self.test_label_embeds, ) self.val_loss.append(val_l) return state, l def eval_step(self): q, d = self.model.apply( {"params": self.state.params}, jnp.array(self.test_user_embeds), jnp.array(self.item_embeds), training=False, ) q = q / jnp.linalg.norm(q, axis=1, keepdims=True) d = d / jnp.linalg.norm(d, axis=1, keepdims=True) self.evaluator(query_embeddings=q, corpus_embeddings=d, metrics=["recall"]) @jax.jit def train_step(state: TrainState, user_embeds, item_embeds, labels, rng): def loss_fn(params): u = user_embeds / jnp.linalg.norm(user_embeds, axis=-1, keepdims=True) i = item_embeds / jnp.linalg.norm(item_embeds, axis=-1, keepdims=True) u, i = state.apply_fn( {"params": params}, u, i, training=True, rngs={"dropout": rng}, ) u = u / jnp.linalg.norm(u, axis=-1, keepdims=True) i = i / jnp.linalg.norm(i, axis=-1, keepdims=True) l = loss.sparse_categorical_cross_entropy(u, i, labels) return l grad_fn = jax.value_and_grad(loss_fn) l, grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) return state, l