{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "Gender Over Time", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "mGDHOsFEIvKY" }, "source": [ "# [pair.withgoogle.com/explorables/fill-in-the-blank](https://pair.withgoogle.com/explorables/fill-in-the-blank)\n", "\n", "`Runtime -> Run all` to generate the the plots in the \"Appendix: Differences Over Time\" section. \n", "\n", "In addition to the difference between sentence 0 and sentence 1, the logits of the top tokens over time for sentence 0 and sentence 1 are also shown here. " ] }, { "cell_type": "markdown", "metadata": { "id": "ULz91t5Mfsfh" }, "source": [ "# Helpers" ] }, { "cell_type": "code", "metadata": { "id": "OQvEH3U6Q_OE" }, "source": [ "%%capture\n", "\n", "import os\n", "import torch\n", "!pip install transformers\n", "from transformers import (BertForMaskedLM, BertTokenizer)\n", "import numpy as np\n", "import pandas as pd\n", "import IPython\n", "from google.colab import output" ], "execution_count": 1, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "9bKnXE1DRAvx" }, "source": [ "%%capture\n", "\n", "modelpath_bert_large = \"bert-large-uncased\"\n", "tokenizer = BertTokenizer.from_pretrained(modelpath_bert_large)\n", "model = BertForMaskedLM.from_pretrained(modelpath_bert_large)\n", "model.eval()\n", "\n", "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", "model = model.to(device)" ], "execution_count": 2, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "yoRggB_YgVgB" }, "source": [ "" ], "execution_count": 2, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "3YsB2WUJfu3i" }, "source": [ "def calcYearEmbeds(sentence):\n", " sentenceTokens = []\n", " for year in range(minYear, maxYear):\n", " sentenceTokens.append(tokenizer.encode(sentence.replace('YEAR', str(year))))\n", "\n", " inputs = torch.tensor(sentenceTokens).to(device)\n", " outputs = model(inputs)\n", " embeds = outputs[0].cpu().detach().numpy()\n", "\n", " index_of_mask = sentenceTokens[0].index(103)\n", " return np.take(embeds, index_of_mask, axis=1)" ], "execution_count": 3, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "lHxngLxBJgSK" }, "source": [ "def calcTopTokens(e0, e1):\n", " # Merge e0 and e1 into a df; \n", " df = pd.DataFrame({'e0': e0.flatten(), 'e1': e1.flatten()})\n", " df['dif'] = df['e0'] - df['e1']\n", "\n", " # Calculate year and token_index based on index \n", " df.reset_index(inplace=True)\n", " df['token_index'] = df['index'].mod(30522)\n", " df['year_index'] = df['index'].div(30522).apply(np.floor)\n", "\n", " # Group by token_index \n", " # Sentences rank tokens separately so the less likely sentence will still include its outliers\n", " by_token = df.groupby('token_index')[['e0', 'e1']].mean()\n", " by_token['i0'] = by_token['e0'].rank(ascending=False)\n", " by_token['i1'] = by_token['e1'].rank(ascending=False)\n", " by_token['i_combined_min'] = by_token[['i0','i1']].min(axis=1).rank()\n", " \n", " top_tokens = by_token.loc[by_token['i_combined_min'] < 150]\n", "\n", " return df.loc[df['token_index'].isin(top_tokens.index)]\n" ], "execution_count": 4, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "81VpA88LmuuV" }, "source": [ "" ], "execution_count": 4, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Yb3jJxcyfdwE" }, "source": [ "HTML_DEV_TEMPLATE = '''\n", " \n", " \n", "
\n", "\n", " \n", " \n", " \n", "'''\n", "\n", "HTML_TEMPLATE = '''\n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", "'''" ], "execution_count": 11, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "VLQLT18HtaYU" }, "source": [ " " ], "execution_count": 5, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "hMqMXXR1fgr3" }, "source": [ " # Edit s0 and s1 to see other differences over time\n" ] }, { "cell_type": "code", "metadata": { "id": "n3i4tXy7eV9z", "colab": { "base_uri": "https://localhost:8080/", "height": 405 }, "outputId": "b39f8887-7921-49f8-81b2-f32fe11b49dd" }, "source": [ "s0 = 'In YEAR, he was arrested for [MASK].'\n", "s1 = 'In YEAR, she was arrested for [MASK].'\n", "\n", "minYear = 1860 # min 1707, \"1706\" token not in BERT vocab.\n", "maxYear = 2018 # max 2022, BERT was trained in 2018.\n", "\n", "e0 = calcYearEmbeds(s0)\n", "e1 = calcYearEmbeds(s1)\n", "\n", "out_df = calcTopTokens(e0, e1)\n", "tidyCSV = out_df[['e0', 'e1', 'token_index', 'year_index']].to_csv(index=False)\n", "js_data = {'minYear': minYear, 'maxYear': maxYear, 's0': s0, 's1': s1, 'tidyCSV': tidyCSV}\n", "IPython.display.display(IPython.display.HTML(HTML_TEMPLATE.format(js_data=js_data)))" ], "execution_count": 12, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", "\n", " \n", " \n", " \n" ], "text/plain": [ "