{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "\n",
    "os.environ['TORCH_LOGS'] = '+dynamic'\n",
    "import pylab as pl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n",
      "  WeightNorm.apply(module, name, dim)\n",
      "/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:123: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "CustomAlbert.forward() got an unexpected keyword argument 'attention_mask'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[3], line 10\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mkokoro\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m generate\n\u001b[1;32m      9\u001b[0m text \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHow could I know? It\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ms an unanswerable question. Like asking an unborn child if they\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mll lead a good life. They haven\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt even been born.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 10\u001b[0m audio, out_ps \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvoicepack\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     12\u001b[0m \u001b[38;5;66;03m# 4️⃣ Display the 24khz audio and print the output phonemes\u001b[39;00m\n\u001b[1;32m     13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mIPython\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdisplay\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m display, Audio\n",
      "File \u001b[0;32m~/Projects/DeepLearning/TTS/Kokoro-82M/kokoro.py:147\u001b[0m, in \u001b[0;36mgenerate\u001b[0;34m(model, text, voicepack, lang, speed)\u001b[0m\n\u001b[1;32m    145\u001b[0m     \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTruncated to 510 tokens\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m    146\u001b[0m ref_s \u001b[38;5;241m=\u001b[39m voicepack[\u001b[38;5;28mlen\u001b[39m(tokens)]\n\u001b[0;32m--> 147\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mref_s\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mspeed\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    148\u001b[0m ps \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mnext\u001b[39m(k \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m VOCAB\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m==\u001b[39m v) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tokens)\n\u001b[1;32m    149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out, ps\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m    114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    115\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/Projects/DeepLearning/TTS/Kokoro-82M/kokoro.py:119\u001b[0m, in \u001b[0;36mforward\u001b[0;34m(model, tokens, ref_s, speed)\u001b[0m\n\u001b[1;32m    117\u001b[0m input_lengths \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mLongTensor([tokens\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]])\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m    118\u001b[0m text_mask \u001b[38;5;241m=\u001b[39m length_to_mask(input_lengths)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m--> 119\u001b[0m bert_dur \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbert\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m~\u001b[39;49m\u001b[43mtext_mask\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mint\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    120\u001b[0m d_en \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mbert_encoder(bert_dur)\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m    121\u001b[0m s \u001b[38;5;241m=\u001b[39m ref_s[:, \u001b[38;5;241m128\u001b[39m:]\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "\u001b[0;31mTypeError\u001b[0m: CustomAlbert.forward() got an unexpected keyword argument 'attention_mask'"
     ]
    }
   ],
   "source": [
    "from models import build_model\n",
    "import torch\n",
    "device = \"cpu\" #'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "model = build_model('kokoro-v0_19.pth', device)\n",
    "voicepack = torch.load('voices/af.pt', weights_only=True).to(device)\n",
    "\n",
    "# 3️⃣ Call generate, which returns a 24khz audio waveform and a string of output phonemes\n",
    "from kokoro import generate\n",
    "text = \"How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born.\"\n",
    "audio, out_ps = generate(model, text, voicepack)\n",
    "\n",
    "# 4️⃣ Display the 24khz audio and print the output phonemes\n",
    "from IPython.display import display, Audio\n",
    "display(Audio(data=audio, rate=24000, autoplay=True))\n",
    "print(out_ps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 640, 348])\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hˌaʊ kʊd aɪ nˈoʊ? ɪts ɐn ʌnˈænsɚɹəbəl kwˈɛstʃən. lˈaɪk ˈæskɪŋ ɐn ʌnbˈɔːɹn tʃˈaɪld ɪf ðeɪl lˈiːd ɐ ɡˈʊd lˈaɪf. ðeɪ hˈævənt ˈiːvən bˌɪn bˈɔːɹn.\n"
     ]
    }
   ],
   "source": [
    "from kokoro import phonemize, tokenize, length_to_mask\n",
    "import torch.nn.functional as F\n",
    "model = model\n",
    "speed = 1.\n",
    "\n",
    "ps = phonemize(text, \"a\")\n",
    "tokens = tokenize(ps)\n",
    "\n",
    "tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)\n",
    "\n",
    "# tokens = torch.nn.functional.pad(tokens, (0, 510 - tokens.shape[-1]))\n",
    "input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
    "\n",
    "text_mask = length_to_mask(input_lengths).to(device)\n",
    "bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
    "\n",
    "\n",
    "d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
    "\n",
    "ref_s =voicepack[tokens.shape[1]]\n",
    "s = ref_s[:, 128:]\n",
    "\n",
    "d = model.predictor.text_encoder.inference(d_en, s)\n",
    "x, _ = model.predictor.lstm(d)\n",
    "\n",
    "duration = model.predictor.duration_proj(x)\n",
    "duration = torch.sigmoid(duration).sum(axis=-1) / speed\n",
    "pred_dur = torch.round(duration).clamp(min=1).long()\n",
    "max_mels = pred_dur.sum().item()\n",
    "pred_aln_trg = torch.zeros(input_lengths, max_mels)\n",
    "\n",
    "c_start = F.pad(pred_dur,(1,0), \"constant\").cumsum(dim=1)[0,0:-1]\n",
    "c_end = c_start + pred_dur[0,:]\n",
    "\n",
    "for row, cs, ce in zip(pred_aln_trg, c_start, c_end):\n",
    "    row[cs:ce] = 1\n",
    "    \n",
    "en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)\n",
    "print(en.shape)\n",
    "F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
    "t_en = model.text_encoder.inference(tokens)\n",
    "asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)\n",
    "output = model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().detach().cpu().numpy()\n",
    "\n",
    "from IPython.display import display, Audio\n",
    "display(Audio(data=output, rate=24000, autoplay=True))\n",
    "print(out_ps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "ename": "Error",
     "evalue": "Unable to infer type of dictionary: Cannot infer concrete type of torch.nn.Module",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mError\u001b[0m                                     Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m scrpt \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjit\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscript\u001b[49m\u001b[43m(\u001b[49m\u001b[43mMODEL\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/jit/_script.py:1429\u001b[0m, in \u001b[0;36mscript\u001b[0;34m(obj, optimize, _frames_up, _rcb, example_inputs)\u001b[0m\n\u001b[1;32m   1427\u001b[0m prev \u001b[38;5;241m=\u001b[39m _TOPLEVEL\n\u001b[1;32m   1428\u001b[0m _TOPLEVEL \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m-> 1429\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43m_script_impl\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1430\u001b[0m \u001b[43m    \u001b[49m\u001b[43mobj\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1431\u001b[0m \u001b[43m    \u001b[49m\u001b[43moptimize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1432\u001b[0m \u001b[43m    \u001b[49m\u001b[43m_frames_up\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_frames_up\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1433\u001b[0m \u001b[43m    \u001b[49m\u001b[43m_rcb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_rcb\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1434\u001b[0m \u001b[43m    \u001b[49m\u001b[43mexample_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexample_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1435\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1437\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m prev:\n\u001b[1;32m   1438\u001b[0m     log_torchscript_usage(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mscript\u001b[39m\u001b[38;5;124m\"\u001b[39m, model_id\u001b[38;5;241m=\u001b[39m_get_model_id(ret))\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/jit/_script.py:1154\u001b[0m, in \u001b[0;36m_script_impl\u001b[0;34m(obj, optimize, _frames_up, _rcb, example_inputs)\u001b[0m\n\u001b[1;32m   1151\u001b[0m     obj \u001b[38;5;241m=\u001b[39m obj\u001b[38;5;241m.\u001b[39m__prepare_scriptable__() \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(obj, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m__prepare_scriptable__\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m obj  \u001b[38;5;66;03m# type: ignore[operator]\u001b[39;00m\n\u001b[1;32m   1153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obj, \u001b[38;5;28mdict\u001b[39m):\n\u001b[0;32m-> 1154\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m create_script_dict(obj)\n\u001b[1;32m   1155\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obj, \u001b[38;5;28mlist\u001b[39m):\n\u001b[1;32m   1156\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m create_script_list(obj)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/jit/_script.py:1066\u001b[0m, in \u001b[0;36mcreate_script_dict\u001b[0;34m(obj)\u001b[0m\n\u001b[1;32m   1053\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_script_dict\u001b[39m(obj):\n\u001b[1;32m   1054\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m   1055\u001b[0m \u001b[38;5;124;03m    Create a ``torch._C.ScriptDict`` instance with the data from ``obj``.\u001b[39;00m\n\u001b[1;32m   1056\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1064\u001b[0m \u001b[38;5;124;03m        zero copy overhead.\u001b[39;00m\n\u001b[1;32m   1065\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1066\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mScriptDict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mError\u001b[0m: Unable to infer type of dictionary: Cannot infer concrete type of torch.nn.Module"
     ]
    }
   ],
   "source": [
    "scrpt = torch.jit.script(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 143, 640])\n",
      "torch.Size([1, 640, 143])\n",
      "torch.Size([1, 512, 143])\n",
      "torch.Size([1, 143, 348])\n",
      "en.shape=torch.Size([1, 640, 348])\n",
      "s.shape=torch.Size([1, 128])\n",
      "en.dtype=torch.float32\n",
      "s.dtype=torch.float32\n",
      "torch.Size([1, 512, 143])\n",
      "torch.Size([1, 512, 348])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fbbc0db3220>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(d.shape)\n",
    "print(d.transpose(-1,-2).shape)\n",
    "print(d_en.shape)\n",
    "print(pred_aln_trg.unsqueeze(0).shape)\n",
    "print(f\"{en.shape=}\")\n",
    "print(f\"{s.shape=}\")\n",
    "print(f\"{en.dtype=}\")\n",
    "print(f\"{s.dtype=}\")\n",
    "\n",
    "print(t_en.shape)\n",
    "print(asr.shape)\n",
    "pl.imshow(pred_aln_trg[:,:])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Export StyleTTS2 model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0101 18:56:11.998000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3557] create_symbol s0 = 143 for L['args'][0][0].size()[1] [2, 510] (_export/non_strict_utils.py:109 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"s0\"\n",
      "I0101 18:56:12.007000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:4857] set_replacement s0 = 143 (range_refined_to_singleton) VR[143, 143]\n",
      "I0101 18:56:12.008000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:5106] eval Eq(s0, 143) [guard added] (mp/ipykernel_2488298/2554868606.py:17 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"Eq(s0, 143)\"\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0101 18:56:27.383000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3317] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:390 in local_scalar_dense)\n",
      "I0101 18:56:27.387000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:5106] runtime_assert u0 >= 0 [guard added] (_refs/__init__.py:4957 in arange), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"u0 >= 0\"\n",
      "W0101 18:56:28.575000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:5124] failed during evaluate_expr(u0, hint=None, size_oblivious=False, forcing_spec=False\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] failed while running evaluate_expr(*(u0, None), **{'fx_node': False})\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] Traceback (most recent call last):\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298]   File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/recording.py\", line 262, in wrapper\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298]     return retlog(fn(*args, **kwargs))\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298]   File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py\", line 5122, in evaluate_expr\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298]     return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298]   File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py\", line 5238, in _evaluate_expr\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298]     raise self._make_data_dependent_error(\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: u0)\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] \n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] Potential framework code culprit (scroll up for full backtrace):\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298]   File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298]     return self._op_dk(dk, *args, **kwargs)\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] \n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For more information, run with TORCH_LOGS=\"dynamic\"\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] \n",
      "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n"
     ]
    },
    {
     "ename": "GuardOnDataDependentSymNode",
     "evalue": "Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: u0)\n\nPotential framework code culprit (scroll up for full backtrace):\n  File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n    return self._op_dk(dk, *args, **kwargs)\n\nFor more information, run with TORCH_LOGS=\"dynamic\"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n\nThe following call raised this error:\n  File \"/rhome/eingerman/Projects/DeepLearning/TTS/Kokoro-82M/models.py\", line 471, in F0Ntrain\n    x2, _temp = self.shared(x1)\n\nTo fix the error, insert one of the following checks before this call:\n  1. torch._check(x.shape[2])\n  2. torch._check(~x.shape[2])\n\n(These suggested fixes were derived by replacing `u0` with x.shape[2] or x1.shape[1] in u0 and its negation.)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mGuardOnDataDependentSymNode\u001b[0m               Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[39], line 61\u001b[0m\n\u001b[1;32m     58\u001b[0m dynamic_shapes \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokens\u001b[39m\u001b[38;5;124m\"\u001b[39m:{\u001b[38;5;241m0\u001b[39m:batch, \u001b[38;5;241m1\u001b[39m:token_len}}\n\u001b[1;32m     60\u001b[0m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[0;32m---> 61\u001b[0m export_mod \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstyle_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m     62\u001b[0m \u001b[38;5;66;03m# export_mod = torch.export.export(style_model, args=( tokens, ), strict=False)\u001b[39;00m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/__init__.py:270\u001b[0m, in \u001b[0;36mexport\u001b[0;34m(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature)\u001b[0m\n\u001b[1;32m    264\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(mod, torch\u001b[38;5;241m.\u001b[39mjit\u001b[38;5;241m.\u001b[39mScriptModule):\n\u001b[1;32m    265\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m    266\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExporting a ScriptModule is not supported. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    267\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMaybe try converting your ScriptModule to an ExportedProgram \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    268\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124musing `TS2EPConverter(mod, args, kwargs).convert()` instead.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    269\u001b[0m     )\n\u001b[0;32m--> 270\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_export\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    271\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    272\u001b[0m \u001b[43m    \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    273\u001b[0m \u001b[43m    \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    274\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    275\u001b[0m \u001b[43m    \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstrict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    276\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    277\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    278\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1017\u001b[0m, in \u001b[0;36m_log_export_wrapper.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m   1010\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1011\u001b[0m         log_export_usage(\n\u001b[1;32m   1012\u001b[0m             event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.error.unclassified\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m   1013\u001b[0m             \u001b[38;5;28mtype\u001b[39m\u001b[38;5;241m=\u001b[39merror_type,\n\u001b[1;32m   1014\u001b[0m             message\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mstr\u001b[39m(e),\n\u001b[1;32m   1015\u001b[0m             flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m   1016\u001b[0m         )\n\u001b[0;32m-> 1017\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m   1018\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m   1019\u001b[0m     _EXPORT_FLAGS \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:990\u001b[0m, in \u001b[0;36m_log_export_wrapper.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    988\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    989\u001b[0m     start \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m--> 990\u001b[0m     ep \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    991\u001b[0m     end \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m    992\u001b[0m     log_export_usage(\n\u001b[1;32m    993\u001b[0m         event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.time\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m    994\u001b[0m         metrics\u001b[38;5;241m=\u001b[39mend \u001b[38;5;241m-\u001b[39m start,\n\u001b[1;32m    995\u001b[0m         flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m    996\u001b[0m         \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mget_ep_stats(ep),\n\u001b[1;32m    997\u001b[0m     )\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/exported_program.py:114\u001b[0m, in \u001b[0;36m_disable_prexisiting_fake_mode.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    111\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(fn)\n\u001b[1;32m    112\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    113\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m unset_fake_temporarily():\n\u001b[0;32m--> 114\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1880\u001b[0m, in \u001b[0;36m_export\u001b[0;34m(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature, pre_dispatch, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m   1877\u001b[0m \u001b[38;5;66;03m# Call the appropriate export function based on the strictness of tracing.\u001b[39;00m\n\u001b[1;32m   1878\u001b[0m export_func \u001b[38;5;241m=\u001b[39m _strict_export \u001b[38;5;28;01mif\u001b[39;00m strict \u001b[38;5;28;01melse\u001b[39;00m _non_strict_export\n\u001b[0;32m-> 1880\u001b[0m export_artifact \u001b[38;5;241m=\u001b[39m \u001b[43mexport_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# type: ignore[operator]\u001b[39;49;00m\n\u001b[1;32m   1881\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1882\u001b[0m \u001b[43m    \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1883\u001b[0m \u001b[43m    \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1884\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1885\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1886\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1887\u001b[0m \u001b[43m    \u001b[49m\u001b[43moriginal_state_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1888\u001b[0m \u001b[43m    \u001b[49m\u001b[43moriginal_in_spec\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1889\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_complex_guards_as_runtime_asserts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1890\u001b[0m \u001b[43m    \u001b[49m\u001b[43m_is_torch_jit_trace\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1891\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1892\u001b[0m export_graph_signature: ExportGraphSignature \u001b[38;5;241m=\u001b[39m export_artifact\u001b[38;5;241m.\u001b[39maten\u001b[38;5;241m.\u001b[39msig\n\u001b[1;32m   1894\u001b[0m forward_arg_names \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m   1895\u001b[0m     _get_forward_arg_names(mod, args, kwargs) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _is_torch_jit_trace \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1896\u001b[0m )\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1683\u001b[0m, in \u001b[0;36m_non_strict_export\u001b[0;34m(mod, args, kwargs, dynamic_shapes, preserve_module_call_signature, pre_dispatch, original_state_dict, orig_in_spec, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace, dispatch_tracing_mode)\u001b[0m\n\u001b[1;32m   1667\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) \u001b[38;5;28;01mas\u001b[39;00m (\n\u001b[1;32m   1668\u001b[0m     patched_mod,\n\u001b[1;32m   1669\u001b[0m     new_fake_args,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1672\u001b[0m     map_fake_to_real,\n\u001b[1;32m   1673\u001b[0m ):\n\u001b[1;32m   1674\u001b[0m     _to_aten_func \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m   1675\u001b[0m         _export_to_aten_ir_make_fx\n\u001b[1;32m   1676\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m dispatch_tracing_mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmake_fx\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1681\u001b[0m         )\n\u001b[1;32m   1682\u001b[0m     )\n\u001b[0;32m-> 1683\u001b[0m     aten_export_artifact \u001b[38;5;241m=\u001b[39m \u001b[43m_to_aten_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# type: ignore[operator]\u001b[39;49;00m\n\u001b[1;32m   1684\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpatched_mod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1685\u001b[0m \u001b[43m        \u001b[49m\u001b[43mnew_fake_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1686\u001b[0m \u001b[43m        \u001b[49m\u001b[43mnew_fake_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1687\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfake_params_buffers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1688\u001b[0m \u001b[43m        \u001b[49m\u001b[43mnew_fake_constant_attrs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1689\u001b[0m \u001b[43m        \u001b[49m\u001b[43mproduce_guards_callback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_produce_guards_callback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1690\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtransform\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_tuplify_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1691\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1692\u001b[0m     \u001b[38;5;66;03m# aten_export_artifact.constants contains only fake script objects, we need to map them back\u001b[39;00m\n\u001b[1;32m   1693\u001b[0m     aten_export_artifact\u001b[38;5;241m.\u001b[39mconstants \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m   1694\u001b[0m         fqn: map_fake_to_real[obj] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obj, FakeScriptObject) \u001b[38;5;28;01melse\u001b[39;00m obj\n\u001b[1;32m   1695\u001b[0m         \u001b[38;5;28;01mfor\u001b[39;00m fqn, obj \u001b[38;5;129;01min\u001b[39;00m aten_export_artifact\u001b[38;5;241m.\u001b[39mconstants\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m   1696\u001b[0m     }\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:637\u001b[0m, in \u001b[0;36m_export_to_aten_ir\u001b[0;34m(mod, fake_args, fake_kwargs, fake_params_buffers, constant_attrs, produce_guards_callback, transform, pre_dispatch, decomp_table, _check_autograd_state, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m    627\u001b[0m \u001b[38;5;66;03m# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,\u001b[39;00m\n\u001b[1;32m    628\u001b[0m \u001b[38;5;66;03m# otherwise aot_export_module will error out because it sees a mix of fake_modes.\u001b[39;00m\n\u001b[1;32m    629\u001b[0m \u001b[38;5;66;03m# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.\u001b[39;00m\n\u001b[1;32m    630\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mstateless\u001b[38;5;241m.\u001b[39m_reparametrize_module(\n\u001b[1;32m    631\u001b[0m     mod,\n\u001b[1;32m    632\u001b[0m     fake_params_buffers,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    635\u001b[0m     stack_weights\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m    636\u001b[0m ), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context():  \u001b[38;5;66;03m# type: ignore[attr-defined]\u001b[39;00m\n\u001b[0;32m--> 637\u001b[0m     gm, graph_signature \u001b[38;5;241m=\u001b[39m \u001b[43mtransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43maot_export_module\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    638\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    639\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfake_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    640\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtrace_joint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    641\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpre_dispatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    642\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdecompositions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecomp_table\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    643\u001b[0m \u001b[43m        \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfake_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    644\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    646\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_maybe_fixup_gm_and_output_node_meta\u001b[39m(old_gm, new_gm):\n\u001b[1;32m    647\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(old_gm, torch\u001b[38;5;241m.\u001b[39mfx\u001b[38;5;241m.\u001b[39mGraphModule):\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1611\u001b[0m, in \u001b[0;36m_non_strict_export.<locals>._tuplify_outputs.<locals>._aot_export_non_strict\u001b[0;34m(mod, args, kwargs, **flags)\u001b[0m\n\u001b[1;32m   1605\u001b[0m new_preserved_call_signatures \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m   1606\u001b[0m     \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_export_root.\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m i \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m preserve_module_call_signature\n\u001b[1;32m   1607\u001b[0m ]\n\u001b[1;32m   1608\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _wrap_submodules(\n\u001b[1;32m   1609\u001b[0m     wrapped_mod, new_preserved_call_signatures, module_call_specs\n\u001b[1;32m   1610\u001b[0m ):\n\u001b[0;32m-> 1611\u001b[0m     gm, sig \u001b[38;5;241m=\u001b[39m \u001b[43maot_export\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwrapped_mod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mflags\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1612\u001b[0m     log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExported program from AOTAutograd:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, gm)\n\u001b[1;32m   1614\u001b[0m sig\u001b[38;5;241m.\u001b[39mparameters \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_map(_strip_root, sig\u001b[38;5;241m.\u001b[39mparameters)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1246\u001b[0m, in \u001b[0;36maot_export_module\u001b[0;34m(mod, args, decompositions, trace_joint, output_loss_index, pre_dispatch, dynamic_shapes, kwargs)\u001b[0m\n\u001b[1;32m   1243\u001b[0m full_args\u001b[38;5;241m.\u001b[39mextend(args)\n\u001b[1;32m   1245\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx():\n\u001b[0;32m-> 1246\u001b[0m     fx_g, metadata, in_spec, out_spec \u001b[38;5;241m=\u001b[39m \u001b[43m_aot_export_function\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1247\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfn_to_trace\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1248\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfull_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1249\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdecompositions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecompositions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1250\u001b[0m \u001b[43m        \u001b[49m\u001b[43mnum_params_buffers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams_len\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1251\u001b[0m \u001b[43m        \u001b[49m\u001b[43mno_tangents\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m   1252\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpre_dispatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1253\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1254\u001b[0m \u001b[43m        \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1255\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1256\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trace_joint:\n\u001b[1;32m   1258\u001b[0m     \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mflattened_joint\u001b[39m(\u001b[38;5;241m*\u001b[39margs):\n\u001b[1;32m   1259\u001b[0m         \u001b[38;5;66;03m# The idea here is that the joint graph that AOTAutograd creates has some strict properties:\u001b[39;00m\n\u001b[1;32m   1260\u001b[0m         \u001b[38;5;66;03m# (1) It accepts two arguments (primals, tangents), and pytree_flattens them\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1273\u001b[0m         \u001b[38;5;66;03m# This function \"fixes\" both of the above by removing any tangent inputs,\u001b[39;00m\n\u001b[1;32m   1274\u001b[0m         \u001b[38;5;66;03m# and removing pytrees from the original FX graph.\u001b[39;00m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1480\u001b[0m, in \u001b[0;36m_aot_export_function\u001b[0;34m(func, args, num_params_buffers, decompositions, no_tangents, pre_dispatch, dynamic_shapes, kwargs)\u001b[0m\n\u001b[1;32m   1477\u001b[0m fake_mode, shape_env \u001b[38;5;241m=\u001b[39m construct_fake_mode(flat_args, aot_config)\n\u001b[1;32m   1478\u001b[0m fake_flat_args \u001b[38;5;241m=\u001b[39m process_inputs(flat_args, aot_config, fake_mode, shape_env)\n\u001b[0;32m-> 1480\u001b[0m fx_g, meta \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_aot_dispatcher_function\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1481\u001b[0m \u001b[43m    \u001b[49m\u001b[43mflat_fn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1482\u001b[0m \u001b[43m    \u001b[49m\u001b[43mfake_flat_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1483\u001b[0m \u001b[43m    \u001b[49m\u001b[43maot_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1484\u001b[0m \u001b[43m    \u001b[49m\u001b[43mfake_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1485\u001b[0m \u001b[43m    \u001b[49m\u001b[43mshape_env\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1486\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1487\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fx_g, meta, in_spec, out_spec\u001b[38;5;241m.\u001b[39mspec\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:522\u001b[0m, in \u001b[0;36mcreate_aot_dispatcher_function\u001b[0;34m(flat_fn, fake_flat_args, aot_config, fake_mode, shape_env)\u001b[0m\n\u001b[1;32m    514\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_aot_dispatcher_function\u001b[39m(\n\u001b[1;32m    515\u001b[0m     flat_fn,\n\u001b[1;32m    516\u001b[0m     fake_flat_args: FakifiedFlatArgs,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    519\u001b[0m     shape_env: Optional[ShapeEnv],\n\u001b[1;32m    520\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[Callable, ViewAndMutationMeta]:\n\u001b[1;32m    521\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m dynamo_timed(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcreate_aot_dispatcher_function\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 522\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_create_aot_dispatcher_function\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    523\u001b[0m \u001b[43m            \u001b[49m\u001b[43mflat_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfake_flat_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maot_config\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfake_mode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshape_env\u001b[49m\n\u001b[1;32m    524\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:623\u001b[0m, in \u001b[0;36m_create_aot_dispatcher_function\u001b[0;34m(flat_fn, fake_flat_args, aot_config, fake_mode, shape_env)\u001b[0m\n\u001b[1;32m    621\u001b[0m     ctx \u001b[38;5;241m=\u001b[39m nullcontext()\n\u001b[1;32m    622\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx:\n\u001b[0;32m--> 623\u001b[0m     fw_metadata \u001b[38;5;241m=\u001b[39m \u001b[43mrun_functionalized_fw_and_collect_metadata\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    624\u001b[0m \u001b[43m        \u001b[49m\u001b[43mflat_fn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    625\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstatic_input_indices\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maot_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstatic_input_indices\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    626\u001b[0m \u001b[43m        \u001b[49m\u001b[43mkeep_input_mutations\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maot_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkeep_inference_input_mutations\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    627\u001b[0m \u001b[43m        \u001b[49m\u001b[43mis_train\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mneeds_autograd\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    628\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maot_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpre_dispatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    629\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m_dup_fake_script_obj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfake_flat_args\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    631\u001b[0m req_subclass_dispatch \u001b[38;5;241m=\u001b[39m requires_subclass_dispatch(\n\u001b[1;32m    632\u001b[0m     fake_flat_args, fw_metadata\n\u001b[1;32m    633\u001b[0m )\n\u001b[1;32m    635\u001b[0m output_and_mutation_safe \u001b[38;5;241m=\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28many\u001b[39m(\n\u001b[1;32m    636\u001b[0m     x\u001b[38;5;241m.\u001b[39mrequires_grad\n\u001b[1;32m    637\u001b[0m     \u001b[38;5;66;03m# view-type operations preserve requires_grad even in no_grad.\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    652\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m fw_metadata\u001b[38;5;241m.\u001b[39minput_info\n\u001b[1;32m    653\u001b[0m )\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py:173\u001b[0m, in \u001b[0;36mrun_functionalized_fw_and_collect_metadata.<locals>.inner\u001b[0;34m(*flat_args)\u001b[0m\n\u001b[1;32m    170\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m disable_above, mode, suppress_pending:\n\u001b[1;32m    171\u001b[0m     \u001b[38;5;66;03m# precondition: The passed in function already handles unflattening inputs + flattening outputs\u001b[39;00m\n\u001b[1;32m    172\u001b[0m     flat_f_args \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_map(_to_fun, flat_args)\n\u001b[0;32m--> 173\u001b[0m     flat_f_outs \u001b[38;5;241m=\u001b[39m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mflat_f_args\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    174\u001b[0m     \u001b[38;5;66;03m# We didn't do any tracing, so we don't need to process the\u001b[39;00m\n\u001b[1;32m    175\u001b[0m     \u001b[38;5;66;03m# unbacked symbols, they will just disappear into the ether.\u001b[39;00m\n\u001b[1;32m    176\u001b[0m     \u001b[38;5;66;03m# Also, prevent memoization from applying.\u001b[39;00m\n\u001b[1;32m    177\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m fake_mode:\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:182\u001b[0m, in \u001b[0;36mcreate_tree_flattened_fn.<locals>.flat_fn\u001b[0;34m(*flat_args)\u001b[0m\n\u001b[1;32m    180\u001b[0m \u001b[38;5;28;01mnonlocal\u001b[39;00m out_spec\n\u001b[1;32m    181\u001b[0m args, kwargs \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_unflatten(flat_args, tensor_args_spec)\n\u001b[0;32m--> 182\u001b[0m tree_out \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    183\u001b[0m flat_out, spec \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_flatten(tree_out)\n\u001b[1;32m    184\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m flat_out:\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:863\u001b[0m, in \u001b[0;36mcreate_functional_call.<locals>.functional_call\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    859\u001b[0m                 out \u001b[38;5;241m=\u001b[39m PropagateUnbackedSymInts(mod)\u001b[38;5;241m.\u001b[39mrun(\n\u001b[1;32m    860\u001b[0m                     \u001b[38;5;241m*\u001b[39margs[params_len:], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m    861\u001b[0m                 )\n\u001b[1;32m    862\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 863\u001b[0m         out \u001b[38;5;241m=\u001b[39m \u001b[43mmod\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m[\u001b[49m\u001b[43mparams_len\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    865\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(out, (\u001b[38;5;28mtuple\u001b[39m, \u001b[38;5;28mlist\u001b[39m)):\n\u001b[1;32m    866\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m    867\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGraph output must be a (). This is so that we can avoid \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    868\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpytree processing of the outputs. Please change the module to \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    869\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhave tuple outputs or use aot_module instead.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    870\u001b[0m     )\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1598\u001b[0m, in \u001b[0;36m_non_strict_export.<locals>._tuplify_outputs.<locals>._aot_export_non_strict.<locals>.Wrapper.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1594\u001b[0m         tree_out \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mfx\u001b[38;5;241m.\u001b[39mInterpreter(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_export_root)\u001b[38;5;241m.\u001b[39mrun(\n\u001b[1;32m   1595\u001b[0m             \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m   1596\u001b[0m         )\n\u001b[1;32m   1597\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1598\u001b[0m     tree_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_export_root\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1599\u001b[0m flat_outs, out_spec \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_flatten(tree_out)\n\u001b[1;32m   1600\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtuple\u001b[39m(flat_outs)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "Cell \u001b[0;32mIn[39], line 46\u001b[0m, in \u001b[0;36mStyleTTS2.forward\u001b[0;34m(self, tokens)\u001b[0m\n\u001b[1;32m     42\u001b[0m pred_aln_trg\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mvstack(pred_aln_trg_list)\n\u001b[1;32m     44\u001b[0m en \u001b[38;5;241m=\u001b[39m d\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m@\u001b[39m pred_aln_trg\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m---> 46\u001b[0m F0_pred, N_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpredictor\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mF0Ntrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43men\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     47\u001b[0m t_en \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext_encoder\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39minference(tokens)\n\u001b[1;32m     48\u001b[0m asr \u001b[38;5;241m=\u001b[39m t_en \u001b[38;5;241m@\u001b[39m pred_aln_trg\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n",
      "File \u001b[0;32m~/Projects/DeepLearning/TTS/Kokoro-82M/models.py:471\u001b[0m, in \u001b[0;36mProsodyPredictor.F0Ntrain\u001b[0;34m(self, x, s)\u001b[0m\n\u001b[1;32m    466\u001b[0m x1 \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m    467\u001b[0m \u001b[38;5;66;03m# torch._check(x1.dim() == 3, lambda: print(f\"Expected 3D tensor, got {x1.dim()}D tensor\"))\u001b[39;00m\n\u001b[1;32m    468\u001b[0m \u001b[38;5;66;03m# torch._check(x1.shape[1] > 0, lambda: print(f\"Shape 2, got {x1.shape[1]}\"))\u001b[39;00m\n\u001b[1;32m    469\u001b[0m \u001b[38;5;66;03m# torch._check(x1.shape[2] > 0, lambda: print(f\"Shape 2, got {x1.shape[2]}\"))\u001b[39;00m\n\u001b[1;32m    470\u001b[0m \u001b[38;5;66;03m# torch._check(x.shape[2] > 0, lambda: print(f\"Shape 2, got {x.shape[2]}\"))\u001b[39;00m\n\u001b[0;32m--> 471\u001b[0m x2, _temp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshared\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx1\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    472\u001b[0m \u001b[38;5;66;03m# torch._check(x.shape[2] > 0, lambda: print(f\"Shape 2, got {x.size(2)}\"))\u001b[39;00m\n\u001b[1;32m    474\u001b[0m F0 \u001b[38;5;241m=\u001b[39m x2\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123\u001b[0m, in \u001b[0;36mLSTM.forward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m   1120\u001b[0m         hx \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpermute_hidden(hx, sorted_indices)\n\u001b[1;32m   1122\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_sizes \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1123\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1124\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1125\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1126\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_flat_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1127\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1128\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_layers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1129\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1130\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1131\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbidirectional\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1132\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_first\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1133\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1134\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1135\u001b[0m     result \u001b[38;5;241m=\u001b[39m _VF\u001b[38;5;241m.\u001b[39mlstm(\n\u001b[1;32m   1136\u001b[0m         \u001b[38;5;28minput\u001b[39m,\n\u001b[1;32m   1137\u001b[0m         batch_sizes,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1144\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbidirectional,\n\u001b[1;32m   1145\u001b[0m     )\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_export/non_strict_utils.py:520\u001b[0m, in \u001b[0;36m_NonStrictTorchFunctionHandler.__torch_function__\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m    512\u001b[0m         log\u001b[38;5;241m.\u001b[39mdebug(\n\u001b[1;32m    513\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m called at \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m in \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m    514\u001b[0m             func\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    517\u001b[0m             frame\u001b[38;5;241m.\u001b[39mf_code\u001b[38;5;241m.\u001b[39mco_name,\n\u001b[1;32m    518\u001b[0m         )\n\u001b[1;32m    519\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 520\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    521\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m GuardOnDataDependentSymNode \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m    522\u001b[0m     _suggest_fixes_for_data_dependent_error_non_strict(e)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_decomp/decompositions.py:3476\u001b[0m, in \u001b[0;36mlstm_impl\u001b[0;34m(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)\u001b[0m\n\u001b[1;32m   3474\u001b[0m hidden \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mzip\u001b[39m(hx[\u001b[38;5;241m0\u001b[39m], hx[\u001b[38;5;241m1\u001b[39m]))\n\u001b[1;32m   3475\u001b[0m layer_fn \u001b[38;5;241m=\u001b[39m select_one_layer_lstm_function(\u001b[38;5;28minput\u001b[39m, hx, params)\n\u001b[0;32m-> 3476\u001b[0m out, final_hiddens \u001b[38;5;241m=\u001b[39m \u001b[43m_rnn_helper\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   3477\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3478\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhidden\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3479\u001b[0m \u001b[43m    \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3480\u001b[0m \u001b[43m    \u001b[49m\u001b[43mhas_biases\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3481\u001b[0m \u001b[43m    \u001b[49m\u001b[43mnum_layers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3482\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3483\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3484\u001b[0m \u001b[43m    \u001b[49m\u001b[43mbidirectional\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3485\u001b[0m \u001b[43m    \u001b[49m\u001b[43mbatch_first\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3486\u001b[0m \u001b[43m    \u001b[49m\u001b[43mlayer_fn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3487\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3488\u001b[0m final_hiddens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mzip\u001b[39m(\u001b[38;5;241m*\u001b[39mfinal_hiddens))\n\u001b[1;32m   3489\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out, torch\u001b[38;5;241m.\u001b[39mstack(final_hiddens[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m0\u001b[39m), torch\u001b[38;5;241m.\u001b[39mstack(final_hiddens[\u001b[38;5;241m1\u001b[39m], \u001b[38;5;241m0\u001b[39m)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_decomp/decompositions.py:3151\u001b[0m, in \u001b[0;36m_rnn_helper\u001b[0;34m(input, hidden, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, layer_fn)\u001b[0m\n\u001b[1;32m   3147\u001b[0m cur_params, cur_hidden, bidir_params, bidir_hidden \u001b[38;5;241m=\u001b[39m params_hiddens(\n\u001b[1;32m   3148\u001b[0m     params, hidden, i, bidirectional\n\u001b[1;32m   3149\u001b[0m )\n\u001b[1;32m   3150\u001b[0m dropout \u001b[38;5;241m=\u001b[39m dropout \u001b[38;5;28;01mif\u001b[39;00m (train \u001b[38;5;129;01mand\u001b[39;00m num_layers \u001b[38;5;241m<\u001b[39m i \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;241m0.0\u001b[39m\n\u001b[0;32m-> 3151\u001b[0m fwd_inp, fwd_hidden \u001b[38;5;241m=\u001b[39m \u001b[43mlayer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcur_hidden\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcur_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhas_biases\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3152\u001b[0m final_hiddens\u001b[38;5;241m.\u001b[39mappend(fwd_hidden)\n\u001b[1;32m   3154\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m bidirectional:\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_decomp/decompositions.py:3333\u001b[0m, in \u001b[0;36mone_layer_lstm\u001b[0;34m(inp, hidden, params, has_biases, reverse)\u001b[0m\n\u001b[1;32m   3331\u001b[0m precomputed_input \u001b[38;5;241m=\u001b[39m precomputed_input\u001b[38;5;241m.\u001b[39mflip(\u001b[38;5;241m0\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m reverse \u001b[38;5;28;01melse\u001b[39;00m precomputed_input\n\u001b[1;32m   3332\u001b[0m step_output \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m-> 3333\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m inp \u001b[38;5;129;01min\u001b[39;00m precomputed_input:\n\u001b[1;32m   3334\u001b[0m     hx, cx \u001b[38;5;241m=\u001b[39m lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m   3335\u001b[0m     step_output\u001b[38;5;241m.\u001b[39mappend(hx)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_tensor.py:1119\u001b[0m, in \u001b[0;36mTensor.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1110\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_get_tracing_state():\n\u001b[1;32m   1111\u001b[0m     warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m   1112\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIterating over a tensor might cause the trace to be incorrect. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   1113\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPassing a tensor of different shape won\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt change the number of \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1117\u001b[0m         stacklevel\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m,\n\u001b[1;32m   1118\u001b[0m     )\n\u001b[0;32m-> 1119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28miter\u001b[39m(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munbind\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:534\u001b[0m, in \u001b[0;36mFunctionalTensorMode.__torch_dispatch__\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m    525\u001b[0m     outs_wrapped \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_map_only(\n\u001b[1;32m    526\u001b[0m         torch\u001b[38;5;241m.\u001b[39mTensor, wrap, outs_unwrapped\n\u001b[1;32m    527\u001b[0m     )\n\u001b[1;32m    528\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    529\u001b[0m     \u001b[38;5;66;03m# When we dispatch to the C++ functionalization kernel, we might need to jump back to the\u001b[39;00m\n\u001b[1;32m    530\u001b[0m     \u001b[38;5;66;03m# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath\u001b[39;00m\n\u001b[1;32m    531\u001b[0m     \u001b[38;5;66;03m# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch\u001b[39;00m\n\u001b[1;32m    532\u001b[0m     \u001b[38;5;66;03m# from the TLS in order to avoid infinite looping, but this would prevent us from coming\u001b[39;00m\n\u001b[1;32m    533\u001b[0m     \u001b[38;5;66;03m# back to PreDispatch later\u001b[39;00m\n\u001b[0;32m--> 534\u001b[0m     outs_unwrapped \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_op_dk\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    535\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mDispatchKey\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mFunctionalize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    536\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs_unwrapped\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    537\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs_unwrapped\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    538\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    539\u001b[0m     \u001b[38;5;66;03m# We don't allow any mutation on result of dropout or _to_copy\u001b[39;00m\n\u001b[1;32m    540\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mexport:\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/utils/_stats.py:21\u001b[0m, in \u001b[0;36mcount.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     19\u001b[0m     simple_call_counter[fn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m     20\u001b[0m simple_call_counter[fn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m] \u001b[38;5;241m=\u001b[39m simple_call_counter[fn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m] \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1238\u001b[0m, in \u001b[0;36mFakeTensorMode.__torch_dispatch__\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m   1234\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m (\n\u001b[1;32m   1235\u001b[0m     torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_get_dispatch_mode(torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_TorchDispatchModeKey\u001b[38;5;241m.\u001b[39mFAKE) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1236\u001b[0m ), func\n\u001b[1;32m   1237\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1238\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdispatch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1239\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m   1240\u001b[0m     log\u001b[38;5;241m.\u001b[39mexception(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfake tensor raised TypeError\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1692\u001b[0m, in \u001b[0;36mFakeTensorMode.dispatch\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m   1689\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m   1691\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcache_enabled:\n\u001b[0;32m-> 1692\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cached_dispatch_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1693\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1694\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dispatch_impl(func, types, args, kwargs)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1348\u001b[0m, in \u001b[0;36mFakeTensorMode._cached_dispatch_impl\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m   1345\u001b[0m     FakeTensorMode\u001b[38;5;241m.\u001b[39mcache_bypasses[e\u001b[38;5;241m.\u001b[39mreason] \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m   1347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output \u001b[38;5;129;01mis\u001b[39;00m _UNASSIGNED:\n\u001b[0;32m-> 1348\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dispatch_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1350\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1943\u001b[0m, in \u001b[0;36mFakeTensorMode._dispatch_impl\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m   1933\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m func \u001b[38;5;129;01min\u001b[39;00m decomposition_table \u001b[38;5;129;01mand\u001b[39;00m (\n\u001b[1;32m   1934\u001b[0m     has_symbolic_sizes\n\u001b[1;32m   1935\u001b[0m     \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1940\u001b[0m     )\n\u001b[1;32m   1941\u001b[0m ):\n\u001b[1;32m   1942\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m-> 1943\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mdecomposition_table\u001b[49m\u001b[43m[\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1945\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m   1946\u001b[0m     \u001b[38;5;66;03m# Decomposes CompositeImplicitAutograd ops\u001b[39;00m\n\u001b[1;32m   1947\u001b[0m     r \u001b[38;5;241m=\u001b[39m func\u001b[38;5;241m.\u001b[39mdecompose(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_refs/__init__.py:3956\u001b[0m, in \u001b[0;36munbind\u001b[0;34m(t, dim)\u001b[0m\n\u001b[1;32m   3953\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m ()\n\u001b[1;32m   3954\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   3955\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtuple\u001b[39m(\n\u001b[0;32m-> 3956\u001b[0m         torch\u001b[38;5;241m.\u001b[39msqueeze(s, dim) \u001b[38;5;28;01mfor\u001b[39;00m s \u001b[38;5;129;01min\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtensor_split\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[43mdim\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3957\u001b[0m     )\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/utils/_stats.py:21\u001b[0m, in \u001b[0;36mcount.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     19\u001b[0m     simple_call_counter[fn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m     20\u001b[0m simple_call_counter[fn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m] \u001b[38;5;241m=\u001b[39m simple_call_counter[fn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m] \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1238\u001b[0m, in \u001b[0;36mFakeTensorMode.__torch_dispatch__\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m   1234\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m (\n\u001b[1;32m   1235\u001b[0m     torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_get_dispatch_mode(torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_TorchDispatchModeKey\u001b[38;5;241m.\u001b[39mFAKE) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1236\u001b[0m ), func\n\u001b[1;32m   1237\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1238\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdispatch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1239\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m   1240\u001b[0m     log\u001b[38;5;241m.\u001b[39mexception(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfake tensor raised TypeError\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1692\u001b[0m, in \u001b[0;36mFakeTensorMode.dispatch\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m   1689\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m   1691\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcache_enabled:\n\u001b[0;32m-> 1692\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cached_dispatch_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1693\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1694\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dispatch_impl(func, types, args, kwargs)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1348\u001b[0m, in \u001b[0;36mFakeTensorMode._cached_dispatch_impl\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m   1345\u001b[0m     FakeTensorMode\u001b[38;5;241m.\u001b[39mcache_bypasses[e\u001b[38;5;241m.\u001b[39mreason] \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m   1347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output \u001b[38;5;129;01mis\u001b[39;00m _UNASSIGNED:\n\u001b[0;32m-> 1348\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dispatch_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1350\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1947\u001b[0m, in \u001b[0;36mFakeTensorMode._dispatch_impl\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m   1943\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m decomposition_table[func](\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m   1945\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m   1946\u001b[0m     \u001b[38;5;66;03m# Decomposes CompositeImplicitAutograd ops\u001b[39;00m\n\u001b[0;32m-> 1947\u001b[0m     r \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecompose\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1948\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m r \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mNotImplemented\u001b[39m:\n\u001b[1;32m   1949\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m r\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py:759\u001b[0m, in \u001b[0;36mOpOverload.decompose\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    757\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpy_kernels[dk](\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m    758\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_dispatch_has_kernel_for_dispatch_key(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname(), dk):\n\u001b[0;32m--> 759\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_op_dk\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    760\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    761\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mNotImplemented\u001b[39m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py:429\u001b[0m, in \u001b[0;36mSymNode.guard_int\u001b[0;34m(self, file, line)\u001b[0m\n\u001b[1;32m    426\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mguard_int\u001b[39m(\u001b[38;5;28mself\u001b[39m, file, line):\n\u001b[1;32m    427\u001b[0m     \u001b[38;5;66;03m# TODO: use the file/line for some useful diagnostic on why a\u001b[39;00m\n\u001b[1;32m    428\u001b[0m     \u001b[38;5;66;03m# guard occurred\u001b[39;00m\n\u001b[0;32m--> 429\u001b[0m     r \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape_env\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate_expr\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexpr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfx_node\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfx_node\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    430\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    431\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mint\u001b[39m(r)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/recording.py:262\u001b[0m, in \u001b[0;36mrecord_shapeenv_event.<locals>.decorator.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    255\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    256\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m args[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mis_recording:  \u001b[38;5;66;03m# type: ignore[has-type]\u001b[39;00m\n\u001b[1;32m    257\u001b[0m         \u001b[38;5;66;03m# If ShapeEnv is already recording an event, call the wrapped\u001b[39;00m\n\u001b[1;32m    258\u001b[0m         \u001b[38;5;66;03m# function directly.\u001b[39;00m\n\u001b[1;32m    259\u001b[0m         \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m    260\u001b[0m         \u001b[38;5;66;03m# NB: here, we skip the check of whether all ShapeEnv instances\u001b[39;00m\n\u001b[1;32m    261\u001b[0m         \u001b[38;5;66;03m# are equal, in favor of a faster dispatch.\u001b[39;00m\n\u001b[0;32m--> 262\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m retlog(\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m    264\u001b[0m     \u001b[38;5;66;03m# Retrieve an instance of ShapeEnv.\u001b[39;00m\n\u001b[1;32m    265\u001b[0m     \u001b[38;5;66;03m# Assumption: the collection of args and kwargs may not reference\u001b[39;00m\n\u001b[1;32m    266\u001b[0m     \u001b[38;5;66;03m# different ShapeEnv instances.\u001b[39;00m\n\u001b[1;32m    267\u001b[0m     \u001b[38;5;28mself\u001b[39m \u001b[38;5;241m=\u001b[39m _extract_shape_env_and_assert_equal(args, kwargs)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5122\u001b[0m, in \u001b[0;36mShapeEnv.evaluate_expr\u001b[0;34m(self, orig_expr, hint, fx_node, size_oblivious, forcing_spec)\u001b[0m\n\u001b[1;32m   5117\u001b[0m \u001b[38;5;129m@lru_cache\u001b[39m(\u001b[38;5;241m256\u001b[39m)\n\u001b[1;32m   5118\u001b[0m \u001b[38;5;129m@record_shapeenv_event\u001b[39m(save_tracked_fakes\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m   5119\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mevaluate_expr\u001b[39m(\u001b[38;5;28mself\u001b[39m, orig_expr: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msympy.Expr\u001b[39m\u001b[38;5;124m\"\u001b[39m, hint\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, fx_node\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   5120\u001b[0m                   size_oblivious: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;241m*\u001b[39m, forcing_spec: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m   5121\u001b[0m     \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 5122\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluate_expr\u001b[49m\u001b[43m(\u001b[49m\u001b[43morig_expr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfx_node\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize_oblivious\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mforcing_spec\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforcing_spec\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   5123\u001b[0m     \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[1;32m   5124\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m   5125\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfailed during evaluate_expr(\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m, hint=\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m, size_oblivious=\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m, forcing_spec=\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m   5126\u001b[0m             orig_expr, hint, size_oblivious, forcing_spec\n\u001b[1;32m   5127\u001b[0m         )\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5238\u001b[0m, in \u001b[0;36mShapeEnv._evaluate_expr\u001b[0;34m(self, orig_expr, hint, fx_node, size_oblivious, forcing_spec)\u001b[0m\n\u001b[1;32m   5236\u001b[0m         concrete_val \u001b[38;5;241m=\u001b[39m unsound_result\n\u001b[1;32m   5237\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 5238\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_data_dependent_error(\n\u001b[1;32m   5239\u001b[0m             expr\u001b[38;5;241m.\u001b[39mxreplace(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvar_to_val),\n\u001b[1;32m   5240\u001b[0m             expr,\n\u001b[1;32m   5241\u001b[0m             size_oblivious_result\u001b[38;5;241m=\u001b[39msize_oblivious_result\n\u001b[1;32m   5242\u001b[0m         )\n\u001b[1;32m   5243\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   5244\u001b[0m     expr \u001b[38;5;241m=\u001b[39m new_expr\n",
      "\u001b[0;31mGuardOnDataDependentSymNode\u001b[0m: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0).  (Size-like symbols: u0)\n\nPotential framework code culprit (scroll up for full backtrace):\n  File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n    return self._op_dk(dk, *args, **kwargs)\n\nFor more information, run with TORCH_LOGS=\"dynamic\"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n\nThe following call raised this error:\n  File \"/rhome/eingerman/Projects/DeepLearning/TTS/Kokoro-82M/models.py\", line 471, in F0Ntrain\n    x2, _temp = self.shared(x1)\n\nTo fix the error, insert one of the following checks before this call:\n  1. torch._check(x.shape[2])\n  2. torch._check(~x.shape[2])\n\n(These suggested fixes were derived by replacing `u0` with x.shape[2] or x1.shape[1] in u0 and its negation.)"
     ]
    }
   ],
   "source": [
    "os.environ['TORCH_LOGS'] = '+dynamic'\n",
    "os.environ['TORCH_LOGS'] = '+export'\n",
    "os.environ['TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED']=\"u0 >= 0\"\n",
    "os.environ['TORCHDYNAMO_EXTENDED_DEBUG_CPP']=\"1\"\n",
    "os.environ['TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL']=\"u0\"\n",
    "\n",
    "class StyleTTS2(torch.nn.Module):\n",
    "    def __init__(self, model, voicepack):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "        self.voicepack = voicepack\n",
    "    \n",
    "    def forward(self, tokens):\n",
    "        speed = 1.\n",
    "        # tokens = torch.nn.functional.pad(tokens, (0, 510 - tokens.shape[-1]))\n",
    "        device = tokens.device\n",
    "        input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
    "\n",
    "        text_mask = length_to_mask(input_lengths).to(device)\n",
    "        bert_dur = self.model['bert'](tokens, attention_mask=(~text_mask).int())\n",
    "\n",
    "        d_en = self.model[\"bert_encoder\"](bert_dur).transpose(-1, -2)\n",
    "\n",
    "        ref_s = self.voicepack[tokens.shape[1]]\n",
    "        s = ref_s[:, 128:]\n",
    "\n",
    "        d = self.model[\"predictor\"].text_encoder.inference(d_en, s)\n",
    "        x, _ = self.model[\"predictor\"].lstm(d)\n",
    "\n",
    "        duration = self.model[\"predictor\"].duration_proj(x)\n",
    "        duration = torch.sigmoid(duration).sum(axis=-1) / speed\n",
    "        pred_dur = torch.round(duration).clamp(min=1).long()\n",
    "        \n",
    "        c_start = F.pad(pred_dur,(1,0), \"constant\").cumsum(dim=1)[0,0:-1]\n",
    "        c_end = c_start + pred_dur[0,:]\n",
    "        indices = torch.arange(0, pred_dur.sum().item()).long().to(device)\n",
    "\n",
    "        pred_aln_trg_list=[]\n",
    "        for cs, ce in zip(c_start, c_end):\n",
    "            row = torch.where((indices>=cs) & (indices<ce), 1., 0.)\n",
    "            pred_aln_trg_list.append(row)\n",
    "        pred_aln_trg=torch.vstack(pred_aln_trg_list)\n",
    "            \n",
    "        en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)\n",
    "        \n",
    "        F0_pred, N_pred = self.model[\"predictor\"].F0Ntrain(en, s)\n",
    "        t_en = self.model[\"text_encoder\"].inference(tokens)\n",
    "        asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)\n",
    "        return (asr, F0_pred, N_pred, ref_s[:, :128])\n",
    "        # output = self.model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().detach().cpu().numpy()\n",
    "\n",
    "\n",
    "style_model = StyleTTS2(model=model, voicepack=voicepack)\n",
    "(asr, F0_pred, N_pred, ref_s) = style_model(tokens)\n",
    "\n",
    "token_len = torch.export.Dim(\"token_len\", min=2, max=510)\n",
    "batch = torch.export.Dim(\"batch\")\n",
    "dynamic_shapes = {\"tokens\":{0:batch, 1:token_len}}\n",
    "\n",
    "# with torch.no_grad():\n",
    "export_mod = torch.export.export(style_model, args=( tokens, ), dynamic_shapes=dynamic_shapes, strict=False)\n",
    "# export_mod = torch.export.export(style_model, args=( tokens, ), strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0101 18:19:15.402000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3557] create_symbol s0 = 143 for L['args'][0][0]._base.size()[1] [2, int_oo] (_export/non_strict_utils.py:109 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"s0\"\n",
      "I0101 18:19:15.407000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3557] create_symbol s1 = 143 for L['args'][0][0].size()[0] [2, int_oo] (_export/non_strict_utils.py:109 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"s1\"\n",
      "I0101 18:19:15.420000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:4857] set_replacement s1 = 143 (range_refined_to_singleton) VR[143, 143]\n",
      "I0101 18:19:15.422000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:5106] eval Eq(s1, 143) [guard added] (mp/ipykernel_2488298/2011460168.py:16 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"Eq(s1, 143)\"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 143])\n",
      "torch.Size([1, s1])\n",
      "torch.Size([1, 143])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "I0101 18:19:33.124000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3646] produce_guards\n"
     ]
    },
    {
     "ename": "UserError",
     "evalue": "Constraints violated (token_len)! For more information, run with TORCH_LOGS=\"+dynamic\".\n  - Not all values of token_len = L['args'][0][0].size()[0] in the specified range are valid because token_len was inferred to be a constant (143).\nSuggested fixes:\n  token_len = 143",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mConstraintViolationError\u001b[0m                  Traceback (most recent call last)",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:670\u001b[0m, in \u001b[0;36m_export_to_aten_ir\u001b[0;34m(mod, fake_args, fake_kwargs, fake_params_buffers, constant_attrs, produce_guards_callback, transform, pre_dispatch, decomp_table, _check_autograd_state, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m    669\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 670\u001b[0m     \u001b[43mproduce_guards_callback\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgm\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    671\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (ConstraintViolationError, ValueRangeError) \u001b[38;5;28;01mas\u001b[39;00m e:\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1655\u001b[0m, in \u001b[0;36m_non_strict_export.<locals>._produce_guards_callback\u001b[0;34m(gm)\u001b[0m\n\u001b[1;32m   1654\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_produce_guards_callback\u001b[39m(gm):\n\u001b[0;32m-> 1655\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mproduce_guards_and_solve_constraints\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1656\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfake_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfake_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1657\u001b[0m \u001b[43m        \u001b[49m\u001b[43mgm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1658\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtransformed_dynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1659\u001b[0m \u001b[43m        \u001b[49m\u001b[43mequalities_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mequalities_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1660\u001b[0m \u001b[43m        \u001b[49m\u001b[43moriginal_signature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moriginal_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1661\u001b[0m \u001b[43m        \u001b[49m\u001b[43m_is_torch_jit_trace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_is_torch_jit_trace\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1662\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_export/non_strict_utils.py:305\u001b[0m, in \u001b[0;36mproduce_guards_and_solve_constraints\u001b[0;34m(fake_mode, gm, dynamic_shapes, equalities_inputs, original_signature, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m    304\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m constraint_violation_error:\n\u001b[0;32m--> 305\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m constraint_violation_error\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_export/non_strict_utils.py:270\u001b[0m, in \u001b[0;36mproduce_guards_and_solve_constraints\u001b[0;34m(fake_mode, gm, dynamic_shapes, equalities_inputs, original_signature, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m    269\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 270\u001b[0m     \u001b[43mshape_env\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mproduce_guards\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    271\u001b[0m \u001b[43m        \u001b[49m\u001b[43mplaceholders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    272\u001b[0m \u001b[43m        \u001b[49m\u001b[43msources\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    273\u001b[0m \u001b[43m        \u001b[49m\u001b[43minput_contexts\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_contexts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    274\u001b[0m \u001b[43m        \u001b[49m\u001b[43mequalities_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mequalities_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    275\u001b[0m \u001b[43m        \u001b[49m\u001b[43mignore_static\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    276\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    277\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ConstraintViolationError \u001b[38;5;28;01mas\u001b[39;00m e:\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:4178\u001b[0m, in \u001b[0;36mShapeEnv.produce_guards\u001b[0;34m(self, placeholders, sources, source_ref, guards, input_contexts, equalities_inputs, _simplified, ignore_static)\u001b[0m\n\u001b[1;32m   4177\u001b[0m     err \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)\n\u001b[0;32m-> 4178\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m ConstraintViolationError(\n\u001b[1;32m   4179\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mConstraints violated (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdebug_names\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m)! \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   4180\u001b[0m         \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mFor more information, run with TORCH_LOGS=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m+dynamic\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m   4181\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   4182\u001b[0m     )\n\u001b[1;32m   4183\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(warn_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
      "\u001b[0;31mConstraintViolationError\u001b[0m: Constraints violated (token_len)! For more information, run with TORCH_LOGS=\"+dynamic\".\n  - Not all values of token_len = L['args'][0][0].size()[0] in the specified range are valid because token_len was inferred to be a constant (143).\nSuggested fixes:\n  token_len = 143",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[0;31mUserError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[33], line 61\u001b[0m\n\u001b[1;32m     58\u001b[0m dynamic_shapes \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokens0\u001b[39m\u001b[38;5;124m\"\u001b[39m:{\u001b[38;5;241m0\u001b[39m:token_len}}\n\u001b[1;32m     60\u001b[0m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[0;32m---> 61\u001b[0m export_mod \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m     62\u001b[0m \u001b[38;5;66;03m# export_mod = torch.export.export(test_model, args=( tokens[0,:], ), strict=False).run_decompositions()\u001b[39;00m\n\u001b[1;32m     63\u001b[0m \u001b[38;5;28mprint\u001b[39m(export_mod)\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/__init__.py:270\u001b[0m, in \u001b[0;36mexport\u001b[0;34m(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature)\u001b[0m\n\u001b[1;32m    264\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(mod, torch\u001b[38;5;241m.\u001b[39mjit\u001b[38;5;241m.\u001b[39mScriptModule):\n\u001b[1;32m    265\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m    266\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExporting a ScriptModule is not supported. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    267\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMaybe try converting your ScriptModule to an ExportedProgram \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    268\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124musing `TS2EPConverter(mod, args, kwargs).convert()` instead.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    269\u001b[0m     )\n\u001b[0;32m--> 270\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_export\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    271\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    272\u001b[0m \u001b[43m    \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    273\u001b[0m \u001b[43m    \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    274\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    275\u001b[0m \u001b[43m    \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstrict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    276\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    277\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    278\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1017\u001b[0m, in \u001b[0;36m_log_export_wrapper.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m   1010\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1011\u001b[0m         log_export_usage(\n\u001b[1;32m   1012\u001b[0m             event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.error.unclassified\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m   1013\u001b[0m             \u001b[38;5;28mtype\u001b[39m\u001b[38;5;241m=\u001b[39merror_type,\n\u001b[1;32m   1014\u001b[0m             message\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mstr\u001b[39m(e),\n\u001b[1;32m   1015\u001b[0m             flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m   1016\u001b[0m         )\n\u001b[0;32m-> 1017\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m   1018\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m   1019\u001b[0m     _EXPORT_FLAGS \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:990\u001b[0m, in \u001b[0;36m_log_export_wrapper.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    988\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    989\u001b[0m     start \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m--> 990\u001b[0m     ep \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    991\u001b[0m     end \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m    992\u001b[0m     log_export_usage(\n\u001b[1;32m    993\u001b[0m         event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.time\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m    994\u001b[0m         metrics\u001b[38;5;241m=\u001b[39mend \u001b[38;5;241m-\u001b[39m start,\n\u001b[1;32m    995\u001b[0m         flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m    996\u001b[0m         \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mget_ep_stats(ep),\n\u001b[1;32m    997\u001b[0m     )\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/exported_program.py:114\u001b[0m, in \u001b[0;36m_disable_prexisiting_fake_mode.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    111\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(fn)\n\u001b[1;32m    112\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    113\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m unset_fake_temporarily():\n\u001b[0;32m--> 114\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1880\u001b[0m, in \u001b[0;36m_export\u001b[0;34m(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature, pre_dispatch, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m   1877\u001b[0m \u001b[38;5;66;03m# Call the appropriate export function based on the strictness of tracing.\u001b[39;00m\n\u001b[1;32m   1878\u001b[0m export_func \u001b[38;5;241m=\u001b[39m _strict_export \u001b[38;5;28;01mif\u001b[39;00m strict \u001b[38;5;28;01melse\u001b[39;00m _non_strict_export\n\u001b[0;32m-> 1880\u001b[0m export_artifact \u001b[38;5;241m=\u001b[39m \u001b[43mexport_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# type: ignore[operator]\u001b[39;49;00m\n\u001b[1;32m   1881\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1882\u001b[0m \u001b[43m    \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1883\u001b[0m \u001b[43m    \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1884\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1885\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1886\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1887\u001b[0m \u001b[43m    \u001b[49m\u001b[43moriginal_state_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1888\u001b[0m \u001b[43m    \u001b[49m\u001b[43moriginal_in_spec\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1889\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_complex_guards_as_runtime_asserts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1890\u001b[0m \u001b[43m    \u001b[49m\u001b[43m_is_torch_jit_trace\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1891\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1892\u001b[0m export_graph_signature: ExportGraphSignature \u001b[38;5;241m=\u001b[39m export_artifact\u001b[38;5;241m.\u001b[39maten\u001b[38;5;241m.\u001b[39msig\n\u001b[1;32m   1894\u001b[0m forward_arg_names \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m   1895\u001b[0m     _get_forward_arg_names(mod, args, kwargs) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _is_torch_jit_trace \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1896\u001b[0m )\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1683\u001b[0m, in \u001b[0;36m_non_strict_export\u001b[0;34m(mod, args, kwargs, dynamic_shapes, preserve_module_call_signature, pre_dispatch, original_state_dict, orig_in_spec, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace, dispatch_tracing_mode)\u001b[0m\n\u001b[1;32m   1667\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) \u001b[38;5;28;01mas\u001b[39;00m (\n\u001b[1;32m   1668\u001b[0m     patched_mod,\n\u001b[1;32m   1669\u001b[0m     new_fake_args,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1672\u001b[0m     map_fake_to_real,\n\u001b[1;32m   1673\u001b[0m ):\n\u001b[1;32m   1674\u001b[0m     _to_aten_func \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m   1675\u001b[0m         _export_to_aten_ir_make_fx\n\u001b[1;32m   1676\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m dispatch_tracing_mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmake_fx\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1681\u001b[0m         )\n\u001b[1;32m   1682\u001b[0m     )\n\u001b[0;32m-> 1683\u001b[0m     aten_export_artifact \u001b[38;5;241m=\u001b[39m \u001b[43m_to_aten_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# type: ignore[operator]\u001b[39;49;00m\n\u001b[1;32m   1684\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpatched_mod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1685\u001b[0m \u001b[43m        \u001b[49m\u001b[43mnew_fake_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1686\u001b[0m \u001b[43m        \u001b[49m\u001b[43mnew_fake_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1687\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfake_params_buffers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1688\u001b[0m \u001b[43m        \u001b[49m\u001b[43mnew_fake_constant_attrs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1689\u001b[0m \u001b[43m        \u001b[49m\u001b[43mproduce_guards_callback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_produce_guards_callback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1690\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtransform\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_tuplify_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1691\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1692\u001b[0m     \u001b[38;5;66;03m# aten_export_artifact.constants contains only fake script objects, we need to map them back\u001b[39;00m\n\u001b[1;32m   1693\u001b[0m     aten_export_artifact\u001b[38;5;241m.\u001b[39mconstants \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m   1694\u001b[0m         fqn: map_fake_to_real[obj] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obj, FakeScriptObject) \u001b[38;5;28;01melse\u001b[39;00m obj\n\u001b[1;32m   1695\u001b[0m         \u001b[38;5;28;01mfor\u001b[39;00m fqn, obj \u001b[38;5;129;01min\u001b[39;00m aten_export_artifact\u001b[38;5;241m.\u001b[39mconstants\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m   1696\u001b[0m     }\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:672\u001b[0m, in \u001b[0;36m_export_to_aten_ir\u001b[0;34m(mod, fake_args, fake_kwargs, fake_params_buffers, constant_attrs, produce_guards_callback, transform, pre_dispatch, decomp_table, _check_autograd_state, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m    670\u001b[0m         produce_guards_callback(gm)\n\u001b[1;32m    671\u001b[0m     \u001b[38;5;28;01mexcept\u001b[39;00m (ConstraintViolationError, ValueRangeError) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m--> 672\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m UserError(UserErrorType\u001b[38;5;241m.\u001b[39mCONSTRAINT_VIOLATION, \u001b[38;5;28mstr\u001b[39m(e))  \u001b[38;5;66;03m# noqa: B904\u001b[39;00m\n\u001b[1;32m    674\u001b[0m \u001b[38;5;66;03m# Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature.\u001b[39;00m\n\u001b[1;32m    675\u001b[0m \u001b[38;5;66;03m# Overwrite output specs afterwards.\u001b[39;00m\n\u001b[1;32m    676\u001b[0m flat_fake_args \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_leaves((fake_args, fake_kwargs))\n",
      "\u001b[0;31mUserError\u001b[0m: Constraints violated (token_len)! For more information, run with TORCH_LOGS=\"+dynamic\".\n  - Not all values of token_len = L['args'][0][0].size()[0] in the specified range are valid because token_len was inferred to be a constant (143).\nSuggested fixes:\n  token_len = 143"
     ]
    }
   ],
   "source": [
    "os.environ['TORCH_LOGS'] = '+dynamic'\n",
    "os.environ['TORCH_LOGS'] = '+export'\n",
    "class test(torch.nn.Module):\n",
    "    def __init__(self, model, voicepack):\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "        self.voicepack = voicepack\n",
    "        self.model.text_encoder.lstm.flatten_parameters()\n",
    "    \n",
    "    def forward(self, tokens0):\n",
    "        tokens = tokens0.unsqueeze(0)\n",
    "        print(tokens.shape)\n",
    "        # speed = 1.\n",
    "        # # tokens = torch.nn.functional.pad(tokens, (0, 510 - tokens.shape[-1]))\n",
    "        # device = tokens.device\n",
    "        input_lengths = torch.LongTensor([tokens0.shape[-1]]).to(device)\n",
    "\n",
    "        # text_mask = length_to_mask(input_lengths).to(device)\n",
    "        # bert_dur = self.model['bert'](tokens, attention_mask=(~text_mask).int())\n",
    "\n",
    "        # d_en = self.model[\"bert_encoder\"](bert_dur).transpose(-1, -2)\n",
    "\n",
    "        # ref_s = self.voicepack[tokens.shape[1]]\n",
    "        # s = ref_s[:, 128:]\n",
    "\n",
    "        # d = self.model[\"predictor\"].text_encoder.inference(d_en, s)\n",
    "        # x, _ = self.model[\"predictor\"].lstm(d)\n",
    "\n",
    "        # duration = self.model[\"predictor\"].duration_proj(x)\n",
    "        # duration = torch.sigmoid(duration).sum(axis=-1) / speed\n",
    "        # pred_dur = torch.round(duration).clamp(min=1).long()\n",
    "        \n",
    "        # c_start = F.pad(pred_dur,(1,0), \"constant\").cumsum(dim=1)[0,0:-1]\n",
    "        # c_end = c_start + pred_dur[0,:]\n",
    "        # indices = torch.arange(0, pred_dur.sum().item()).long().to(device)\n",
    "\n",
    "        # pred_aln_trg_list=[]\n",
    "        # for cs, ce in zip(c_start, c_end):\n",
    "        #     row = torch.where((indices>=cs) & (indices<ce), 1., 0.)\n",
    "        #     pred_aln_trg_list.append(row)\n",
    "        # pred_aln_trg=torch.vstack(pred_aln_trg_list)\n",
    "            \n",
    "        # en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)\n",
    "\n",
    "        en = torch.rand((1, 640, 2*tokens.shape[-1]))\n",
    "        s = torch.rand((1,128))\n",
    "        F0_pred, N_pred = self.model[\"predictor\"].F0Ntrain(en, s)\n",
    "        t_en = self.model[\"text_encoder\"].inference(tokens)\n",
    "        asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)\n",
    "        return (asr, F0_pred, N_pred, ref_s[:, :128])\n",
    "        # output = self.model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().detach().cpu().numpy()\n",
    "\n",
    "\n",
    "test_model = test(model=model, voicepack=voicepack)\n",
    "(asr, F0_pred, N_pred, ref_s) = test_model(tokens[0,:])\n",
    "\n",
    "token_len = torch.export.Dim(\"token_len\") #, min=2, max=510)\n",
    "dynamic_shapes = {\"tokens0\":{0:token_len}}\n",
    "\n",
    "# with torch.no_grad():\n",
    "export_mod = torch.export.export(test_model, args=( tokens[0,:], ), dynamic_shapes=dynamic_shapes, strict=False)\n",
    "# export_mod = torch.export.export(test_model, args=( tokens[0,:], ), strict=False).run_decompositions()\n",
    "print(export_mod)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.float32\n"
     ]
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "# pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())\n",
    "c_start = F.pad(pred_dur,(1,0), \"constant\").cumsum(dim=1)[0,0:-1]\n",
    "c_end = c_start + pred_dur[0,:]\n",
    "indices = torch.arange(0, pred_dur.sum().item()).to(device)\n",
    "\n",
    "pred_aln_trg_list=[]\n",
    "for cs, ce in zip(c_start, c_end):\n",
    "    row = torch.where((indices>=cs) & (indices<ce), 1., 0.)\n",
    "    pred_aln_trg_list.append(row)\n",
    "\n",
    "pred_aln_trg=torch.vstack(pred_aln_trg_list)\n",
    "# print(pred_aln_trg)\n",
    "# pl.imshow(pred_aln_trg)\n",
    "pl.plot(pred_aln_trg[:,50])\n",
    "# print(pred_dur.shape)\n",
    "print(row.dtype)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 142])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c_start[:,:-2].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "The size of tensor a (23) must match the size of tensor b (143) at non-singleton dimension 2",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[48], line 41\u001b[0m\n\u001b[1;32m     36\u001b[0m     pred_aln_trg[flat_indices\u001b[38;5;241m.\u001b[39mT[\u001b[38;5;241m0\u001b[39m], flat_indices\u001b[38;5;241m.\u001b[39mT[\u001b[38;5;241m1\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m     38\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m pred_aln_trg\n\u001b[0;32m---> 41\u001b[0m pred_aln_trg \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_alignment_matrix\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_lengths\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpred_dur\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     42\u001b[0m pl\u001b[38;5;241m.\u001b[39mimshow(pred_aln_trg)\n",
      "Cell \u001b[0;32mIn[48], line 22\u001b[0m, in \u001b[0;36mcreate_alignment_matrix\u001b[0;34m(input_lengths, pred_dur)\u001b[0m\n\u001b[1;32m     19\u001b[0m col_indices \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39marange(pred_dur\u001b[38;5;241m.\u001b[39mmax()\u001b[38;5;241m.\u001b[39mitem())\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mrepeat(input_lengths, \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m     21\u001b[0m \u001b[38;5;66;03m# Create a mask based on durations\u001b[39;00m\n\u001b[0;32m---> 22\u001b[0m mask \u001b[38;5;241m=\u001b[39m \u001b[43mcol_indices\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m<\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpred_dur\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     24\u001b[0m \u001b[38;5;66;03m# Create offset indices for the columns\u001b[39;00m\n\u001b[1;32m     25\u001b[0m offset \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat((torch\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m0\u001b[39m]), cum_dur[\u001b[38;5;241m0\u001b[39m, :\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]))\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mrepeat(\u001b[38;5;241m1\u001b[39m, pred_dur\u001b[38;5;241m.\u001b[39mmax()\u001b[38;5;241m.\u001b[39mitem())\n",
      "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (23) must match the size of tensor b (143) at non-singleton dimension 2"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "def create_alignment_matrix(input_lengths, pred_dur):\n",
    "    \"\"\"Creates an alignment matrix without explicit loops.\n",
    "\n",
    "    Args:\n",
    "        input_lengths: Number of input units (int).\n",
    "        pred_dur: Predicted durations (torch.Tensor of shape (1, input_lengths)).\n",
    "\n",
    "    Returns:\n",
    "        pred_aln_trg: Alignment matrix (torch.Tensor of shape (input_lengths, pred_dur.sum())).\n",
    "    \"\"\"\n",
    "    total_duration = pred_dur.sum().item()\n",
    "    pred_aln_trg = torch.zeros(input_lengths, total_duration)\n",
    "\n",
    "    # Calculate cumulative durations\n",
    "    cum_dur = torch.cumsum(pred_dur, dim=1)\n",
    "\n",
    "    # Create indices for filling the matrix\n",
    "    row_indices = torch.arange(input_lengths).unsqueeze(1).repeat(1, pred_dur.max().item())\n",
    "    col_indices = torch.arange(pred_dur.max().item()).unsqueeze(0).repeat(input_lengths, 1)\n",
    "\n",
    "    # Create a mask based on durations\n",
    "    mask = col_indices < pred_dur.unsqueeze(1)\n",
    "\n",
    "    # Create offset indices for the columns\n",
    "    offset = torch.cat((torch.tensor([0]), cum_dur[0, :-1])).unsqueeze(1).repeat(1, pred_dur.max().item())\n",
    "\n",
    "    # Apply the mask and offset to generate the final column indices\n",
    "    final_col_indices = (col_indices + offset) * mask\n",
    "\n",
    "    # Flatten indices and create a flattened index tensor\n",
    "    flat_row_indices = row_indices[mask].long()\n",
    "    flat_col_indices = final_col_indices[mask].long()\n",
    "    flat_indices = torch.stack([flat_row_indices, flat_col_indices], dim=1)\n",
    "\n",
    "    # Scatter ones into the alignment matrix\n",
    "    pred_aln_trg[flat_indices.T[0], flat_indices.T[1]] = 1\n",
    "\n",
    "    return pred_aln_trg\n",
    "\n",
    "\n",
    "pred_aln_trg = create_alignment_matrix(input_lengths.item(), pred_dur)\n",
    "pl.imshow(pred_aln_trg)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([143])"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_lengths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())\n",
    "c_frame = 0\n",
    "\n",
    "for i in range(pred_aln_trg.size(0)):\n",
    "    pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1\n",
    "    c_frame += pred_dur[0,i].item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<bound method Module.eval of StyleTTS2()>"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "style_model.eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CustomAlbert(\n",
       "  (embeddings): AlbertEmbeddings(\n",
       "    (word_embeddings): Embedding(178, 128, padding_idx=0)\n",
       "    (position_embeddings): Embedding(512, 128)\n",
       "    (token_type_embeddings): Embedding(2, 128)\n",
       "    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)\n",
       "    (dropout): Dropout(p=0, inplace=False)\n",
       "  )\n",
       "  (encoder): AlbertTransformer(\n",
       "    (embedding_hidden_mapping_in): Linear(in_features=128, out_features=768, bias=True)\n",
       "    (albert_layer_groups): ModuleList(\n",
       "      (0): AlbertLayerGroup(\n",
       "        (albert_layers): ModuleList(\n",
       "          (0): AlbertLayer(\n",
       "            (full_layer_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (attention): AlbertAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (attention_dropout): Dropout(p=0, inplace=False)\n",
       "              (output_dropout): Dropout(p=0, inplace=False)\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            )\n",
       "            (ffn): Linear(in_features=768, out_features=2048, bias=True)\n",
       "            (ffn_output): Linear(in_features=2048, out_features=768, bias=True)\n",
       "            (activation): NewGELUActivation()\n",
       "            (dropout): Dropout(p=0, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (pooler): Linear(in_features=768, out_features=768, bias=True)\n",
       "  (pooler_activation): Tanh()\n",
       ")"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model['bert']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "only integer tensors of a single element can be converted to an index",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[17], line 11\u001b[0m\n\u001b[1;32m      8\u001b[0m pred_aln_trg1 \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mzeros(input_lengths, pred_dur\u001b[38;5;241m.\u001b[39msum()\u001b[38;5;241m.\u001b[39mitem(), dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[1;32m      9\u001b[0m batch_indices \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39marange(input_lengths\u001b[38;5;241m.\u001b[39mitem())\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m---> 11\u001b[0m \u001b[43mpred_aln_trg1\u001b[49m\u001b[43m[\u001b[49m\u001b[43mbatch_indices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstart_indices\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mend_indices\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m     13\u001b[0m pl\u001b[38;5;241m.\u001b[39mimshow(pred_aln_trg1)\n",
      "\u001b[0;31mTypeError\u001b[0m: only integer tensors of a single element can be converted to an index"
     ]
    }
   ],
   "source": [
    "# Process durations\n",
    "\n",
    "cumsum_dur = torch.cumsum(pred_dur, dim=1).to(device)\n",
    "end_indices = cumsum_dur - 1\n",
    "start_indices = torch.cat([torch.zeros(1, 1, dtype=torch.long).to(device), end_indices[:, :-1] + 1], dim=1)\n",
    "\n",
    "# Create binary alignment target\n",
    "pred_aln_trg1 = torch.zeros(input_lengths, pred_dur.sum().item(), dtype=torch.float32)\n",
    "batch_indices = torch.arange(input_lengths.item()).unsqueeze(1)\n",
    "\n",
    "pred_aln_trg1[batch_indices, start_indices: end_indices + 1] = 1\n",
    "\n",
    "pl.imshow(pred_aln_trg1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([143, 329])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_aln_trg1 = torch.zeros(input_lengths, pred_dur.sum().item()).to(device)\n",
    "a = torch.arange(pred_aln_trg1.size(0))[:, None].repeat(1, pred_dur.size(1)).to(device)\n",
    "b = (torch.arange(pred_dur.size(1)).repeat(pred_aln_trg1.size(0), 1).to(device) < pred_dur).to(torch.float32).to(device)\n",
    "print(pred_aln_trg.dtype, pred_aln_trg1.dtype, a.dtype, b.dtype)\n",
    "print(a.device, b.device, pred_dur.device)\n",
    "pred_aln_trg1.scatter_(1, \n",
    "                      a, \n",
    "                      b)\n",
    "\n",
    "pl.imshow(pred_aln_trg1.detach().cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "Expected index [1, 25] to be smaller than self [143, 329] apart from dimension 1 and to be smaller size than src [1, 1]",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[8], line 18\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[38;5;66;03m# Use scatter_add_ to set the appropriate slices to 1\u001b[39;00m\n\u001b[1;32m     17\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(pred_dur\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m1\u001b[39m)):\n\u001b[0;32m---> 18\u001b[0m \t\u001b[43mmask\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscatter_add_\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mstart_indices\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m:\u001b[49m\u001b[43mi\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marange\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpred_dur\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclamp\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mmax\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpred_aln_trg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalues\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m:\u001b[49m\u001b[43mi\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     20\u001b[0m \u001b[38;5;66;03m# Apply the mask to pred_aln_trg\u001b[39;00m\n\u001b[1;32m     21\u001b[0m pred_aln_trg \u001b[38;5;241m=\u001b[39m mask\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Expected index [1, 25] to be smaller than self [143, 329] apart from dimension 1 and to be smaller size than src [1, 1]"
     ]
    }
   ],
   "source": [
    "# Calculate the cumulative sum of durations to get the end indices\n",
    "cumulative_durations = torch.cumsum(pred_dur, dim=1).to(device)\n",
    "\n",
    "# Calculate the start indices by shifting the cumulative durations\n",
    "start_indices = cumulative_durations - pred_dur\n",
    "\n",
    "# Create a tensor of indices for pred_aln_trg\n",
    "indices = torch.arange(pred_aln_trg.size(1)).to(device)\n",
    "\n",
    "# Create a mask tensor initialized to zeros\n",
    "mask = torch.zeros_like(pred_aln_trg).to(device)\n",
    "\n",
    "# Create a tensor to hold the values to scatter\n",
    "values = torch.ones_like(pred_dur, dtype=pred_aln_trg.dtype).to(device)\n",
    "\n",
    "# Use scatter_ to set the appropriate slices to 1\n",
    "mask.scatter_(1, start_indices.unsqueeze(2) + torch.arange(pred_dur.max()).unsqueeze(0).unsqueeze(0).to(device), values.unsqueeze(2))\n",
    "\n",
    "# Apply the mask to pred_aln_trg\n",
    "pred_aln_trg = mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cpu')"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.arange(pred_dur.size(1)).repeat(pred_aln_trg1.size(0), 1).device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f715c3faf50>"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pl.imshow(pred_aln_trg.detach().cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 143])\n"
     ]
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(pred_dur.shape)\n",
    "pl.plot(pred_dur[0,:].detach().cpu().numpy().cumsum());"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[50, 157, 43, 135, 16, 53, 135, 46, 16, 43, 102, 16, 56, 156, 57, 135, 6, 16, 102, 62, 61, 16, 70, 56, 16, 138, 56, 156, 72, 56, 61, 85, 123, 83, 44, 83, 54, 16, 53, 65, 156, 86, 61, 62, 131, 83, 56, 4, 16, 54, 156, 43, 102, 53, 16, 156, 72, 61, 53, 102, 112, 16, 70, 56, 16, 138, 56, 44, 156, 76, 158, 123, 56, 16, 62, 131, 156, 43, 102, 54, 46, 16, 102, 48, 16, 81, 47, 102, 54, 16, 54, 156, 51, 158, 46, 16, 70, 16, 92, 156, 135, 46, 16, 54, 156, 43, 102, 48, 4, 16, 81, 47, 102, 16, 50, 156, 72, 64, 83, 56, 62, 16, 156, 51, 158, 64, 83, 56, 16, 44, 157, 102, 56, 16, 44, 156, 76, 158, 123, 56, 4]\n"
     ]
    }
   ],
   "source": [
    "ps = phonemize(text)\n",
    "tokens = tokenize(ps)\n",
    "print(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models import build_model\n",
    "import torch\n",
    "device = \"cpu\" #'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "model = build_model('kokoro-v0_19.pth', device)\n",
    "voicepack = torch.load('voices/af.pt', weights_only=True).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert = model[\"bert\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "embeddings.word_embeddings.weight torch.Size([178, 128])\n",
      "embeddings.position_embeddings.weight torch.Size([512, 128])\n",
      "embeddings.token_type_embeddings.weight torch.Size([2, 128])\n",
      "embeddings.LayerNorm.weight torch.Size([128])\n",
      "embeddings.LayerNorm.bias torch.Size([128])\n",
      "encoder.embedding_hidden_mapping_in.weight torch.Size([768, 128])\n",
      "encoder.embedding_hidden_mapping_in.bias torch.Size([768])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.weight torch.Size([768])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.bias torch.Size([768])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.attention.query.weight torch.Size([768, 768])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.attention.query.bias torch.Size([768])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.attention.key.weight torch.Size([768, 768])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.attention.key.bias torch.Size([768])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.attention.value.weight torch.Size([768, 768])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.attention.value.bias torch.Size([768])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.attention.dense.weight torch.Size([768, 768])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.attention.dense.bias torch.Size([768])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.attention.LayerNorm.weight torch.Size([768])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.attention.LayerNorm.bias torch.Size([768])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.ffn.weight torch.Size([2048, 768])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.ffn.bias torch.Size([2048])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight torch.Size([768, 2048])\n",
      "encoder.albert_layer_groups.0.albert_layers.0.ffn_output.bias torch.Size([768])\n",
      "pooler.weight torch.Size([768, 768])\n",
      "pooler.bias torch.Size([768])\n"
     ]
    }
   ],
   "source": [
    "# show all parameters of model bert\n",
    "for name, param in bert.named_parameters():\n",
    "    print(name, param.requires_grad())\n",
    "    # print(param)\n",
    "    # print(param.shape)\n",
    "    # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Testing LSTM export"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x1.shape=torch.Size([1, 300, 256])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py:4279: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model. \n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exported graph: graph(%x : Float(*, 300, 128, strides=[38400, 128, 1], requires_grad=0, device=cpu),\n",
      "      %onnx::LSTM_194 : Float(2, 1024, strides=[1024, 1], requires_grad=0, device=cpu),\n",
      "      %onnx::LSTM_195 : Float(2, 512, 128, strides=[65536, 128, 1], requires_grad=0, device=cpu),\n",
      "      %onnx::LSTM_196 : Float(2, 512, 128, strides=[65536, 128, 1], requires_grad=0, device=cpu)):\n",
      "  %/lstm/Shape_output_0 : Long(3, strides=[1], device=cpu) = onnx::Shape[onnx_name=\"/lstm/Shape\"](%x), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1081:0\n",
      "  %/lstm/Constant_output_0 : Long(device=cpu) = onnx::Constant[value={0}, onnx_name=\"/lstm/Constant\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1081:0\n",
      "  %/lstm/Gather_output_0 : Long(device=cpu) = onnx::Gather[axis=0, onnx_name=\"/lstm/Gather\"](%/lstm/Shape_output_0, %/lstm/Constant_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1081:0\n",
      "  %/lstm/Constant_1_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={2}, onnx_name=\"/lstm/Constant_1\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm\n",
      "  %onnx::Unsqueeze_16 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()\n",
      "  %/lstm/Unsqueeze_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name=\"/lstm/Unsqueeze\"](%/lstm/Gather_output_0, %onnx::Unsqueeze_16), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm\n",
      "  %/lstm/Constant_2_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={128}, onnx_name=\"/lstm/Constant_2\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm\n",
      "  %/lstm/Concat_output_0 : Long(3, strides=[1], device=cpu) = onnx::Concat[axis=0, onnx_name=\"/lstm/Concat\"](%/lstm/Constant_1_output_0, %/lstm/Unsqueeze_output_0, %/lstm/Constant_2_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1085:0\n",
      "  %/lstm/ConstantOfShape_output_0 : Float(*, *, *, strides=[128, 128, 1], requires_grad=0, device=cpu) = onnx::ConstantOfShape[value={0}, onnx_name=\"/lstm/ConstantOfShape\"](%/lstm/Concat_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1085:0\n",
      "  %/lstm/Transpose_output_0 : Float(300, *, 128, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name=\"/lstm/Transpose\"](%x), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n",
      "  %onnx::LSTM_23 : Tensor? = prim::Constant(), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n",
      "  %/lstm/LSTM_output_0 : Float(300, 2, *, 128, device=cpu), %/lstm/LSTM_output_1 : Float(2, *, 128, strides=[128, 128, 1], requires_grad=1, device=cpu), %/lstm/LSTM_output_2 : Float(2, *, 128, strides=[128, 128, 1], requires_grad=1, device=cpu) = onnx::LSTM[direction=\"bidirectional\", hidden_size=128, onnx_name=\"/lstm/LSTM\"](%/lstm/Transpose_output_0, %onnx::LSTM_195, %onnx::LSTM_196, %onnx::LSTM_194, %onnx::LSTM_23, %/lstm/ConstantOfShape_output_0, %/lstm/ConstantOfShape_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n",
      "  %/lstm/Transpose_1_output_0 : Float(300, *, 2, 128, device=cpu) = onnx::Transpose[perm=[0, 2, 1, 3], onnx_name=\"/lstm/Transpose_1\"](%/lstm/LSTM_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n",
      "  %/lstm/Constant_3_output_0 : Long(3, strides=[1], device=cpu) = onnx::Constant[value= 0  0 -1 [ CPULongType{3} ], onnx_name=\"/lstm/Constant_3\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n",
      "  %/lstm/Reshape_output_0 : Float(300, *, 256, device=cpu) = onnx::Reshape[allowzero=0, onnx_name=\"/lstm/Reshape\"](%/lstm/Transpose_1_output_0, %/lstm/Constant_3_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n",
      "  %151 : Float(*, 300, 256, strides=[256, 256, 1], requires_grad=1, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name=\"/lstm/Transpose_2\"](%/lstm/Reshape_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n",
      "  return (%151)\n",
      "\n"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "'NoneType' object has no attribute 'graph'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[2], line 37\u001b[0m\n\u001b[1;32m     34\u001b[0m export_mod \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39monnx\u001b[38;5;241m.\u001b[39mexport(model\u001b[38;5;241m=\u001b[39mmodel, args\u001b[38;5;241m=\u001b[39m( xa, ), dynamic_axes\u001b[38;5;241m=\u001b[39mdynamic_shapes, input_names\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m\"\u001b[39m], f\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel.onnx\u001b[39m\u001b[38;5;124m\"\u001b[39m, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, dynamo\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m     35\u001b[0m \u001b[38;5;66;03m# export_mod.save(\"model.onnx\")\u001b[39;00m\n\u001b[1;32m     36\u001b[0m \u001b[38;5;66;03m# export_mod.save_diagnostics(\"model_diagnostics.sarif\")\u001b[39;00m\n\u001b[0;32m---> 37\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mexport_mod\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgraph\u001b[49m)\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'graph'"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "# os.environ['TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED']=\"Eq(s0, 384)\"\n",
    "\n",
    "# model class containing a single bidirectional LSTM layer\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.lstm = torch.nn.LSTM(128, 128, 1, bidirectional=True, batch_first=True)\n",
    "        #initialize lstm weights\n",
    "        for name, param in self.lstm.named_parameters():\n",
    "            if 'weight' in name:\n",
    "                torch.nn.init.orthogonal_(param)\n",
    "            elif 'bias' in name:\n",
    "                torch.nn.init.zeros_(param)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x1 = x.transpose(-1,-2)\n",
    "        # print(f\"{x.shape=} {x1.shape=}\")\n",
    "        x2, _ = self.lstm(x)\n",
    "        return x2\n",
    "\n",
    "model = Model()\n",
    "model = model.to(\"cpu\")\n",
    "model.eval()\n",
    "\n",
    "#inital input to LSTM in variable x\n",
    "xa = torch.zeros((1, 300, 128)).to(\"cpu\")\n",
    "x1 = model(xa)\n",
    "print(f\"{x1.shape=}\")\n",
    "ntokens = torch.export.Dim(\"ntokens\", min=3)\n",
    "dynamic_shapes= {\"x\":{0:\"ntokens\"}}\n",
    "\n",
    "# scripted = torch.jit.script(model)\n",
    "torch.onnx.export(model=model, args=( xa, ), dynamic_axes=dynamic_shapes, input_names=[\"x\"], f=\"model.onnx\", verbose=True, dynamo=False)\n",
    "# export_mod.save(\"model.onnx\")\n",
    "# export_mod.save_diagnostics(\"model_diagnostics.sarif\")\n",
    "# print(export_mod.graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 143])\n"
     ]
    }
   ],
   "source": [
    "from kokoro import phonemize, tokenize\n",
    "from models_scripting import load_plbert\n",
    "bert = load_plbert()\n",
    "\n",
    "text = \"How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born.\"\n",
    "ps = phonemize(text, \"a\")\n",
    "tokens = tokenize(ps)\n",
    "tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)\n",
    "dynamic_shapes = {\"tokens\":{1:'ntokens'}}\n",
    "print(tokens.shape)\n",
    "torch.onnx.export(model=bert, args=( tokens, ), dynamic_axes=dynamic_shapes, input_names=[\"tokens\"], f=\"bert.onnx\", verbose=False, dynamo=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "ename": "Fail",
     "evalue": "[ONNXRuntimeError] : 1 : FAIL : Load model from style_model.onnx failed:Node (/Transpose_9) Op (Transpose) [TypeInferenceError] Invalid attribute perm {1, -1, 0}, input shape = {0, 0, 128}",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFail\u001b[0m                                      Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 6\u001b[0m\n\u001b[1;32m      3\u001b[0m onnx_model \u001b[38;5;241m=\u001b[39m onnx\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstyle_model.onnx\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01monnxruntime\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mort\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m ort_session \u001b[38;5;241m=\u001b[39m \u001b[43mort\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mInferenceSession\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstyle_model.onnx\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m      7\u001b[0m outputs \u001b[38;5;241m=\u001b[39m ort_session\u001b[38;5;241m.\u001b[39mrun(\u001b[38;5;28;01mNone\u001b[39;00m, {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokens\u001b[39m\u001b[38;5;124m\"\u001b[39m: tokens\u001b[38;5;241m.\u001b[39mnumpy()})\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:465\u001b[0m, in \u001b[0;36mInferenceSession.__init__\u001b[0;34m(self, path_or_bytes, sess_options, providers, provider_options, **kwargs)\u001b[0m\n\u001b[1;32m    462\u001b[0m disabled_optimizers \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdisabled_optimizers\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    464\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 465\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_create_inference_session\u001b[49m\u001b[43m(\u001b[49m\u001b[43mproviders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprovider_options\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdisabled_optimizers\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    466\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mValueError\u001b[39;00m, \u001b[38;5;167;01mRuntimeError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m    467\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_enable_fallback:\n",
      "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:526\u001b[0m, in \u001b[0;36mInferenceSession._create_inference_session\u001b[0;34m(self, providers, provider_options, disabled_optimizers)\u001b[0m\n\u001b[1;32m    523\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_register_ep_custom_ops(session_options, providers, provider_options, available_providers)\n\u001b[1;32m    525\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model_path:\n\u001b[0;32m--> 526\u001b[0m     sess \u001b[38;5;241m=\u001b[39m \u001b[43mC\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mInferenceSession\u001b[49m\u001b[43m(\u001b[49m\u001b[43msession_options\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_model_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_read_config_from_model\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    527\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    528\u001b[0m     sess \u001b[38;5;241m=\u001b[39m C\u001b[38;5;241m.\u001b[39mInferenceSession(session_options, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model_bytes, \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_read_config_from_model)\n",
      "\u001b[0;31mFail\u001b[0m: [ONNXRuntimeError] : 1 : FAIL : Load model from style_model.onnx failed:Node (/Transpose_9) Op (Transpose) [TypeInferenceError] Invalid attribute perm {1, -1, 0}, input shape = {0, 0, 128}"
     ]
    }
   ],
   "source": [
    "import onnx\n",
    "\n",
    "onnx_model = onnx.load(\"style_model.onnx\")\n",
    "import onnxruntime as ort\n",
    "\n",
    "ort_session = ort.InferenceSession(\"style_model.onnx\")\n",
    "outputs = ort_session.run(None, {\"tokens\": tokens.numpy()})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "styletts2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}