{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "import os\n", "\n", "os.environ['TORCH_LOGS'] = '+dynamic'\n", "import pylab as pl" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n", " WeightNorm.apply(module, name, dim)\n", "/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:123: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1\n", " warnings.warn(\n" ] }, { "ename": "TypeError", "evalue": "CustomAlbert.forward() got an unexpected keyword argument 'attention_mask'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[3], line 10\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mkokoro\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m generate\n\u001b[1;32m 9\u001b[0m text \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHow could I know? It\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124ms an unanswerable question. Like asking an unborn child if they\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mll lead a good life. They haven\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt even been born.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 10\u001b[0m audio, out_ps \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvoicepack\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m# 4️⃣ Display the 24khz audio and print the output phonemes\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mIPython\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdisplay\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m display, Audio\n", "File \u001b[0;32m~/Projects/DeepLearning/TTS/Kokoro-82M/kokoro.py:147\u001b[0m, in \u001b[0;36mgenerate\u001b[0;34m(model, text, voicepack, lang, speed)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mTruncated to 510 tokens\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 146\u001b[0m ref_s \u001b[38;5;241m=\u001b[39m voicepack[\u001b[38;5;28mlen\u001b[39m(tokens)]\n\u001b[0;32m--> 147\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mref_s\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mspeed\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 148\u001b[0m ps \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mnext\u001b[39m(k \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m VOCAB\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;241m==\u001b[39m v) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m tokens)\n\u001b[1;32m 149\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out, ps\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Projects/DeepLearning/TTS/Kokoro-82M/kokoro.py:119\u001b[0m, in \u001b[0;36mforward\u001b[0;34m(model, tokens, ref_s, speed)\u001b[0m\n\u001b[1;32m 117\u001b[0m input_lengths \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mLongTensor([tokens\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]])\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 118\u001b[0m text_mask \u001b[38;5;241m=\u001b[39m length_to_mask(input_lengths)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m--> 119\u001b[0m bert_dur \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbert\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m~\u001b[39;49m\u001b[43mtext_mask\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mint\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 120\u001b[0m d_en \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mbert_encoder(bert_dur)\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 121\u001b[0m s \u001b[38;5;241m=\u001b[39m ref_s[:, \u001b[38;5;241m128\u001b[39m:]\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", "\u001b[0;31mTypeError\u001b[0m: CustomAlbert.forward() got an unexpected keyword argument 'attention_mask'" ] } ], "source": [ "from models import build_model\n", "import torch\n", "device = \"cpu\" #'cuda' if torch.cuda.is_available() else 'cpu'\n", "model = build_model('kokoro-v0_19.pth', device)\n", "voicepack = torch.load('voices/af.pt', weights_only=True).to(device)\n", "\n", "# 3️⃣ Call generate, which returns a 24khz audio waveform and a string of output phonemes\n", "from kokoro import generate\n", "text = \"How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born.\"\n", "audio, out_ps = generate(model, text, voicepack)\n", "\n", "# 4️⃣ Display the 24khz audio and print the output phonemes\n", "from IPython.display import display, Audio\n", "display(Audio(data=audio, rate=24000, autoplay=True))\n", "print(out_ps)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 640, 348])\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "hˌaʊ kʊd aɪ nˈoʊ? ɪts ɐn ʌnˈænsɚɹəbəl kwˈɛstʃən. lˈaɪk ˈæskɪŋ ɐn ʌnbˈɔːɹn tʃˈaɪld ɪf ðeɪl lˈiːd ɐ ɡˈʊd lˈaɪf. ðeɪ hˈævənt ˈiːvən bˌɪn bˈɔːɹn.\n" ] } ], "source": [ "from kokoro import phonemize, tokenize, length_to_mask\n", "import torch.nn.functional as F\n", "model = model\n", "speed = 1.\n", "\n", "ps = phonemize(text, \"a\")\n", "tokens = tokenize(ps)\n", "\n", "tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)\n", "\n", "# tokens = torch.nn.functional.pad(tokens, (0, 510 - tokens.shape[-1]))\n", "input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n", "\n", "text_mask = length_to_mask(input_lengths).to(device)\n", "bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n", "\n", "\n", "d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n", "\n", "ref_s =voicepack[tokens.shape[1]]\n", "s = ref_s[:, 128:]\n", "\n", "d = model.predictor.text_encoder.inference(d_en, s)\n", "x, _ = model.predictor.lstm(d)\n", "\n", "duration = model.predictor.duration_proj(x)\n", "duration = torch.sigmoid(duration).sum(axis=-1) / speed\n", "pred_dur = torch.round(duration).clamp(min=1).long()\n", "max_mels = pred_dur.sum().item()\n", "pred_aln_trg = torch.zeros(input_lengths, max_mels)\n", "\n", "c_start = F.pad(pred_dur,(1,0), \"constant\").cumsum(dim=1)[0,0:-1]\n", "c_end = c_start + pred_dur[0,:]\n", "\n", "for row, cs, ce in zip(pred_aln_trg, c_start, c_end):\n", " row[cs:ce] = 1\n", " \n", "en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)\n", "print(en.shape)\n", "F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n", "t_en = model.text_encoder.inference(tokens)\n", "asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)\n", "output = model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().detach().cpu().numpy()\n", "\n", "from IPython.display import display, Audio\n", "display(Audio(data=output, rate=24000, autoplay=True))\n", "print(out_ps)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "ename": "Error", "evalue": "Unable to infer type of dictionary: Cannot infer concrete type of torch.nn.Module", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m scrpt \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjit\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscript\u001b[49m\u001b[43m(\u001b[49m\u001b[43mMODEL\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/jit/_script.py:1429\u001b[0m, in \u001b[0;36mscript\u001b[0;34m(obj, optimize, _frames_up, _rcb, example_inputs)\u001b[0m\n\u001b[1;32m 1427\u001b[0m prev \u001b[38;5;241m=\u001b[39m _TOPLEVEL\n\u001b[1;32m 1428\u001b[0m _TOPLEVEL \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m-> 1429\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43m_script_impl\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1430\u001b[0m \u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1431\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1432\u001b[0m \u001b[43m \u001b[49m\u001b[43m_frames_up\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_frames_up\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1433\u001b[0m \u001b[43m \u001b[49m\u001b[43m_rcb\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_rcb\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1434\u001b[0m \u001b[43m \u001b[49m\u001b[43mexample_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexample_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1435\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1437\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m prev:\n\u001b[1;32m 1438\u001b[0m log_torchscript_usage(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mscript\u001b[39m\u001b[38;5;124m\"\u001b[39m, model_id\u001b[38;5;241m=\u001b[39m_get_model_id(ret))\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/jit/_script.py:1154\u001b[0m, in \u001b[0;36m_script_impl\u001b[0;34m(obj, optimize, _frames_up, _rcb, example_inputs)\u001b[0m\n\u001b[1;32m 1151\u001b[0m obj \u001b[38;5;241m=\u001b[39m obj\u001b[38;5;241m.\u001b[39m__prepare_scriptable__() \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(obj, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m__prepare_scriptable__\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m obj \u001b[38;5;66;03m# type: ignore[operator]\u001b[39;00m\n\u001b[1;32m 1153\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obj, \u001b[38;5;28mdict\u001b[39m):\n\u001b[0;32m-> 1154\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m create_script_dict(obj)\n\u001b[1;32m 1155\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obj, \u001b[38;5;28mlist\u001b[39m):\n\u001b[1;32m 1156\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m create_script_list(obj)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/jit/_script.py:1066\u001b[0m, in \u001b[0;36mcreate_script_dict\u001b[0;34m(obj)\u001b[0m\n\u001b[1;32m 1053\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_script_dict\u001b[39m(obj):\n\u001b[1;32m 1054\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1055\u001b[0m \u001b[38;5;124;03m Create a ``torch._C.ScriptDict`` instance with the data from ``obj``.\u001b[39;00m\n\u001b[1;32m 1056\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1064\u001b[0m \u001b[38;5;124;03m zero copy overhead.\u001b[39;00m\n\u001b[1;32m 1065\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1066\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mScriptDict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mError\u001b[0m: Unable to infer type of dictionary: Cannot infer concrete type of torch.nn.Module" ] } ], "source": [ "scrpt = torch.jit.script(model)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 143, 640])\n", "torch.Size([1, 640, 143])\n", "torch.Size([1, 512, 143])\n", "torch.Size([1, 143, 348])\n", "en.shape=torch.Size([1, 640, 348])\n", "s.shape=torch.Size([1, 128])\n", "en.dtype=torch.float32\n", "s.dtype=torch.float32\n", "torch.Size([1, 512, 143])\n", "torch.Size([1, 512, 348])\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAD8CAYAAABdPV+VAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAyOUlEQVR4nO3deXwTdf4/8NckadODNqVnGmih3CA3QikqgnS5XOQSAUEBERYEFFBEXLn87S4sHqso6qoruF9BRZdjRWHlagFbChQKchWKhbbQtFBo0oOmOT6/P6rB0HInnSR9PR+PPB6Zz3xm8s7HVF/OfGZGEkIIEBEREbkRhdwFEBEREV2PAYWIiIjcDgMKERERuR0GFCIiInI7DChERETkdhhQiIiIyO0woBAREZHbYUAhIiIit8OAQkRERG6HAYWIiIjcjqwBZcWKFWjcuDH8/PwQHx+Pffv2yVkOERERuQnZAsrXX3+N2bNnY+HChTh48CA6dOiAfv36obCwUK6SiIiIyE1Icj0sMD4+Hl27dsX7778PALDZbIiJicGMGTPwyiuv3HRbm82GCxcuICgoCJIk1Ua5REREdI+EECgpKYFOp4NCcfNjJKpaqslBZWUl0tPTMW/ePHubQqFAYmIiUlNTq/U3mUwwmUz25fPnz6NNmza1UisRERE5V25uLho2bHjTPrIElEuXLsFqtSIqKsqhPSoqCidPnqzWf8mSJVi8eHG19iYzFkCp9nNZnTdj9ROwxFag+bISWLOyZamBiIjIk1hgxh78gKCgoFv2lSWg3Kl58+Zh9uzZ9mWj0YiYmBjEvncYKslHlppsD3bEiE+SsTL5MYQKFSy/nJWlDiIiIo/x66SS25meIUtACQ8Ph1KpREFBgUN7QUEBtFpttf5qtRpqtbq2yrstij0Z+E+bKPTP2IVvznRCg2FyV0REROQ9ZLmKx9fXF126dMH27dvtbTabDdu3b0dCQoIcJd0dIZD8YgKshzVQJ2uhjIiQuyIiIiKvINtlxrNnz8Ynn3yCzz//HCdOnMDUqVNRVlaGCRMmyFXSXfHZlo6wY1aYLCoUPtYMivat5C6JiIjI48k2B2XkyJG4ePEiFixYAL1ej44dO2LLli3VJs56gsBv06D4IQDvH92Ip9ZOR5OjyqoVNqu8hREREXko2e6Dci+MRiM0Gg16YbBsk2SrkSQoWzfHyan18be+awEA780fiXpr98pcGBERkXuwCDOSsBEGgwHBwcE37esRV/F4BCFgPX4KkWndMS9gOABA1UVCRUgCwj+ufm8XIiIiujEGFCfTfLEXmi+q3us3tEZcj4so3xgJ68UinvIhIiK6TXyasQvpnjwH/TtN8fH+dUC3++Quh4iIyGMwoLiQrbwcmoN6DHj/ZZz6kw/0L/SQuyQiIiKPwFM8LmbJPgfdsnOo3NoI56RwaLu3BwCozhXCkq+XuToiIiL3xCMotcT3D+cQeEyN/637N/637t84M6WJ3CURERG5LR5BqUWx355Hz9OTAQCmvhbo9lY9LOnsolbw3bJfztKIiIjcCgNKLbJkn4N/9jkAQL1mPXChpQYAUNjZB/UiEiAJgdBNJ2AtNshZJhERkewYUGSiezMFeLPqfcPtwH9brYdVCAw9NRHSwRMQNsHLkomIqM5iQHEDqmckDK03FkKSEPuvX9C//inkmkPx44NxsF65Ind5REREtY4BxQ1YzuXa3+/aloBtYVX3TPGZo4QkAN8rUtURFyIiojqCAcXNNP5z1W3xlfXrY/K+/Wjicwn/MXTBgS8bATabvZ+l4CJPARERkddiQHFT1itX8M/Onaret2uCz1JXwFeSAACVQmDSwGdhO3pSzhKJiIhchgHFjdlKSgAAqlN5GPjOy4B0bV3Z3KvwUbdz6F+pD0Dz59Nqs0QiIiKXYEDxANZLRYh++3dzUCQJyh3R6B2R6dBvd1RzVDzQ0WmfqyyrhC3juNP2R0REdLsYUDyRELD2voBtCHJs7x6JH9etctrHzC3oiIxOTtsdERHRbWNA8SKKY9noNWmS0/ZnjFGh+/5D9uVSiy+KhvrDoi9w2mcQERHVhAHFi9hKSqD+3nm3zI/q2AbHh2nty2abAleHRkFZce05QhG79bBmZTvtM4mIiAAGFLoJW8ZxqPteW/b388Nfj21Agp/J3vbAoucRfrbqPi7CYqntEomIyEsxoNBts1VUYFm/wYBSaW8zLTVgxtyjAIDlT44A9v0sV3lERORFGFDojlx/Oke1vQem5z9d9f5xJaRhCZDMQJOlR2ArK5OjRCIi8gIMKHRPIt9PQeSv73V7gzA9ajsuWoPw7n+GQ3Em134vFyIiojuhkLsA8h76R6x4rW1vvPdgL7z0n6+RPaet3CUREZGHYkAhp7GVl8NWVgbrpSK8vGwyKjUCpz69H5CkW29MRET0Owwo5HTCYkH4P1OhvqxAfOtfYH24E5ThYXKXRUREHsTpAWXJkiXo2rUrgoKCEBkZiSFDhiAz0/GW7L169YIkSQ6vKVOmOLsUklns4hQYJ4Xj+9X/RNHAFnKXQ0REHsTpASU5ORnTpk3D3r17sXXrVpjNZvTt2xdl113RMWnSJOTn59tfy5Ytc3Yp5AbELznoP/E5FD9ahtbpKvsr+28JcpdGRERuzOlX8WzZssVhedWqVYiMjER6ejp69uxpbw8ICIBWq71+c/IytooK+G7ZD/FwAtKCG9nbzfWtKJpUPaREbT0Py9mc2iyRiIjckMsvMzYYDACA0NBQh/bVq1fjiy++gFarxaBBgzB//nwEBATUuA+TyQST6drdS41Go+sKJpeIm5fqsFz6chSSFr5drV9i+UyE5Fc960f87p85ERHVLZIQQrhq5zabDY899hiKi4uxZ88ee/vHH3+MRo0aQafT4ciRI5g7dy66deuGdevW1bifRYsWYfHixdXae2EwVJKPq8onF1IGBwNR4dXaDe8BC5t9BxsUeHfIUNiOnpShOiIicgWLMCMJG2EwGBAcHHzTvi4NKFOnTsXmzZuxZ88eNGzY8Ib9duzYgT59+iArKwtNmzattr6mIygxMTEMKF5I/0IPlDS1AgAUlRIgJCgrgSb/7xBsFRUyV0dERPfiTgKKy07xTJ8+HZs2bcKuXbtuGk4AID4+HgBuGFDUajXUarVL6iT3on03BVoAkCS0PqDEqPppOGsOx+df9oOq5NpEa6u+kKeAiIi8mNMDihACM2bMwPr165GUlIS4uLhbbpORkQEAiI6OdnY55KmEwMkeSiyWHoAiPAxLdn+OOB+bffWg6S/Af8M+GQskIiJXcnpAmTZtGtasWYONGzciKCgIer0eAKDRaODv748zZ85gzZo1GDhwIMLCwnDkyBHMmjULPXv2RPv27Z1dDnkwYTJBABAXL2Hi0pmwqa7dkdYwxATNhFawCgkxM0pgyc2Tr1AiInI6p89BkW5wW/OVK1di/PjxyM3NxdixY3H06FGUlZUhJiYGQ4cOxWuvvXbL81G/MRqN0Gg0nINSh51Z0xGj2qTDLJTY/ffu8C8029f5ZebDcv6CjNUREVFN3GaSrKswoJCdQomJJ7PwRD2Dvan9W88h+q0UGYsiIqKauMUkWaJaYbPik2eG4iM/pb2pZIwJAw5VvT/4fCcodh+SqTgiIrpbDCjk8RR7Mhye2RB4fw/siKh69k/5Q/4I13SD3yZOqCUi8iQMKOR1Giy9dnqndaoev/QMh7RVzcuSiYg8iNMfFkjkToqGBaDsv1rMPJYBVZPGcpdDRES3iUdQyKtZ8vWI2h+Bad9NAGYAgBaqcglxi9IhzJVyl0dERDfAgEJeT+z/Gc0PqRCffhW9653AzxUx2Lw6HlJFJVBphiXvvNwlEhHRdRhQqE4QFgvSutVDGrpCEaPDiu2fIUKpwsKCHjjaRe7qiIjoegwoVGf8NklWXCjAiP83B0IJmEIkBHx/EQBQsj8CsYt4/xQiInfAgEJ1jq28HGGfpgIApPvbot3ILADA+ub1YE7sAt/dR3nFDxGRzHgVD9Vp4sBRHOkscKSzgGaXH9aveh9KnVbusoiI6jweQSH6lXbjLxiaPQO61VloGmjGzwYdyvsYeLUPEZEMGFCIfmXRF8D3UhFSxrbHkeBomEw+8HuuOSQb4H/JhuAv98pdIhFRncGAQvQ7wmJB86cPAgBUcY3wftIHCFEo8Gf9I8jeGADb1auA5z1fk4jI4zCgEN2A5Vwepj84ClBIKO6mw1tHP8CCsc9ASjksd2lERF6PAYXoRmxWWHLzAAAaHxWe+nY6rCMFMDoekkVCy4XHYTUaZS6SiMg7MaAQ3QZrVjaavJwN/+QoPNdgB4ptAfh07RCoLpVAslhhyT4nd4lERF6FAYXoDlT0KcLbUico6gXitfTPcb9vJTaVR+CzNs0hLBa5yyMi8hoMKER34LcQYjVYMWfhVFjVgCVAgv93BVBIAvmnItD8+TSZqyQi8nwMKER3w2ZFyP9V3Y1W1SgGujGXoVZYUBqrhmlgVwCA//lS2A6fkLNKIiKPxYBCdI8s53KRE1/1PuSPQUj69GMAQMvdT6PxSBkLIyLyYAwoRE4UuCcTf3hiPABA8UgAHvm5zL5ud1EzmHvreR8VIqLbwIBC5ETWYgMUezIAABFh3fDvDt3s60wmHwTOjAN+yycCaLA6E9ZLRbVfKBGRm2NAIXIR/4370HDjtWXlfS3xr83vw0eSAAA2AGMPTIPyasU9f5awWPgEZiLyKgwoRLXEeiILzyY8ca1BktDpuwyMDrn3q37GHR4P7RBOyCUi7+H0gLJo0SIsXrzYoa1ly5Y4efIkAKCiogIvvvgivvrqK5hMJvTr1w8ffPABoqKinF0KkXuxWWE5f8Gh6Ycve2B9aMK971oJnPqg6nSSdreCDzYkIo/nkiMo9913H7Zt23btQ1TXPmbWrFn4/vvv8c0330Cj0WD69OkYNmwYfvrpJ1eUQuTWdMtSnLKf0hHxmPu3LwAAL2As6mc0r1qRXwhrscEpn0FEVJtcElBUKhW0Wm21doPBgH/9619Ys2YNHnnkEQDAypUr0bp1a+zduxfdu3d3RTlEXq/et/vwwfo2AACf15VYv20NACD+7y8garlzQhARUW1SuGKnp0+fhk6nQ5MmTTBmzBjk5OQAANLT02E2m5GYmGjv26pVK8TGxiI1NfWG+zOZTDAajQ4vIvodIaomylosaLrWgIf+/Dwe+vPzMLQ149S/7pe7OiKiO+b0gBIfH49Vq1Zhy5Yt+PDDD5GdnY2HHnoIJSUl0Ov18PX1RUhIiMM2UVFR0Ov1N9znkiVLoNFo7K+YmBhnl03kNWwZx1F/VSrqr0qF6ooKEVoDKv7YDcqwULlLIyK6bU4PKAMGDMCIESPQvn179OvXDz/88AOKi4uxdu3au97nvHnzYDAY7K/c3FwnVkzkvZrMTYVmaT0kf/wxyrs1lbscIqLb5pJTPL8XEhKCFi1aICsrC1qtFpWVlSguLnboU1BQUOOcld+o1WoEBwc7vIjo9qgOnkLfx8fBNvMSTn3K0z1E5BlcHlBKS0tx5swZREdHo0uXLvDx8cH27dvt6zMzM5GTk4OEhHu/1JKIqrOVlUFKOYz8Q1rAokD+7B5QMuQTkZtzekB56aWXkJycjLNnzyIlJQVDhw6FUqnE6NGjodFoMHHiRMyePRs7d+5Eeno6JkyYgISEBF7BQ+RiTV5JRXiaEt++8AZEYx2UwcFQBAXJXRYRUY2cfplxXl4eRo8ejaKiIkRERODBBx/E3r17ERERAQD4xz/+AYVCgeHDhzvcqI2IXC9szUHM2jEKgzfvRo+AM9h7tQnWd4qFreLeb7dPRORMkhCe92hVo9EIjUaDXhgMleQjdzlEnkWhxPmX42EOFIACqAy3ANK11a3fvAzrqTPy1UdEXssizEjCRhgMhlvOJ+WzeIjqGpsVDZZW3bxN1bABhm1NR4iy3L76re1PIsTnun81mC0MLURUqxhQiOowS955rL2vgUPbQ+l78ZfIdIe2H68G4r3W7SHMlbVZHhHVYQwoRHWdzeqwuHd+NzxY3/GqOnMgUH/zeagkG06d1qHFlH21WSER1UEMKETkwG/TPvhd16Zq2ACBYyvgpzTjQlQwyofG29dJNoGAHzJ4dIWInIoBhYhuyZJ3HoYHAQOAsIGBSPr0n/Z1hdYyTOg+ApbzF+QrkIi8DgMKEd0R/90n0X/wU/Zlq58KnTcdQpz6InZdaYGih4zVThsREd0pBhQiuiO2khJg/8/2ZaVajbUHukIZaIatUomAl9UAgMALAiH/vvFTyomIboYBhYjuiTCZ0GLSfgCAsmUzvPPj5/CTBGafG4Ly7+rDauARFSK6cy5/Fg8R1R3WU2cw8/4hmNJlKC6saIa3Dn0PqXNrucsiIg/EgEJEziMErBcvwnrxIkLSCzFs1UvIfDYAlybzYaBEdGcYUIjIJaynf0HsohR0aH0OV3pUQtG2FaBQyl0WEXkIBhQicqmrvQoRcFKND77/FMqwULnLISIPwUmyRORaQiB2XQFGnZ+Del+eR5BPFDILIxH7xFHA855VSkS1hAGFiFzOeuoMQs/rUfF0OPxUZoTUK0fp490gXZdP6mWXQqQfk6dIInIrDChEVCtsZWXw/UMZrgAI6RCFDT+8DxUc56Q03zoJzSdIjhvyKAtRncSAQkS1Tpz8BYMfG1+t3XewL544nm9f3ljYEaaH9bVYGRG5CwYUIqp1wmQCajiVExnTDX9tPPBav0oF/P/cpHq/DDPU3+93aY1EJC8GFCJyG/4b9qH5hmvLUqf7sGLDO9UuN3xkyyy03lt1RZCttKwq8BCRV2FAISK3JQ6fxIwug6u1+85U4b2D/wUADH/zZUS9l1LbpRGRizGgEJH7sllhvVRUrTl2awUGXX0ZAFDRwgrjvzsDAFq+WQ7bkZO1WiIRuQYDChF5HEXyIcQkV70/vSIeC7ptAgB80Hk4Ikobw/LLWfmKIyKnYEAhIo/WfFoavpQaAAA6pWXgxGgt/PvJXBQR3TPe6p6IPJ8QgBA4/WobFCVHI2h3OIJ2hyP/xR5yV0ZEd4kBhYi8hs+2dEQcMsNiU8JiU6I01oaSUd1RMqo7VE0ay10eEd0BBhQi8irqzftx9eECXH24AKqrEna/9QF2v/UBzj+qAySp6kVEbs/pAaVx48aQJKnaa9q0aQCAXr16VVs3ZcoUZ5dBRIRm/ziDPw4cgz8OHIOyhHI8fTIHT5/Mge3BjnKXRkS34PRJsvv374fVarUvHz16FH/4wx8wYsQIe9ukSZPw+uuv25cDAgKcXQYREawFhUBBIQAgMLUHXjMNAQD4DPSF8pEekGxA4+VHYTUaZaySiGri9IASERHhsLx06VI0bdoUDz/8sL0tICAAWq3W2R9NRHRDUe+lIOrX92E/1ceihptQbPPFgu/HQ5lXCAhbjfdcISJ5uPQy48rKSnzxxReYPXs2pN+d9129ejW++OILaLVaDBo0CPPnz7/pURSTyQTT725lbeT/7RDRPbjS34KZqj9CUvtifPIP6OGfi6TyxljToSlvm0/kJlwaUDZs2IDi4mKMHz/e3vbkk0+iUaNG0Ol0OHLkCObOnYvMzEysW7fuhvtZsmQJFi9e7MpSiagOsZWUVL1RKLHkg9GwBABCBZj+dRWSBCjO+SHu1VR5iySq4yQhhHDVzvv16wdfX1989913N+yzY8cO9OnTB1lZWWjatGmNfWo6ghITE4NeGAyV5OP0uomo7lHFNESvzScQpKjANxe6wGee5tpKIYCMkxAWi3wFEnkBizAjCRthMBgQHBx8074uO4Jy7tw5bNu27aZHRgAgPj4eAG4aUNRqNdRqtdNrJCL6jSU3D9vaBgEIgq1/FLZs/MS+7pK1DOMSnoAl77x8BRLVMS4LKCtXrkRkZCQeffTRm/bLyMgAAERHR7uqFCKiO+J/4Bc8NO1P9mWbCgj6dx6i/K/9H9+5v7aC36Z9cpRHVCe4JKDYbDasXLkS48aNg0p17SPOnDmDNWvWYODAgQgLC8ORI0cwa9Ys9OzZE+3bt3dFKUREd8x6qQgB669d0SOp1Sh8qjEU0rUz4hc7qhCs6W5fDj5bAemnjNosk8iruSSgbNu2DTk5OXjmmWcc2n19fbFt2za88847KCsrQ0xMDIYPH47XXnvNFWUQETmFMJkQ8VgmrL9rC93ih51Tv7Ev37d7AuJSlb9uYKuat0JEd82lk2RdxWg0QqPRcJIsEclG1SgGtqBA+/K5waFYNG41AOD/fTwGujdT5CqNyG25xSRZIiJvZjmX67Cs1d2PuY2eAAAoo224MKcHdG+m8kgK0V3iwwKJiJzA58cDaDF5P1pM3g+br8DYp7dCpYuGxCsQie4KAwoRkZO1nHMYOyZ2x/spa2Ec2knucog8Ek/xEBE5ma2iAsrTeXjsvZdR2tuMoqHtIQTQfJ4BluxzcpdH5BF4BIWIyAWsV65A92YKFIEWTG27C8+124XirtEQCR2Abu0AhVLuEoncGo+gEBG5ULOxh7AZIYAk4alj32FKyHlkm0sxrctgPj2Z6CYYUIiIaoMQ+GZGf6ypp4TVV0L0+jPQ+NTDwYKGiBySyat9iK7DgEJEVEtU29OhAqAIDET20w0Q7GeCxabAlXFVd6QN1Fvgu2W/vEUSuQkGFCKiWmYrK0P4oFMAgIYd2+A/m96FUpIw/PRjsGxT8anJROAkWSIiWYljWXi8z5MYmvgk8tc0xpQTJ6Fs00LusohkxyMoREQyEuZKWDOzAAARAb6Ytf1JqJ5WQrIlQFkhodHf0yFMJpmrJKp9DChERG5CpB9Di4MSuhy04omQ/ThljsSqr/8AqbQcwmKBtaBQ7hKJag0DChGROxEChx6oh0PKRCjCQ/G3HauhU1nw4eWuSOngK3d1RLWGAYWIyM3Yysur3pjNGP/2LAgVYA4AxDclkCQBcSQYsa/zacnk3RhQiIjclK2iAlHvVQURRftWeGjkISglGz5HPGwPdoRy/wnOTyGvxat4iIg8gO3ISSS398eOdoEI+m8QNn39CRRNG8ldFpHLMKAQEXmY8B9/Qd8/TYPt/TLkvdpD7nKIXIIBhYjIw1j0BfD7IR1Z+ghc1VphGNudDx8kr8OAQkTkiWxWNH0yAz5GBd56/QMo62sASZK7KiKnYUAhIvJgTf+RicVjJmBMymEUPdtd7nKInIYBhYjIg1mLLkN1LBsLvh+B4lYCF+ZwTgp5BwYUIiIPZzUa0Wz2XtjCK9H98cNQxTSE5MObupFnY0AhIvISLZ49il9ebYUvUtbC9Eh7ucshuicMKEREXkKYK+F3LA+9l72Ec0/akLOAp3vIczGgEBF5EYu+AFHvpcC/ngnWNqWw9uoMa6/OUDVpLHdpRHfkjgPKrl27MGjQIOh0OkiShA0bNjisF0JgwYIFiI6Ohr+/PxITE3H69GmHPpcvX8aYMWMQHByMkJAQTJw4EaWlpff0RYiI6JqGw48heFsgtq35DNvWfIYTM6PkLonojtxxQCkrK0OHDh2wYsWKGtcvW7YMy5cvx0cffYS0tDQEBgaiX79+qKiosPcZM2YMjh07hq1bt2LTpk3YtWsXJk+efPffgoiIqon6Phu9JzyL3hOehcIswScpGpJaLXdZRLdFEkKIu95YkrB+/XoMGTIEQNXRE51OhxdffBEvvfQSAMBgMCAqKgqrVq3CqFGjcOLECbRp0wb79+/H/fffDwDYsmULBg4ciLy8POh0ult+rtFohEajQS8MhkryudvyiYjqjKKJCTD90QCfHzXQJl2C9cTpW29E5GQWYUYSNsJgMCA4OPimfZ06ByU7Oxt6vR6JiYn2No1Gg/j4eKSmpgIAUlNTERISYg8nAJCYmAiFQoG0tLQa92symWA0Gh1eRER0+8L+lYrYmSXYMO8N5PeJgKRWVx1N4d1nyU05NaDo9XoAQFSU47nOqKgo+zq9Xo/IyEiH9SqVCqGhofY+11uyZAk0Go39FRMT48yyiYjqBEvuBUx9eAxMvYx45fh+vHJ8P2wPdJC7LKIaqeQu4HbMmzcPs2fPti8bjUaGFCKiO2WzwpJ9DgH/S8DEvF/n/Q0DpKHdIVkkNF9yHNZig7w1Ev3KqQFFq9UCAAoKChAdHW1vLygoQMeOHe19CgsLHbazWCy4fPmyffvrqdVqqDmxi4jIKcI+SUXYr+/9k6MwO+Z/KLH5Yfn6J6AqMAAWKyy5ebLWSOTUUzxxcXHQarXYvn27vc1oNCItLQ0JCQkAgISEBBQXFyM9Pd3eZ8eOHbDZbIiPj3dmOUREdAsVfYuxpE13rIjvgRe/+BJrdn+FZ7bv4q3ySXZ3fASltLQUWVlZ9uXs7GxkZGQgNDQUsbGxmDlzJv7yl7+gefPmiIuLw/z586HT6exX+rRu3Rr9+/fHpEmT8NFHH8FsNmP69OkYNWrUbV3BQ0REziNMJggAqDTjz397FhY/CVY/QPntZUiSgOFMfTSbtVfuMqkOuuOAcuDAAfTu3du+/NvckHHjxmHVqlV4+eWXUVZWhsmTJ6O4uBgPPvggtmzZAj8/P/s2q1evxvTp09GnTx8oFAoMHz4cy5cvd8LXISKiu2KzIvSzqqstVTEN0XyMHmqFBTsULWBO7GLvJtkAn12HISwWuSqlOuKe7oMiF94HhYiodpgGdkXSp5/YlwutZZjQfQQs5y/IWBV5qju5D4pHXMVDRETyCEg5hT5jJ9qXrWoF4v5zEg38rPa23fMT4PfdPjnKIy/GgEJERDdkLTZAtePaRQ0+ajV+eroV6gVee3zJ1S4q+MU4PjlZYRGI+PwghMlUa7WSd2FAISKi2yZMJsSNPuzQVralCb5vu9qhLc8CvLx5BGwXLzn1822VZsBmvXVH8ngMKEREdE+Cx5XiSf+RDm0i0B+jtm5HU9/CG2x1d6a/OR2RH6Q4dZ/knhhQiIjonlgLqocQSa3GX/8zAlYn305FNLfBsPzaPbM0J5UMLF6KAYWIiJxOmExo/OdUp+/31Cdd8ekjn9mXX2o0Asof4+zLksUKy9kcp38u1T4GFCIi8hgtpx7CG8r77cvKsfWxNukf9uU0UyDebBcPW3m5HOWREzGgEBGRxxAWC/C7m8RFJuvRa/Es+7LFT4LP2ksI+GcIL332cAwoRETksaxZ2QjLyrYvK6Mi0WBsGX5qH46GRR0gpRy+ydbkzpz6sEAiIiI5WQsKkRNfBlXXK0j48IDc5dA9YEAhIiKv0/BVK354pyd6HK6ESOggdzl0F3iKh4iIvI71+ClEoAW+/aUjrH0C4du1ByQboF15GLayMrnLo9vAgEJERF7JevwUdEOBxvv88dfobSixCUzbPgFS1jkIc6Xc5dEtMKAQEZFXyx0UhHE+TwC+Pnh0014sX/9Hl9yjhZyLc1CIiMirWQsKYck7D8vZXLz/1SBYAgR+WZoAScX/R3dnDChERFQ32KyIfT0FPkYFxg3cCalVMyhbNoOqUYzclVENGB+JiKhOabQ4DSmr78Nn2z6FRuGLl/MfwumucldF12NAISKiusVmhTivx2ML5wASUBEuof7/8gEAF3+KRuzrfPigO2BAISKiOsdWVobQlb9OlO3eHk1HVAWULU1CYRrQFX7bDvNKH5lxDgoREdVte4/gTNcKnOlageCDanz78TtQRkXIXVWdx4BCRET0K93XWRj19AzErb+EC3N6yF1OncaAQkRE9CtrQSFUe47ih2NtcVVrw6XJCYAkyV1WncSAQkRE9DvCXInm49MhlMALs7+BMrQ+75kiAwYUIiKiGrSYfwxfTHgUC/ZvRfHI++Uup86544Cya9cuDBo0CDqdDpIkYcOGDfZ1ZrMZc+fORbt27RAYGAidToenn34aFy5ccNhH48aNIUmSw2vp0qX3/GWIiIicxVZSAtWpXIxbPR2F3QXy5nFOSm2644BSVlaGDh06YMWKFdXWlZeX4+DBg5g/fz4OHjyIdevWITMzE4899li1vq+//jry8/PtrxkzZtzdNyAiInIRa9FlNJ6fClV4BXSJuVC2aQFJrZa7rDrhjk+qDRgwAAMGDKhxnUajwdatWx3a3n//fXTr1g05OTmIjY21twcFBUGr1d7pxxMREdW6uDFHcXVQF2zaugKDHn8WUuphuUvyei6fg2IwGCBJEkJCQhzaly5dirCwMHTq1AlvvPEGLBbLDfdhMplgNBodXkRERLXGZkXQvhz0nPc8zr9kQe5rPN3jai6dllxRUYG5c+di9OjRCA4Otrc///zz6Ny5M0JDQ5GSkoJ58+YhPz8fb7/9do37WbJkCRYvXuzKUomIiG7Kkq9HyL/1KB18H642rsTVwd0AAPWOX4L19C8yV+d9JCGEuOuNJQnr16/HkCFDqq0zm80YPnw48vLykJSU5BBQrvfZZ5/hT3/6E0pLS6Gu4dyeyWSCyWSyLxuNRsTExKAXBkMl+dxt+URERHflwks98PPsDwAArT55Do0W8vk9t8MizEjCRhgMhpvmAsBFR1DMZjOeeOIJnDt3Djt27LhlEfHx8bBYLDh79ixatmxZbb1ara4xuBAREckh5v+y0PencQAA0zgLAndFoKzXZcBmlbky7+H0gPJbODl9+jR27tyJsLCwW26TkZEBhUKByMhIZ5dDRETkdNaCQkgFhQCA0HYJyPCPQcCc5ojdUAhrZpbM1XmHOw4opaWlyMq6NvjZ2dnIyMhAaGgooqOj8fjjj+PgwYPYtGkTrFYr9Ho9ACA0NBS+vr5ITU1FWloaevfujaCgIKSmpmLWrFkYO3Ys6tev77xvRkREVAvCP05F1J4WWLF5OUYWzEFYXj5sZWVyl+Xx7ngOSlJSEnr37l2tfdy4cVi0aBHi4uJq3G7nzp3o1asXDh48iOeeew4nT56EyWRCXFwcnnrqKcyePfu2T+MYjUZoNBrOQSEiIvegUEIZEYbSfwfCR2mFKjFH7orc0p3MQbmnSbJyYUAhIiJ3dHlCAopbApYwC9r8RQ/LuVy5S3IrdxJQ+CweIiIiJwldmYoGyRb8pec6lHSKhqphA7lL8lgMKERERE6k3rwf/3dfHEb+bTNO/C1K7nI8FgMKERGRkwmLBWvnDoAqTw2fpGj4JEWj+KkEucvyKAwoRERELuD33T6EZALBPhUI9qnAldZA2fB4lA2Ph0rLIyu34tJb3RMREdVl9T9PRdHnVe8tn5uwZ/xnAIAHZk5BvbUFMlbm/hhQiIiIakHrhZfQ/72xAIDKRZcxfHEhzEKJ74fG8+ZuNWBAISIiqgWWsznA2ar3pft6YJmhHwDAZ5Q/FJWRUJoA3XsHIMyV8hXpRhhQiIiIalns678+XFChxCOHjRitOYSTlfXx7roBEIYSez9bSQmExSJTlfJiQCEiIpKLzYrkXjFIVjQGQjV4efs66FTXAsrkqTOh/mG/fPXJiAGFiIhIRtaiywAARVkZpn06BeJ3/2W++pgF0hNdIGwS2ryaC4u+7kysZUAhIiJyA7bycjRckuLQdvG/LTGzxXaYhQr/13kQ/C+E2ddJuXp7uPFGDChERERuKmLwKaxGDCRfX8z4+WsMCSy2r+u8bDq076bceGMPx4BCRETkrn59nq8wmbD8xVF4K+Da/VWNf6hE010R9uUj+5qi6Ut7a71EV2FAISIi8gB+3+2D3++Wi9onQNFS2JdtkZUofaJ7rddV72wZsO9np++XAYWIiMgDNf5zKkp+txw6sTl2/WNFrdfRMvkZNB0j3bqjELfu8zsMKERERF4gct1JPHrwqVr/XJ9BARh94vwt+/1l0zA0enH3be+XAYWIiMgLWK9cAa5cqfXPjdJ1xetNBt26ow+QNzce+PvG29qvRwYU8ethIgvMwJ0dMSIiIiInUm5KQdymW/fLmxuP7SM/RqO/X/vv+M1I4nZ6uZm8vDzExMTIXQYRERHdhdzcXDRs2PCmfTwyoNhsNmRmZqJNmzbIzc1FcHCw3CW5FaPRiJiYGI5NDTg2N8fxuTGOzY1xbG6MY+NICIGSkhLodDooFIqb9vXIUzwKhQINGjQAAAQHB/Mf+g1wbG6MY3NzHJ8b49jcGMfmxjg212g0mtvqd/P4QkRERCQDBhQiIiJyOx4bUNRqNRYuXAi1Wi13KW6HY3NjHJub4/jcGMfmxjg2N8axuXseOUmWiIiIvJvHHkEhIiIi78WAQkRERG6HAYWIiIjcDgMKERERuR2PDCgrVqxA48aN4efnh/j4eOzbt0/ukmrdokWLIEmSw6tVq1b29RUVFZg2bRrCwsJQr149DB8+HAUFBTJW7Fq7du3CoEGDoNPpIEkSNmzY4LBeCIEFCxYgOjoa/v7+SExMxOnTpx36XL58GWPGjEFwcDBCQkIwceJElJaW1uK3cI1bjc348eOr/Zb69+/v0Mdbx2bJkiXo2rUrgoKCEBkZiSFDhiAzM9Ohz+38LeXk5ODRRx9FQEAAIiMjMWfOHFgsltr8Kk53O2PTq1evar+dKVOmOPTxxrH58MMP0b59e/vN1xISErB582b7+rr6m3E2jwsoX3/9NWbPno2FCxfi4MGD6NChA/r164fCwkK5S6t19913H/Lz8+2vPXv22NfNmjUL3333Hb755hskJyfjwoULGDZsmIzVulZZWRk6dOiAFStW1Lh+2bJlWL58OT766COkpaUhMDAQ/fr1Q0VFhb3PmDFjcOzYMWzduhWbNm3Crl27MHny5Nr6Ci5zq7EBgP79+zv8lr788kuH9d46NsnJyZg2bRr27t2LrVu3wmw2o2/fvigrK7P3udXfktVqxaOPPorKykqkpKTg888/x6pVq7BgwQI5vpLT3M7YAMCkSZMcfjvLli2zr/PWsWnYsCGWLl2K9PR0HDhwAI888ggGDx6MY8eOAai7vxmnEx6mW7duYtq0afZlq9UqdDqdWLJkiYxV1b6FCxeKDh061LiuuLhY+Pj4iG+++cbeduLECQFApKam1lKF8gEg1q9fb1+22WxCq9WKN954w95WXFws1Gq1+PLLL4UQQhw/flwAEPv377f32bx5s5AkSZw/f77Wane168dGCCHGjRsnBg8efMNt6srYCCFEYWGhACCSk5OFELf3t/TDDz8IhUIh9Hq9vc+HH34ogoODhclkqt0v4ELXj40QQjz88MPihRdeuOE2dWVshBCifv364tNPP+Vvxok86ghKZWUl0tPTkZiYaG9TKBRITExEamqqjJXJ4/Tp09DpdGjSpAnGjBmDnJwcAEB6ejrMZrPDOLVq1QqxsbF1cpyys7Oh1+sdxkOj0SA+Pt4+HqmpqQgJCcH9999v75OYmAiFQoG0tLRar7m2JSUlITIyEi1btsTUqVNRVFRkX1eXxsZgMAAAQkNDAdze31JqairatWuHqKgoe59+/frBaDTa/4/aG1w/Nr9ZvXo1wsPD0bZtW8ybNw/l5eX2dXVhbKxWK7766iuUlZUhISGBvxkn8qiHBV66dAlWq9XhHyoAREVF4eTJkzJVJY/4+HisWrUKLVu2RH5+PhYvXoyHHnoIR48ehV6vh6+vL0JCQhy2iYqKgl6vl6dgGf32nWv63fy2Tq/XIzIy0mG9SqVCaGio149Z//79MWzYMMTFxeHMmTN49dVXMWDAAKSmpkKpVNaZsbHZbJg5cyYeeOABtG3bFgBu629Jr9fX+Nv6bZ03qGlsAODJJ59Eo0aNoNPpcOTIEcydOxeZmZlYt24dAO8em59//hkJCQmoqKhAvXr1sH79erRp0wYZGRn8zTiJRwUUumbAgAH29+3bt0d8fDwaNWqEtWvXwt/fX8bKyNOMGjXK/r5du3Zo3749mjZtiqSkJPTp00fGymrXtGnTcPToUYe5XFTlRmPz+3lI7dq1Q3R0NPr06YMzZ86gadOmtV1mrWrZsiUyMjJgMBjw7bffYty4cUhOTpa7LK/iUad4wsPDoVQqq82GLigogFarlakq9xASEoIWLVogKysLWq0WlZWVKC4uduhTV8fpt+98s9+NVqutNtHaYrHg8uXLdW7MmjRpgvDwcGRlZQGoG2Mzffp0bNq0CTt37kTDhg3t7bfzt6TVamv8bf22ztPdaGxqEh8fDwAOvx1vHRtfX180a9YMXbp0wZIlS9ChQwe8++67/M04kUcFFF9fX3Tp0gXbt2+3t9lsNmzfvh0JCQkyVia/0tJSnDlzBtHR0ejSpQt8fHwcxikzMxM5OTl1cpzi4uKg1WodxsNoNCItLc0+HgkJCSguLkZ6erq9z44dO2Cz2ez/0q0r8vLyUFRUhOjoaADePTZCCEyfPh3r16/Hjh07EBcX57D+dv6WEhIS8PPPPzuEuK1btyI4OBht2rSpnS/iArcam5pkZGQAgMNvxxvHpiY2mw0mk6lO/2acTu5Zunfqq6++Emq1WqxatUocP35cTJ48WYSEhDjMhq4LXnzxRZGUlCSys7PFTz/9JBITE0V4eLgoLCwUQggxZcoUERsbK3bs2CEOHDggEhISREJCgsxVu05JSYk4dOiQOHTokAAg3n77bXHo0CFx7tw5IYQQS5cuFSEhIWLjxo3iyJEjYvDgwSIuLk5cvXrVvo/+/fuLTp06ibS0NLFnzx7RvHlzMXr0aLm+ktPcbGxKSkrESy+9JFJTU0V2drbYtm2b6Ny5s2jevLmoqKiw78Nbx2bq1KlCo9GIpKQkkZ+fb3+Vl5fb+9zqb8lisYi2bduKvn37ioyMDLFlyxYREREh5s2bJ8dXcppbjU1WVpZ4/fXXxYEDB0R2drbYuHGjaNKkiejZs6d9H946Nq+88opITk4W2dnZ4siRI+KVV14RkiSJH3/8UQhRd38zzuZxAUUIId577z0RGxsrfH19Rbdu3cTevXvlLqnWjRw5UkRHRwtfX1/RoEEDMXLkSJGVlWVff/XqVfHcc8+J+vXri4CAADF06FCRn58vY8WutXPnTgGg2mvcuHFCiKpLjefPny+ioqKEWq0Wffr0EZmZmQ77KCoqEqNHjxb16tUTwcHBYsKECaKkpESGb+NcNxub8vJy0bdvXxERESF8fHxEo0aNxKRJk6oFfm8dm5rGBYBYuXKlvc/t/C2dPXtWDBgwQPj7+4vw8HDx4osvCrPZXMvfxrluNTY5OTmiZ8+eIjQ0VKjVatGsWTMxZ84cYTAYHPbjjWPzzDPPiEaNGglfX18REREh+vTpYw8nQtTd34yzSUIIUXvHa4iIiIhuzaPmoBAREVHdwIBCREREbocBhYiIiNwOAwoRERG5HQYUIiIicjsMKEREROR2GFCIiIjI7TCgEBERkdthQCEiIiK3w4BCREREbocBhYiIiNwOAwoRERG5nf8PgBiqWsr6jrYAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(d.shape)\n", "print(d.transpose(-1,-2).shape)\n", "print(d_en.shape)\n", "print(pred_aln_trg.unsqueeze(0).shape)\n", "print(f\"{en.shape=}\")\n", "print(f\"{s.shape=}\")\n", "print(f\"{en.dtype=}\")\n", "print(f\"{s.dtype=}\")\n", "\n", "print(t_en.shape)\n", "print(asr.shape)\n", "pl.imshow(pred_aln_trg[:,:])\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Export StyleTTS2 model" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "I0101 18:56:11.998000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3557] create_symbol s0 = 143 for L['args'][0][0].size()[1] [2, 510] (_export/non_strict_utils.py:109 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"s0\"\n", "I0101 18:56:12.007000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:4857] set_replacement s0 = 143 (range_refined_to_singleton) VR[143, 143]\n", "I0101 18:56:12.008000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:5106] eval Eq(s0, 143) [guard added] (mp/ipykernel_2488298/2554868606.py:17 in forward), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"Eq(s0, 143)\"\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "I0101 18:56:27.383000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:3317] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:390 in local_scalar_dense)\n", "I0101 18:56:27.387000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:5106] runtime_assert u0 >= 0 [guard added] (_refs/__init__.py:4957 in arange), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"u0 >= 0\"\n", "W0101 18:56:28.575000 2488298 site-packages/torch/fx/experimental/symbolic_shapes.py:5124] failed during evaluate_expr(u0, hint=None, size_oblivious=False, forcing_spec=False\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] failed while running evaluate_expr(*(u0, None), **{'fx_node': False})\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] Traceback (most recent call last):\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/recording.py\", line 262, in wrapper\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] return retlog(fn(*args, **kwargs))\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py\", line 5122, in evaluate_expr\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] return self._evaluate_expr(orig_expr, hint, fx_node, size_oblivious, forcing_spec=forcing_spec)\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py\", line 5238, in _evaluate_expr\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] raise self._make_data_dependent_error(\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] \n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] Potential framework code culprit (scroll up for full backtrace):\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] return self._op_dk(dk, *args, **kwargs)\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] \n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For more information, run with TORCH_LOGS=\"dynamic\"\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] \n", "E0101 18:56:28.576000 2488298 site-packages/torch/fx/experimental/recording.py:298] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n" ] }, { "ename": "GuardOnDataDependentSymNode", "evalue": "Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)\n\nPotential framework code culprit (scroll up for full backtrace):\n File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n return self._op_dk(dk, *args, **kwargs)\n\nFor more information, run with TORCH_LOGS=\"dynamic\"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n\nThe following call raised this error:\n File \"/rhome/eingerman/Projects/DeepLearning/TTS/Kokoro-82M/models.py\", line 471, in F0Ntrain\n x2, _temp = self.shared(x1)\n\nTo fix the error, insert one of the following checks before this call:\n 1. torch._check(x.shape[2])\n 2. torch._check(~x.shape[2])\n\n(These suggested fixes were derived by replacing `u0` with x.shape[2] or x1.shape[1] in u0 and its negation.)", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mGuardOnDataDependentSymNode\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[39], line 61\u001b[0m\n\u001b[1;32m 58\u001b[0m dynamic_shapes \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokens\u001b[39m\u001b[38;5;124m\"\u001b[39m:{\u001b[38;5;241m0\u001b[39m:batch, \u001b[38;5;241m1\u001b[39m:token_len}}\n\u001b[1;32m 60\u001b[0m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[0;32m---> 61\u001b[0m export_mod \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstyle_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;66;03m# export_mod = torch.export.export(style_model, args=( tokens, ), strict=False)\u001b[39;00m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/__init__.py:270\u001b[0m, in \u001b[0;36mexport\u001b[0;34m(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature)\u001b[0m\n\u001b[1;32m 264\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(mod, torch\u001b[38;5;241m.\u001b[39mjit\u001b[38;5;241m.\u001b[39mScriptModule):\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 266\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExporting a ScriptModule is not supported. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 267\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMaybe try converting your ScriptModule to an ExportedProgram \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 268\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124musing `TS2EPConverter(mod, args, kwargs).convert()` instead.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 269\u001b[0m )\n\u001b[0;32m--> 270\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_export\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstrict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 276\u001b[0m \u001b[43m \u001b[49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[43m \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 278\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1017\u001b[0m, in \u001b[0;36m_log_export_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1010\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1011\u001b[0m log_export_usage(\n\u001b[1;32m 1012\u001b[0m event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.error.unclassified\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1013\u001b[0m \u001b[38;5;28mtype\u001b[39m\u001b[38;5;241m=\u001b[39merror_type,\n\u001b[1;32m 1014\u001b[0m message\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mstr\u001b[39m(e),\n\u001b[1;32m 1015\u001b[0m flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m 1016\u001b[0m )\n\u001b[0;32m-> 1017\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 1018\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 1019\u001b[0m _EXPORT_FLAGS \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:990\u001b[0m, in \u001b[0;36m_log_export_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 989\u001b[0m start \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m--> 990\u001b[0m ep \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 991\u001b[0m end \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 992\u001b[0m log_export_usage(\n\u001b[1;32m 993\u001b[0m event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.time\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 994\u001b[0m metrics\u001b[38;5;241m=\u001b[39mend \u001b[38;5;241m-\u001b[39m start,\n\u001b[1;32m 995\u001b[0m flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m 996\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mget_ep_stats(ep),\n\u001b[1;32m 997\u001b[0m )\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/exported_program.py:114\u001b[0m, in \u001b[0;36m_disable_prexisiting_fake_mode..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(fn)\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m unset_fake_temporarily():\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1880\u001b[0m, in \u001b[0;36m_export\u001b[0;34m(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature, pre_dispatch, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 1877\u001b[0m \u001b[38;5;66;03m# Call the appropriate export function based on the strictness of tracing.\u001b[39;00m\n\u001b[1;32m 1878\u001b[0m export_func \u001b[38;5;241m=\u001b[39m _strict_export \u001b[38;5;28;01mif\u001b[39;00m strict \u001b[38;5;28;01melse\u001b[39;00m _non_strict_export\n\u001b[0;32m-> 1880\u001b[0m export_artifact \u001b[38;5;241m=\u001b[39m \u001b[43mexport_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# type: ignore[operator]\u001b[39;49;00m\n\u001b[1;32m 1881\u001b[0m \u001b[43m \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1882\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1883\u001b[0m \u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1884\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1885\u001b[0m \u001b[43m \u001b[49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1886\u001b[0m \u001b[43m \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1887\u001b[0m \u001b[43m \u001b[49m\u001b[43moriginal_state_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1888\u001b[0m \u001b[43m \u001b[49m\u001b[43moriginal_in_spec\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1889\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_complex_guards_as_runtime_asserts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1890\u001b[0m \u001b[43m \u001b[49m\u001b[43m_is_torch_jit_trace\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1891\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1892\u001b[0m export_graph_signature: ExportGraphSignature \u001b[38;5;241m=\u001b[39m export_artifact\u001b[38;5;241m.\u001b[39maten\u001b[38;5;241m.\u001b[39msig\n\u001b[1;32m 1894\u001b[0m forward_arg_names \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1895\u001b[0m _get_forward_arg_names(mod, args, kwargs) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _is_torch_jit_trace \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1896\u001b[0m )\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1683\u001b[0m, in \u001b[0;36m_non_strict_export\u001b[0;34m(mod, args, kwargs, dynamic_shapes, preserve_module_call_signature, pre_dispatch, original_state_dict, orig_in_spec, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace, dispatch_tracing_mode)\u001b[0m\n\u001b[1;32m 1667\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) \u001b[38;5;28;01mas\u001b[39;00m (\n\u001b[1;32m 1668\u001b[0m patched_mod,\n\u001b[1;32m 1669\u001b[0m new_fake_args,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1672\u001b[0m map_fake_to_real,\n\u001b[1;32m 1673\u001b[0m ):\n\u001b[1;32m 1674\u001b[0m _to_aten_func \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1675\u001b[0m _export_to_aten_ir_make_fx\n\u001b[1;32m 1676\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dispatch_tracing_mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmake_fx\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1681\u001b[0m )\n\u001b[1;32m 1682\u001b[0m )\n\u001b[0;32m-> 1683\u001b[0m aten_export_artifact \u001b[38;5;241m=\u001b[39m \u001b[43m_to_aten_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# type: ignore[operator]\u001b[39;49;00m\n\u001b[1;32m 1684\u001b[0m \u001b[43m \u001b[49m\u001b[43mpatched_mod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1685\u001b[0m \u001b[43m \u001b[49m\u001b[43mnew_fake_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1686\u001b[0m \u001b[43m \u001b[49m\u001b[43mnew_fake_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1687\u001b[0m \u001b[43m \u001b[49m\u001b[43mfake_params_buffers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1688\u001b[0m \u001b[43m \u001b[49m\u001b[43mnew_fake_constant_attrs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1689\u001b[0m \u001b[43m \u001b[49m\u001b[43mproduce_guards_callback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_produce_guards_callback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1690\u001b[0m \u001b[43m \u001b[49m\u001b[43mtransform\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_tuplify_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1691\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1692\u001b[0m \u001b[38;5;66;03m# aten_export_artifact.constants contains only fake script objects, we need to map them back\u001b[39;00m\n\u001b[1;32m 1693\u001b[0m aten_export_artifact\u001b[38;5;241m.\u001b[39mconstants \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 1694\u001b[0m fqn: map_fake_to_real[obj] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obj, FakeScriptObject) \u001b[38;5;28;01melse\u001b[39;00m obj\n\u001b[1;32m 1695\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m fqn, obj \u001b[38;5;129;01min\u001b[39;00m aten_export_artifact\u001b[38;5;241m.\u001b[39mconstants\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 1696\u001b[0m }\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:637\u001b[0m, in \u001b[0;36m_export_to_aten_ir\u001b[0;34m(mod, fake_args, fake_kwargs, fake_params_buffers, constant_attrs, produce_guards_callback, transform, pre_dispatch, decomp_table, _check_autograd_state, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[38;5;66;03m# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,\u001b[39;00m\n\u001b[1;32m 628\u001b[0m \u001b[38;5;66;03m# otherwise aot_export_module will error out because it sees a mix of fake_modes.\u001b[39;00m\n\u001b[1;32m 629\u001b[0m \u001b[38;5;66;03m# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.\u001b[39;00m\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mstateless\u001b[38;5;241m.\u001b[39m_reparametrize_module(\n\u001b[1;32m 631\u001b[0m mod,\n\u001b[1;32m 632\u001b[0m fake_params_buffers,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 635\u001b[0m stack_weights\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 636\u001b[0m ), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context(): \u001b[38;5;66;03m# type: ignore[attr-defined]\u001b[39;00m\n\u001b[0;32m--> 637\u001b[0m gm, graph_signature \u001b[38;5;241m=\u001b[39m \u001b[43mtransform\u001b[49m\u001b[43m(\u001b[49m\u001b[43maot_export_module\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 638\u001b[0m \u001b[43m \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 639\u001b[0m \u001b[43m \u001b[49m\u001b[43mfake_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 640\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrace_joint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 641\u001b[0m \u001b[43m \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpre_dispatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 642\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecompositions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecomp_table\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 643\u001b[0m \u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfake_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 644\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 646\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_maybe_fixup_gm_and_output_node_meta\u001b[39m(old_gm, new_gm):\n\u001b[1;32m 647\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(old_gm, torch\u001b[38;5;241m.\u001b[39mfx\u001b[38;5;241m.\u001b[39mGraphModule):\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1611\u001b[0m, in \u001b[0;36m_non_strict_export.._tuplify_outputs.._aot_export_non_strict\u001b[0;34m(mod, args, kwargs, **flags)\u001b[0m\n\u001b[1;32m 1605\u001b[0m new_preserved_call_signatures \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 1606\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_export_root.\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m i \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m preserve_module_call_signature\n\u001b[1;32m 1607\u001b[0m ]\n\u001b[1;32m 1608\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _wrap_submodules(\n\u001b[1;32m 1609\u001b[0m wrapped_mod, new_preserved_call_signatures, module_call_specs\n\u001b[1;32m 1610\u001b[0m ):\n\u001b[0;32m-> 1611\u001b[0m gm, sig \u001b[38;5;241m=\u001b[39m \u001b[43maot_export\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwrapped_mod\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mflags\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1612\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExported program from AOTAutograd:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, gm)\n\u001b[1;32m 1614\u001b[0m sig\u001b[38;5;241m.\u001b[39mparameters \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_map(_strip_root, sig\u001b[38;5;241m.\u001b[39mparameters)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1246\u001b[0m, in \u001b[0;36maot_export_module\u001b[0;34m(mod, args, decompositions, trace_joint, output_loss_index, pre_dispatch, dynamic_shapes, kwargs)\u001b[0m\n\u001b[1;32m 1243\u001b[0m full_args\u001b[38;5;241m.\u001b[39mextend(args)\n\u001b[1;32m 1245\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx():\n\u001b[0;32m-> 1246\u001b[0m fx_g, metadata, in_spec, out_spec \u001b[38;5;241m=\u001b[39m \u001b[43m_aot_export_function\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1247\u001b[0m \u001b[43m \u001b[49m\u001b[43mfn_to_trace\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1248\u001b[0m \u001b[43m \u001b[49m\u001b[43mfull_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1249\u001b[0m \u001b[43m \u001b[49m\u001b[43mdecompositions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdecompositions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1250\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_params_buffers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mparams_len\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1251\u001b[0m \u001b[43m \u001b[49m\u001b[43mno_tangents\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 1252\u001b[0m \u001b[43m \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpre_dispatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1253\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1254\u001b[0m \u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1255\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1256\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trace_joint:\n\u001b[1;32m 1258\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mflattened_joint\u001b[39m(\u001b[38;5;241m*\u001b[39margs):\n\u001b[1;32m 1259\u001b[0m \u001b[38;5;66;03m# The idea here is that the joint graph that AOTAutograd creates has some strict properties:\u001b[39;00m\n\u001b[1;32m 1260\u001b[0m \u001b[38;5;66;03m# (1) It accepts two arguments (primals, tangents), and pytree_flattens them\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1273\u001b[0m \u001b[38;5;66;03m# This function \"fixes\" both of the above by removing any tangent inputs,\u001b[39;00m\n\u001b[1;32m 1274\u001b[0m \u001b[38;5;66;03m# and removing pytrees from the original FX graph.\u001b[39;00m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:1480\u001b[0m, in \u001b[0;36m_aot_export_function\u001b[0;34m(func, args, num_params_buffers, decompositions, no_tangents, pre_dispatch, dynamic_shapes, kwargs)\u001b[0m\n\u001b[1;32m 1477\u001b[0m fake_mode, shape_env \u001b[38;5;241m=\u001b[39m construct_fake_mode(flat_args, aot_config)\n\u001b[1;32m 1478\u001b[0m fake_flat_args \u001b[38;5;241m=\u001b[39m process_inputs(flat_args, aot_config, fake_mode, shape_env)\n\u001b[0;32m-> 1480\u001b[0m fx_g, meta \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_aot_dispatcher_function\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1481\u001b[0m \u001b[43m \u001b[49m\u001b[43mflat_fn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1482\u001b[0m \u001b[43m \u001b[49m\u001b[43mfake_flat_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1483\u001b[0m \u001b[43m \u001b[49m\u001b[43maot_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1484\u001b[0m \u001b[43m \u001b[49m\u001b[43mfake_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1485\u001b[0m \u001b[43m \u001b[49m\u001b[43mshape_env\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1486\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1487\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fx_g, meta, in_spec, out_spec\u001b[38;5;241m.\u001b[39mspec\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:522\u001b[0m, in \u001b[0;36mcreate_aot_dispatcher_function\u001b[0;34m(flat_fn, fake_flat_args, aot_config, fake_mode, shape_env)\u001b[0m\n\u001b[1;32m 514\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_aot_dispatcher_function\u001b[39m(\n\u001b[1;32m 515\u001b[0m flat_fn,\n\u001b[1;32m 516\u001b[0m fake_flat_args: FakifiedFlatArgs,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 519\u001b[0m shape_env: Optional[ShapeEnv],\n\u001b[1;32m 520\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[Callable, ViewAndMutationMeta]:\n\u001b[1;32m 521\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m dynamo_timed(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcreate_aot_dispatcher_function\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 522\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_create_aot_dispatcher_function\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 523\u001b[0m \u001b[43m \u001b[49m\u001b[43mflat_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfake_flat_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maot_config\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfake_mode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshape_env\u001b[49m\n\u001b[1;32m 524\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py:623\u001b[0m, in \u001b[0;36m_create_aot_dispatcher_function\u001b[0;34m(flat_fn, fake_flat_args, aot_config, fake_mode, shape_env)\u001b[0m\n\u001b[1;32m 621\u001b[0m ctx \u001b[38;5;241m=\u001b[39m nullcontext()\n\u001b[1;32m 622\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx:\n\u001b[0;32m--> 623\u001b[0m fw_metadata \u001b[38;5;241m=\u001b[39m \u001b[43mrun_functionalized_fw_and_collect_metadata\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 624\u001b[0m \u001b[43m \u001b[49m\u001b[43mflat_fn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 625\u001b[0m \u001b[43m \u001b[49m\u001b[43mstatic_input_indices\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maot_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstatic_input_indices\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 626\u001b[0m \u001b[43m \u001b[49m\u001b[43mkeep_input_mutations\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maot_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkeep_inference_input_mutations\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 627\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_train\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mneeds_autograd\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 628\u001b[0m \u001b[43m \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maot_config\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpre_dispatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 629\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m_dup_fake_script_obj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfake_flat_args\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 631\u001b[0m req_subclass_dispatch \u001b[38;5;241m=\u001b[39m requires_subclass_dispatch(\n\u001b[1;32m 632\u001b[0m fake_flat_args, fw_metadata\n\u001b[1;32m 633\u001b[0m )\n\u001b[1;32m 635\u001b[0m output_and_mutation_safe \u001b[38;5;241m=\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28many\u001b[39m(\n\u001b[1;32m 636\u001b[0m x\u001b[38;5;241m.\u001b[39mrequires_grad\n\u001b[1;32m 637\u001b[0m \u001b[38;5;66;03m# view-type operations preserve requires_grad even in no_grad.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 652\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m fw_metadata\u001b[38;5;241m.\u001b[39minput_info\n\u001b[1;32m 653\u001b[0m )\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py:173\u001b[0m, in \u001b[0;36mrun_functionalized_fw_and_collect_metadata..inner\u001b[0;34m(*flat_args)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m disable_above, mode, suppress_pending:\n\u001b[1;32m 171\u001b[0m \u001b[38;5;66;03m# precondition: The passed in function already handles unflattening inputs + flattening outputs\u001b[39;00m\n\u001b[1;32m 172\u001b[0m flat_f_args \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_map(_to_fun, flat_args)\n\u001b[0;32m--> 173\u001b[0m flat_f_outs \u001b[38;5;241m=\u001b[39m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mflat_f_args\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 174\u001b[0m \u001b[38;5;66;03m# We didn't do any tracing, so we don't need to process the\u001b[39;00m\n\u001b[1;32m 175\u001b[0m \u001b[38;5;66;03m# unbacked symbols, they will just disappear into the ether.\u001b[39;00m\n\u001b[1;32m 176\u001b[0m \u001b[38;5;66;03m# Also, prevent memoization from applying.\u001b[39;00m\n\u001b[1;32m 177\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fake_mode:\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:182\u001b[0m, in \u001b[0;36mcreate_tree_flattened_fn..flat_fn\u001b[0;34m(*flat_args)\u001b[0m\n\u001b[1;32m 180\u001b[0m \u001b[38;5;28;01mnonlocal\u001b[39;00m out_spec\n\u001b[1;32m 181\u001b[0m args, kwargs \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_unflatten(flat_args, tensor_args_spec)\n\u001b[0;32m--> 182\u001b[0m tree_out \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 183\u001b[0m flat_out, spec \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_flatten(tree_out)\n\u001b[1;32m 184\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m flat_out:\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py:863\u001b[0m, in \u001b[0;36mcreate_functional_call..functional_call\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 859\u001b[0m out \u001b[38;5;241m=\u001b[39m PropagateUnbackedSymInts(mod)\u001b[38;5;241m.\u001b[39mrun(\n\u001b[1;32m 860\u001b[0m \u001b[38;5;241m*\u001b[39margs[params_len:], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m 861\u001b[0m )\n\u001b[1;32m 862\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 863\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mmod\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m[\u001b[49m\u001b[43mparams_len\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 865\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(out, (\u001b[38;5;28mtuple\u001b[39m, \u001b[38;5;28mlist\u001b[39m)):\n\u001b[1;32m 866\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 867\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGraph output must be a (). This is so that we can avoid \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 868\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpytree processing of the outputs. Please change the module to \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 869\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhave tuple outputs or use aot_module instead.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 870\u001b[0m )\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1598\u001b[0m, in \u001b[0;36m_non_strict_export.._tuplify_outputs.._aot_export_non_strict..Wrapper.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1594\u001b[0m tree_out \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mfx\u001b[38;5;241m.\u001b[39mInterpreter(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_export_root)\u001b[38;5;241m.\u001b[39mrun(\n\u001b[1;32m 1595\u001b[0m \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m 1596\u001b[0m )\n\u001b[1;32m 1597\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1598\u001b[0m tree_out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_export_root\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1599\u001b[0m flat_outs, out_spec \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_flatten(tree_out)\n\u001b[1;32m 1600\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtuple\u001b[39m(flat_outs)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", "Cell \u001b[0;32mIn[39], line 46\u001b[0m, in \u001b[0;36mStyleTTS2.forward\u001b[0;34m(self, tokens)\u001b[0m\n\u001b[1;32m 42\u001b[0m pred_aln_trg\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mvstack(pred_aln_trg_list)\n\u001b[1;32m 44\u001b[0m en \u001b[38;5;241m=\u001b[39m d\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m) \u001b[38;5;241m@\u001b[39m pred_aln_trg\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[0;32m---> 46\u001b[0m F0_pred, N_pred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpredictor\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mF0Ntrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43men\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 47\u001b[0m t_en \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext_encoder\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39minference(tokens)\n\u001b[1;32m 48\u001b[0m asr \u001b[38;5;241m=\u001b[39m t_en \u001b[38;5;241m@\u001b[39m pred_aln_trg\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n", "File \u001b[0;32m~/Projects/DeepLearning/TTS/Kokoro-82M/models.py:471\u001b[0m, in \u001b[0;36mProsodyPredictor.F0Ntrain\u001b[0;34m(self, x, s)\u001b[0m\n\u001b[1;32m 466\u001b[0m x1 \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 467\u001b[0m \u001b[38;5;66;03m# torch._check(x1.dim() == 3, lambda: print(f\"Expected 3D tensor, got {x1.dim()}D tensor\"))\u001b[39;00m\n\u001b[1;32m 468\u001b[0m \u001b[38;5;66;03m# torch._check(x1.shape[1] > 0, lambda: print(f\"Shape 2, got {x1.shape[1]}\"))\u001b[39;00m\n\u001b[1;32m 469\u001b[0m \u001b[38;5;66;03m# torch._check(x1.shape[2] > 0, lambda: print(f\"Shape 2, got {x1.shape[2]}\"))\u001b[39;00m\n\u001b[1;32m 470\u001b[0m \u001b[38;5;66;03m# torch._check(x.shape[2] > 0, lambda: print(f\"Shape 2, got {x.shape[2]}\"))\u001b[39;00m\n\u001b[0;32m--> 471\u001b[0m x2, _temp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshared\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx1\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 472\u001b[0m \u001b[38;5;66;03m# torch._check(x.shape[2] > 0, lambda: print(f\"Shape 2, got {x.size(2)}\"))\u001b[39;00m\n\u001b[1;32m 474\u001b[0m F0 \u001b[38;5;241m=\u001b[39m x2\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m2\u001b[39m)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1736\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1745\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1746\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1747\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123\u001b[0m, in \u001b[0;36mLSTM.forward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 1120\u001b[0m hx \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpermute_hidden(hx, sorted_indices)\n\u001b[1;32m 1122\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_sizes \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1123\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlstm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1124\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1125\u001b[0m \u001b[43m \u001b[49m\u001b[43mhx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1126\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_flat_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1127\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1128\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_layers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1129\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1130\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbidirectional\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1132\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_first\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1133\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1134\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1135\u001b[0m result \u001b[38;5;241m=\u001b[39m _VF\u001b[38;5;241m.\u001b[39mlstm(\n\u001b[1;32m 1136\u001b[0m \u001b[38;5;28minput\u001b[39m,\n\u001b[1;32m 1137\u001b[0m batch_sizes,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbidirectional,\n\u001b[1;32m 1145\u001b[0m )\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_export/non_strict_utils.py:520\u001b[0m, in \u001b[0;36m_NonStrictTorchFunctionHandler.__torch_function__\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 512\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\n\u001b[1;32m 513\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m called at \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m:\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m in \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 514\u001b[0m func\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 517\u001b[0m frame\u001b[38;5;241m.\u001b[39mf_code\u001b[38;5;241m.\u001b[39mco_name,\n\u001b[1;32m 518\u001b[0m )\n\u001b[1;32m 519\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 520\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 521\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m GuardOnDataDependentSymNode \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 522\u001b[0m _suggest_fixes_for_data_dependent_error_non_strict(e)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_decomp/decompositions.py:3476\u001b[0m, in \u001b[0;36mlstm_impl\u001b[0;34m(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)\u001b[0m\n\u001b[1;32m 3474\u001b[0m hidden \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mzip\u001b[39m(hx[\u001b[38;5;241m0\u001b[39m], hx[\u001b[38;5;241m1\u001b[39m]))\n\u001b[1;32m 3475\u001b[0m layer_fn \u001b[38;5;241m=\u001b[39m select_one_layer_lstm_function(\u001b[38;5;28minput\u001b[39m, hx, params)\n\u001b[0;32m-> 3476\u001b[0m out, final_hiddens \u001b[38;5;241m=\u001b[39m \u001b[43m_rnn_helper\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3477\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3478\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3479\u001b[0m \u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3480\u001b[0m \u001b[43m \u001b[49m\u001b[43mhas_biases\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3481\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_layers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3482\u001b[0m \u001b[43m \u001b[49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3483\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3484\u001b[0m \u001b[43m \u001b[49m\u001b[43mbidirectional\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3485\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_first\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3486\u001b[0m \u001b[43m \u001b[49m\u001b[43mlayer_fn\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3487\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3488\u001b[0m final_hiddens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mzip\u001b[39m(\u001b[38;5;241m*\u001b[39mfinal_hiddens))\n\u001b[1;32m 3489\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out, torch\u001b[38;5;241m.\u001b[39mstack(final_hiddens[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m0\u001b[39m), torch\u001b[38;5;241m.\u001b[39mstack(final_hiddens[\u001b[38;5;241m1\u001b[39m], \u001b[38;5;241m0\u001b[39m)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_decomp/decompositions.py:3151\u001b[0m, in \u001b[0;36m_rnn_helper\u001b[0;34m(input, hidden, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, layer_fn)\u001b[0m\n\u001b[1;32m 3147\u001b[0m cur_params, cur_hidden, bidir_params, bidir_hidden \u001b[38;5;241m=\u001b[39m params_hiddens(\n\u001b[1;32m 3148\u001b[0m params, hidden, i, bidirectional\n\u001b[1;32m 3149\u001b[0m )\n\u001b[1;32m 3150\u001b[0m dropout \u001b[38;5;241m=\u001b[39m dropout \u001b[38;5;28;01mif\u001b[39;00m (train \u001b[38;5;129;01mand\u001b[39;00m num_layers \u001b[38;5;241m<\u001b[39m i \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;241m0.0\u001b[39m\n\u001b[0;32m-> 3151\u001b[0m fwd_inp, fwd_hidden \u001b[38;5;241m=\u001b[39m \u001b[43mlayer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcur_hidden\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcur_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhas_biases\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3152\u001b[0m final_hiddens\u001b[38;5;241m.\u001b[39mappend(fwd_hidden)\n\u001b[1;32m 3154\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m bidirectional:\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_decomp/decompositions.py:3333\u001b[0m, in \u001b[0;36mone_layer_lstm\u001b[0;34m(inp, hidden, params, has_biases, reverse)\u001b[0m\n\u001b[1;32m 3331\u001b[0m precomputed_input \u001b[38;5;241m=\u001b[39m precomputed_input\u001b[38;5;241m.\u001b[39mflip(\u001b[38;5;241m0\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m reverse \u001b[38;5;28;01melse\u001b[39;00m precomputed_input\n\u001b[1;32m 3332\u001b[0m step_output \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m-> 3333\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m inp \u001b[38;5;129;01min\u001b[39;00m precomputed_input:\n\u001b[1;32m 3334\u001b[0m hx, cx \u001b[38;5;241m=\u001b[39m lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 3335\u001b[0m step_output\u001b[38;5;241m.\u001b[39mappend(hx)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_tensor.py:1119\u001b[0m, in \u001b[0;36mTensor.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1110\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_get_tracing_state():\n\u001b[1;32m 1111\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m 1112\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIterating over a tensor might cause the trace to be incorrect. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1113\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPassing a tensor of different shape won\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt change the number of \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1117\u001b[0m stacklevel\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m,\n\u001b[1;32m 1118\u001b[0m )\n\u001b[0;32m-> 1119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28miter\u001b[39m(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munbind\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py:534\u001b[0m, in \u001b[0;36mFunctionalTensorMode.__torch_dispatch__\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 525\u001b[0m outs_wrapped \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_map_only(\n\u001b[1;32m 526\u001b[0m torch\u001b[38;5;241m.\u001b[39mTensor, wrap, outs_unwrapped\n\u001b[1;32m 527\u001b[0m )\n\u001b[1;32m 528\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 529\u001b[0m \u001b[38;5;66;03m# When we dispatch to the C++ functionalization kernel, we might need to jump back to the\u001b[39;00m\n\u001b[1;32m 530\u001b[0m \u001b[38;5;66;03m# PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath\u001b[39;00m\n\u001b[1;32m 531\u001b[0m \u001b[38;5;66;03m# FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch\u001b[39;00m\n\u001b[1;32m 532\u001b[0m \u001b[38;5;66;03m# from the TLS in order to avoid infinite looping, but this would prevent us from coming\u001b[39;00m\n\u001b[1;32m 533\u001b[0m \u001b[38;5;66;03m# back to PreDispatch later\u001b[39;00m\n\u001b[0;32m--> 534\u001b[0m outs_unwrapped \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_op_dk\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 535\u001b[0m \u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_C\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mDispatchKey\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mFunctionalize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 536\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs_unwrapped\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 537\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs_unwrapped\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 538\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 539\u001b[0m \u001b[38;5;66;03m# We don't allow any mutation on result of dropout or _to_copy\u001b[39;00m\n\u001b[1;32m 540\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mexport:\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/utils/_stats.py:21\u001b[0m, in \u001b[0;36mcount..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 19\u001b[0m simple_call_counter[fn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 20\u001b[0m simple_call_counter[fn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m] \u001b[38;5;241m=\u001b[39m simple_call_counter[fn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m] \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1238\u001b[0m, in \u001b[0;36mFakeTensorMode.__torch_dispatch__\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 1234\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m (\n\u001b[1;32m 1235\u001b[0m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_get_dispatch_mode(torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_TorchDispatchModeKey\u001b[38;5;241m.\u001b[39mFAKE) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1236\u001b[0m ), func\n\u001b[1;32m 1237\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1238\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdispatch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1239\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 1240\u001b[0m log\u001b[38;5;241m.\u001b[39mexception(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfake tensor raised TypeError\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1692\u001b[0m, in \u001b[0;36mFakeTensorMode.dispatch\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 1689\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1691\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcache_enabled:\n\u001b[0;32m-> 1692\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cached_dispatch_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1693\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1694\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dispatch_impl(func, types, args, kwargs)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1348\u001b[0m, in \u001b[0;36mFakeTensorMode._cached_dispatch_impl\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 1345\u001b[0m FakeTensorMode\u001b[38;5;241m.\u001b[39mcache_bypasses[e\u001b[38;5;241m.\u001b[39mreason] \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output \u001b[38;5;129;01mis\u001b[39;00m _UNASSIGNED:\n\u001b[0;32m-> 1348\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dispatch_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1350\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1943\u001b[0m, in \u001b[0;36mFakeTensorMode._dispatch_impl\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 1933\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m func \u001b[38;5;129;01min\u001b[39;00m decomposition_table \u001b[38;5;129;01mand\u001b[39;00m (\n\u001b[1;32m 1934\u001b[0m has_symbolic_sizes\n\u001b[1;32m 1935\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1940\u001b[0m )\n\u001b[1;32m 1941\u001b[0m ):\n\u001b[1;32m 1942\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m-> 1943\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mdecomposition_table\u001b[49m\u001b[43m[\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1945\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m 1946\u001b[0m \u001b[38;5;66;03m# Decomposes CompositeImplicitAutograd ops\u001b[39;00m\n\u001b[1;32m 1947\u001b[0m r \u001b[38;5;241m=\u001b[39m func\u001b[38;5;241m.\u001b[39mdecompose(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_refs/__init__.py:3956\u001b[0m, in \u001b[0;36munbind\u001b[0;34m(t, dim)\u001b[0m\n\u001b[1;32m 3953\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ()\n\u001b[1;32m 3954\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3955\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtuple\u001b[39m(\n\u001b[0;32m-> 3956\u001b[0m torch\u001b[38;5;241m.\u001b[39msqueeze(s, dim) \u001b[38;5;28;01mfor\u001b[39;00m s \u001b[38;5;129;01min\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtensor_split\u001b[49m\u001b[43m(\u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[43mdim\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3957\u001b[0m )\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/utils/_stats.py:21\u001b[0m, in \u001b[0;36mcount..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 19\u001b[0m simple_call_counter[fn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 20\u001b[0m simple_call_counter[fn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m] \u001b[38;5;241m=\u001b[39m simple_call_counter[fn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m] \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m---> 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1238\u001b[0m, in \u001b[0;36mFakeTensorMode.__torch_dispatch__\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 1234\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m (\n\u001b[1;32m 1235\u001b[0m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_get_dispatch_mode(torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_TorchDispatchModeKey\u001b[38;5;241m.\u001b[39mFAKE) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1236\u001b[0m ), func\n\u001b[1;32m 1237\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1238\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdispatch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1239\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 1240\u001b[0m log\u001b[38;5;241m.\u001b[39mexception(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfake tensor raised TypeError\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1692\u001b[0m, in \u001b[0;36mFakeTensorMode.dispatch\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 1689\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m func(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1691\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcache_enabled:\n\u001b[0;32m-> 1692\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cached_dispatch_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1693\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1694\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dispatch_impl(func, types, args, kwargs)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1348\u001b[0m, in \u001b[0;36mFakeTensorMode._cached_dispatch_impl\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 1345\u001b[0m FakeTensorMode\u001b[38;5;241m.\u001b[39mcache_bypasses[e\u001b[38;5;241m.\u001b[39mreason] \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 1347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output \u001b[38;5;129;01mis\u001b[39;00m _UNASSIGNED:\n\u001b[0;32m-> 1348\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dispatch_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtypes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1350\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py:1947\u001b[0m, in \u001b[0;36mFakeTensorMode._dispatch_impl\u001b[0;34m(self, func, types, args, kwargs)\u001b[0m\n\u001b[1;32m 1943\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m decomposition_table[func](\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1945\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m 1946\u001b[0m \u001b[38;5;66;03m# Decomposes CompositeImplicitAutograd ops\u001b[39;00m\n\u001b[0;32m-> 1947\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecompose\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1948\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m r \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mNotImplemented\u001b[39m:\n\u001b[1;32m 1949\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m r\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py:759\u001b[0m, in \u001b[0;36mOpOverload.decompose\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 757\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpy_kernels[dk](\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 758\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_dispatch_has_kernel_for_dispatch_key(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname(), dk):\n\u001b[0;32m--> 759\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_op_dk\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 760\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 761\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mNotImplemented\u001b[39m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/sym_node.py:429\u001b[0m, in \u001b[0;36mSymNode.guard_int\u001b[0;34m(self, file, line)\u001b[0m\n\u001b[1;32m 426\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mguard_int\u001b[39m(\u001b[38;5;28mself\u001b[39m, file, line):\n\u001b[1;32m 427\u001b[0m \u001b[38;5;66;03m# TODO: use the file/line for some useful diagnostic on why a\u001b[39;00m\n\u001b[1;32m 428\u001b[0m \u001b[38;5;66;03m# guard occurred\u001b[39;00m\n\u001b[0;32m--> 429\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape_env\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate_expr\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexpr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfx_node\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfx_node\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 430\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 431\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mint\u001b[39m(r)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/recording.py:262\u001b[0m, in \u001b[0;36mrecord_shapeenv_event..decorator..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 256\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m args[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mis_recording: \u001b[38;5;66;03m# type: ignore[has-type]\u001b[39;00m\n\u001b[1;32m 257\u001b[0m \u001b[38;5;66;03m# If ShapeEnv is already recording an event, call the wrapped\u001b[39;00m\n\u001b[1;32m 258\u001b[0m \u001b[38;5;66;03m# function directly.\u001b[39;00m\n\u001b[1;32m 259\u001b[0m \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m 260\u001b[0m \u001b[38;5;66;03m# NB: here, we skip the check of whether all ShapeEnv instances\u001b[39;00m\n\u001b[1;32m 261\u001b[0m \u001b[38;5;66;03m# are equal, in favor of a faster dispatch.\u001b[39;00m\n\u001b[0;32m--> 262\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m retlog(\u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 264\u001b[0m \u001b[38;5;66;03m# Retrieve an instance of ShapeEnv.\u001b[39;00m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;66;03m# Assumption: the collection of args and kwargs may not reference\u001b[39;00m\n\u001b[1;32m 266\u001b[0m \u001b[38;5;66;03m# different ShapeEnv instances.\u001b[39;00m\n\u001b[1;32m 267\u001b[0m \u001b[38;5;28mself\u001b[39m \u001b[38;5;241m=\u001b[39m _extract_shape_env_and_assert_equal(args, kwargs)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5122\u001b[0m, in \u001b[0;36mShapeEnv.evaluate_expr\u001b[0;34m(self, orig_expr, hint, fx_node, size_oblivious, forcing_spec)\u001b[0m\n\u001b[1;32m 5117\u001b[0m \u001b[38;5;129m@lru_cache\u001b[39m(\u001b[38;5;241m256\u001b[39m)\n\u001b[1;32m 5118\u001b[0m \u001b[38;5;129m@record_shapeenv_event\u001b[39m(save_tracked_fakes\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 5119\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mevaluate_expr\u001b[39m(\u001b[38;5;28mself\u001b[39m, orig_expr: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msympy.Expr\u001b[39m\u001b[38;5;124m\"\u001b[39m, hint\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, fx_node\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 5120\u001b[0m size_oblivious: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;241m*\u001b[39m, forcing_spec: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m 5121\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 5122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluate_expr\u001b[49m\u001b[43m(\u001b[49m\u001b[43morig_expr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhint\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfx_node\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize_oblivious\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mforcing_spec\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforcing_spec\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5123\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[1;32m 5124\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog\u001b[38;5;241m.\u001b[39mwarning(\n\u001b[1;32m 5125\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfailed during evaluate_expr(\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m, hint=\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m, size_oblivious=\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m, forcing_spec=\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 5126\u001b[0m orig_expr, hint, size_oblivious, forcing_spec\n\u001b[1;32m 5127\u001b[0m )\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:5238\u001b[0m, in \u001b[0;36mShapeEnv._evaluate_expr\u001b[0;34m(self, orig_expr, hint, fx_node, size_oblivious, forcing_spec)\u001b[0m\n\u001b[1;32m 5236\u001b[0m concrete_val \u001b[38;5;241m=\u001b[39m unsound_result\n\u001b[1;32m 5237\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 5238\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_data_dependent_error(\n\u001b[1;32m 5239\u001b[0m expr\u001b[38;5;241m.\u001b[39mxreplace(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvar_to_val),\n\u001b[1;32m 5240\u001b[0m expr,\n\u001b[1;32m 5241\u001b[0m size_oblivious_result\u001b[38;5;241m=\u001b[39msize_oblivious_result\n\u001b[1;32m 5242\u001b[0m )\n\u001b[1;32m 5243\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 5244\u001b[0m expr \u001b[38;5;241m=\u001b[39m new_expr\n", "\u001b[0;31mGuardOnDataDependentSymNode\u001b[0m: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: u0)\n\nPotential framework code culprit (scroll up for full backtrace):\n File \"/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_ops.py\", line 759, in decompose\n return self._op_dk(dk, *args, **kwargs)\n\nFor more information, run with TORCH_LOGS=\"dynamic\"\nFor extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"u0\"\nIf you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\nFor more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n\nThe following call raised this error:\n File \"/rhome/eingerman/Projects/DeepLearning/TTS/Kokoro-82M/models.py\", line 471, in F0Ntrain\n x2, _temp = self.shared(x1)\n\nTo fix the error, insert one of the following checks before this call:\n 1. torch._check(x.shape[2])\n 2. torch._check(~x.shape[2])\n\n(These suggested fixes were derived by replacing `u0` with x.shape[2] or x1.shape[1] in u0 and its negation.)" ] } ], "source": [ "os.environ['TORCH_LOGS'] = '+dynamic'\n", "os.environ['TORCH_LOGS'] = '+export'\n", "os.environ['TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED']=\"u0 >= 0\"\n", "os.environ['TORCHDYNAMO_EXTENDED_DEBUG_CPP']=\"1\"\n", "os.environ['TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL']=\"u0\"\n", "\n", "class StyleTTS2(torch.nn.Module):\n", " def __init__(self, model, voicepack):\n", " super().__init__()\n", " self.model = model\n", " self.voicepack = voicepack\n", " \n", " def forward(self, tokens):\n", " speed = 1.\n", " # tokens = torch.nn.functional.pad(tokens, (0, 510 - tokens.shape[-1]))\n", " device = tokens.device\n", " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n", "\n", " text_mask = length_to_mask(input_lengths).to(device)\n", " bert_dur = self.model['bert'](tokens, attention_mask=(~text_mask).int())\n", "\n", " d_en = self.model[\"bert_encoder\"](bert_dur).transpose(-1, -2)\n", "\n", " ref_s = self.voicepack[tokens.shape[1]]\n", " s = ref_s[:, 128:]\n", "\n", " d = self.model[\"predictor\"].text_encoder.inference(d_en, s)\n", " x, _ = self.model[\"predictor\"].lstm(d)\n", "\n", " duration = self.model[\"predictor\"].duration_proj(x)\n", " duration = torch.sigmoid(duration).sum(axis=-1) / speed\n", " pred_dur = torch.round(duration).clamp(min=1).long()\n", " \n", " c_start = F.pad(pred_dur,(1,0), \"constant\").cumsum(dim=1)[0,0:-1]\n", " c_end = c_start + pred_dur[0,:]\n", " indices = torch.arange(0, pred_dur.sum().item()).long().to(device)\n", "\n", " pred_aln_trg_list=[]\n", " for cs, ce in zip(c_start, c_end):\n", " row = torch.where((indices>=cs) & (indices 670\u001b[0m \u001b[43mproduce_guards_callback\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgm\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 671\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (ConstraintViolationError, ValueRangeError) \u001b[38;5;28;01mas\u001b[39;00m e:\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1655\u001b[0m, in \u001b[0;36m_non_strict_export.._produce_guards_callback\u001b[0;34m(gm)\u001b[0m\n\u001b[1;32m 1654\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_produce_guards_callback\u001b[39m(gm):\n\u001b[0;32m-> 1655\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mproduce_guards_and_solve_constraints\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1656\u001b[0m \u001b[43m \u001b[49m\u001b[43mfake_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfake_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1657\u001b[0m \u001b[43m \u001b[49m\u001b[43mgm\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1658\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtransformed_dynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1659\u001b[0m \u001b[43m \u001b[49m\u001b[43mequalities_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mequalities_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1660\u001b[0m \u001b[43m \u001b[49m\u001b[43moriginal_signature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moriginal_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1661\u001b[0m \u001b[43m \u001b[49m\u001b[43m_is_torch_jit_trace\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_is_torch_jit_trace\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1662\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_export/non_strict_utils.py:305\u001b[0m, in \u001b[0;36mproduce_guards_and_solve_constraints\u001b[0;34m(fake_mode, gm, dynamic_shapes, equalities_inputs, original_signature, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m constraint_violation_error:\n\u001b[0;32m--> 305\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m constraint_violation_error\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/_export/non_strict_utils.py:270\u001b[0m, in \u001b[0;36mproduce_guards_and_solve_constraints\u001b[0;34m(fake_mode, gm, dynamic_shapes, equalities_inputs, original_signature, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 270\u001b[0m \u001b[43mshape_env\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mproduce_guards\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mplaceholders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43msources\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_contexts\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_contexts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mequalities_inputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mequalities_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_static\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 276\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ConstraintViolationError \u001b[38;5;28;01mas\u001b[39;00m e:\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py:4178\u001b[0m, in \u001b[0;36mShapeEnv.produce_guards\u001b[0;34m(self, placeholders, sources, source_ref, guards, input_contexts, equalities_inputs, _simplified, ignore_static)\u001b[0m\n\u001b[1;32m 4177\u001b[0m err \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(error_msgs)\n\u001b[0;32m-> 4178\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ConstraintViolationError(\n\u001b[1;32m 4179\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mConstraints violated (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdebug_names\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m)! \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4180\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mFor more information, run with TORCH_LOGS=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m+dynamic\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 4181\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 4182\u001b[0m )\n\u001b[1;32m 4183\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(warn_msgs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n", "\u001b[0;31mConstraintViolationError\u001b[0m: Constraints violated (token_len)! For more information, run with TORCH_LOGS=\"+dynamic\".\n - Not all values of token_len = L['args'][0][0].size()[0] in the specified range are valid because token_len was inferred to be a constant (143).\nSuggested fixes:\n token_len = 143", "\nDuring handling of the above exception, another exception occurred:\n", "\u001b[0;31mUserError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[33], line 61\u001b[0m\n\u001b[1;32m 58\u001b[0m dynamic_shapes \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokens0\u001b[39m\u001b[38;5;124m\"\u001b[39m:{\u001b[38;5;241m0\u001b[39m:token_len}}\n\u001b[1;32m 60\u001b[0m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[0;32m---> 61\u001b[0m export_mod \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexport\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_model\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokens\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;66;03m# export_mod = torch.export.export(test_model, args=( tokens[0,:], ), strict=False).run_decompositions()\u001b[39;00m\n\u001b[1;32m 63\u001b[0m \u001b[38;5;28mprint\u001b[39m(export_mod)\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/__init__.py:270\u001b[0m, in \u001b[0;36mexport\u001b[0;34m(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature)\u001b[0m\n\u001b[1;32m 264\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(mod, torch\u001b[38;5;241m.\u001b[39mjit\u001b[38;5;241m.\u001b[39mScriptModule):\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 266\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExporting a ScriptModule is not supported. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 267\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMaybe try converting your ScriptModule to an ExportedProgram \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 268\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124musing `TS2EPConverter(mod, args, kwargs).convert()` instead.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 269\u001b[0m )\n\u001b[0;32m--> 270\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_export\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstrict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 276\u001b[0m \u001b[43m \u001b[49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[43m \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 278\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1017\u001b[0m, in \u001b[0;36m_log_export_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1010\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1011\u001b[0m log_export_usage(\n\u001b[1;32m 1012\u001b[0m event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.error.unclassified\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 1013\u001b[0m \u001b[38;5;28mtype\u001b[39m\u001b[38;5;241m=\u001b[39merror_type,\n\u001b[1;32m 1014\u001b[0m message\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mstr\u001b[39m(e),\n\u001b[1;32m 1015\u001b[0m flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m 1016\u001b[0m )\n\u001b[0;32m-> 1017\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 1018\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 1019\u001b[0m _EXPORT_FLAGS \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:990\u001b[0m, in \u001b[0;36m_log_export_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 988\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 989\u001b[0m start \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m--> 990\u001b[0m ep \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 991\u001b[0m end \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 992\u001b[0m log_export_usage(\n\u001b[1;32m 993\u001b[0m event\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mexport.time\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 994\u001b[0m metrics\u001b[38;5;241m=\u001b[39mend \u001b[38;5;241m-\u001b[39m start,\n\u001b[1;32m 995\u001b[0m flags\u001b[38;5;241m=\u001b[39m_EXPORT_FLAGS,\n\u001b[1;32m 996\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mget_ep_stats(ep),\n\u001b[1;32m 997\u001b[0m )\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/exported_program.py:114\u001b[0m, in \u001b[0;36m_disable_prexisiting_fake_mode..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(fn)\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mwrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m unset_fake_temporarily():\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1880\u001b[0m, in \u001b[0;36m_export\u001b[0;34m(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature, pre_dispatch, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 1877\u001b[0m \u001b[38;5;66;03m# Call the appropriate export function based on the strictness of tracing.\u001b[39;00m\n\u001b[1;32m 1878\u001b[0m export_func \u001b[38;5;241m=\u001b[39m _strict_export \u001b[38;5;28;01mif\u001b[39;00m strict \u001b[38;5;28;01melse\u001b[39;00m _non_strict_export\n\u001b[0;32m-> 1880\u001b[0m export_artifact \u001b[38;5;241m=\u001b[39m \u001b[43mexport_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# type: ignore[operator]\u001b[39;49;00m\n\u001b[1;32m 1881\u001b[0m \u001b[43m \u001b[49m\u001b[43mmod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1882\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1883\u001b[0m \u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1884\u001b[0m \u001b[43m \u001b[49m\u001b[43mdynamic_shapes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1885\u001b[0m \u001b[43m \u001b[49m\u001b[43mpreserve_module_call_signature\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1886\u001b[0m \u001b[43m \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1887\u001b[0m \u001b[43m \u001b[49m\u001b[43moriginal_state_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1888\u001b[0m \u001b[43m \u001b[49m\u001b[43moriginal_in_spec\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1889\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_complex_guards_as_runtime_asserts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1890\u001b[0m \u001b[43m \u001b[49m\u001b[43m_is_torch_jit_trace\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1891\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1892\u001b[0m export_graph_signature: ExportGraphSignature \u001b[38;5;241m=\u001b[39m export_artifact\u001b[38;5;241m.\u001b[39maten\u001b[38;5;241m.\u001b[39msig\n\u001b[1;32m 1894\u001b[0m forward_arg_names \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1895\u001b[0m _get_forward_arg_names(mod, args, kwargs) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _is_torch_jit_trace \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1896\u001b[0m )\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:1683\u001b[0m, in \u001b[0;36m_non_strict_export\u001b[0;34m(mod, args, kwargs, dynamic_shapes, preserve_module_call_signature, pre_dispatch, original_state_dict, orig_in_spec, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace, dispatch_tracing_mode)\u001b[0m\n\u001b[1;32m 1667\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) \u001b[38;5;28;01mas\u001b[39;00m (\n\u001b[1;32m 1668\u001b[0m patched_mod,\n\u001b[1;32m 1669\u001b[0m new_fake_args,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1672\u001b[0m map_fake_to_real,\n\u001b[1;32m 1673\u001b[0m ):\n\u001b[1;32m 1674\u001b[0m _to_aten_func \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1675\u001b[0m _export_to_aten_ir_make_fx\n\u001b[1;32m 1676\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dispatch_tracing_mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmake_fx\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1681\u001b[0m )\n\u001b[1;32m 1682\u001b[0m )\n\u001b[0;32m-> 1683\u001b[0m aten_export_artifact \u001b[38;5;241m=\u001b[39m \u001b[43m_to_aten_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# type: ignore[operator]\u001b[39;49;00m\n\u001b[1;32m 1684\u001b[0m \u001b[43m \u001b[49m\u001b[43mpatched_mod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1685\u001b[0m \u001b[43m \u001b[49m\u001b[43mnew_fake_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1686\u001b[0m \u001b[43m \u001b[49m\u001b[43mnew_fake_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1687\u001b[0m \u001b[43m \u001b[49m\u001b[43mfake_params_buffers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1688\u001b[0m \u001b[43m \u001b[49m\u001b[43mnew_fake_constant_attrs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1689\u001b[0m \u001b[43m \u001b[49m\u001b[43mproduce_guards_callback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_produce_guards_callback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1690\u001b[0m \u001b[43m \u001b[49m\u001b[43mtransform\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_tuplify_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1691\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1692\u001b[0m \u001b[38;5;66;03m# aten_export_artifact.constants contains only fake script objects, we need to map them back\u001b[39;00m\n\u001b[1;32m 1693\u001b[0m aten_export_artifact\u001b[38;5;241m.\u001b[39mconstants \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 1694\u001b[0m fqn: map_fake_to_real[obj] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obj, FakeScriptObject) \u001b[38;5;28;01melse\u001b[39;00m obj\n\u001b[1;32m 1695\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m fqn, obj \u001b[38;5;129;01min\u001b[39;00m aten_export_artifact\u001b[38;5;241m.\u001b[39mconstants\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 1696\u001b[0m }\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/export/_trace.py:672\u001b[0m, in \u001b[0;36m_export_to_aten_ir\u001b[0;34m(mod, fake_args, fake_kwargs, fake_params_buffers, constant_attrs, produce_guards_callback, transform, pre_dispatch, decomp_table, _check_autograd_state, _is_torch_jit_trace)\u001b[0m\n\u001b[1;32m 670\u001b[0m produce_guards_callback(gm)\n\u001b[1;32m 671\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (ConstraintViolationError, ValueRangeError) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m--> 672\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m UserError(UserErrorType\u001b[38;5;241m.\u001b[39mCONSTRAINT_VIOLATION, \u001b[38;5;28mstr\u001b[39m(e)) \u001b[38;5;66;03m# noqa: B904\u001b[39;00m\n\u001b[1;32m 674\u001b[0m \u001b[38;5;66;03m# Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature.\u001b[39;00m\n\u001b[1;32m 675\u001b[0m \u001b[38;5;66;03m# Overwrite output specs afterwards.\u001b[39;00m\n\u001b[1;32m 676\u001b[0m flat_fake_args \u001b[38;5;241m=\u001b[39m pytree\u001b[38;5;241m.\u001b[39mtree_leaves((fake_args, fake_kwargs))\n", "\u001b[0;31mUserError\u001b[0m: Constraints violated (token_len)! For more information, run with TORCH_LOGS=\"+dynamic\".\n - Not all values of token_len = L['args'][0][0].size()[0] in the specified range are valid because token_len was inferred to be a constant (143).\nSuggested fixes:\n token_len = 143" ] } ], "source": [ "os.environ['TORCH_LOGS'] = '+dynamic'\n", "os.environ['TORCH_LOGS'] = '+export'\n", "class test(torch.nn.Module):\n", " def __init__(self, model, voicepack):\n", " super().__init__()\n", " self.model = model\n", " self.voicepack = voicepack\n", " self.model.text_encoder.lstm.flatten_parameters()\n", " \n", " def forward(self, tokens0):\n", " tokens = tokens0.unsqueeze(0)\n", " print(tokens.shape)\n", " # speed = 1.\n", " # # tokens = torch.nn.functional.pad(tokens, (0, 510 - tokens.shape[-1]))\n", " # device = tokens.device\n", " input_lengths = torch.LongTensor([tokens0.shape[-1]]).to(device)\n", "\n", " # text_mask = length_to_mask(input_lengths).to(device)\n", " # bert_dur = self.model['bert'](tokens, attention_mask=(~text_mask).int())\n", "\n", " # d_en = self.model[\"bert_encoder\"](bert_dur).transpose(-1, -2)\n", "\n", " # ref_s = self.voicepack[tokens.shape[1]]\n", " # s = ref_s[:, 128:]\n", "\n", " # d = self.model[\"predictor\"].text_encoder.inference(d_en, s)\n", " # x, _ = self.model[\"predictor\"].lstm(d)\n", "\n", " # duration = self.model[\"predictor\"].duration_proj(x)\n", " # duration = torch.sigmoid(duration).sum(axis=-1) / speed\n", " # pred_dur = torch.round(duration).clamp(min=1).long()\n", " \n", " # c_start = F.pad(pred_dur,(1,0), \"constant\").cumsum(dim=1)[0,0:-1]\n", " # c_end = c_start + pred_dur[0,:]\n", " # indices = torch.arange(0, pred_dur.sum().item()).long().to(device)\n", "\n", " # pred_aln_trg_list=[]\n", " # for cs, ce in zip(c_start, c_end):\n", " # row = torch.where((indices>=cs) & (indices" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import torch.nn.functional as F\n", "\n", "# pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())\n", "c_start = F.pad(pred_dur,(1,0), \"constant\").cumsum(dim=1)[0,0:-1]\n", "c_end = c_start + pred_dur[0,:]\n", "indices = torch.arange(0, pred_dur.sum().item()).to(device)\n", "\n", "pred_aln_trg_list=[]\n", "for cs, ce in zip(c_start, c_end):\n", " row = torch.where((indices>=cs) & (indices 41\u001b[0m pred_aln_trg \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_alignment_matrix\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_lengths\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpred_dur\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 42\u001b[0m pl\u001b[38;5;241m.\u001b[39mimshow(pred_aln_trg)\n", "Cell \u001b[0;32mIn[48], line 22\u001b[0m, in \u001b[0;36mcreate_alignment_matrix\u001b[0;34m(input_lengths, pred_dur)\u001b[0m\n\u001b[1;32m 19\u001b[0m col_indices \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39marange(pred_dur\u001b[38;5;241m.\u001b[39mmax()\u001b[38;5;241m.\u001b[39mitem())\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\u001b[38;5;241m.\u001b[39mrepeat(input_lengths, \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 21\u001b[0m \u001b[38;5;66;03m# Create a mask based on durations\u001b[39;00m\n\u001b[0;32m---> 22\u001b[0m mask \u001b[38;5;241m=\u001b[39m \u001b[43mcol_indices\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m<\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpred_dur\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;66;03m# Create offset indices for the columns\u001b[39;00m\n\u001b[1;32m 25\u001b[0m offset \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat((torch\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m0\u001b[39m]), cum_dur[\u001b[38;5;241m0\u001b[39m, :\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]))\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mrepeat(\u001b[38;5;241m1\u001b[39m, pred_dur\u001b[38;5;241m.\u001b[39mmax()\u001b[38;5;241m.\u001b[39mitem())\n", "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (23) must match the size of tensor b (143) at non-singleton dimension 2" ] } ], "source": [ "\n", "\n", "def create_alignment_matrix(input_lengths, pred_dur):\n", " \"\"\"Creates an alignment matrix without explicit loops.\n", "\n", " Args:\n", " input_lengths: Number of input units (int).\n", " pred_dur: Predicted durations (torch.Tensor of shape (1, input_lengths)).\n", "\n", " Returns:\n", " pred_aln_trg: Alignment matrix (torch.Tensor of shape (input_lengths, pred_dur.sum())).\n", " \"\"\"\n", " total_duration = pred_dur.sum().item()\n", " pred_aln_trg = torch.zeros(input_lengths, total_duration)\n", "\n", " # Calculate cumulative durations\n", " cum_dur = torch.cumsum(pred_dur, dim=1)\n", "\n", " # Create indices for filling the matrix\n", " row_indices = torch.arange(input_lengths).unsqueeze(1).repeat(1, pred_dur.max().item())\n", " col_indices = torch.arange(pred_dur.max().item()).unsqueeze(0).repeat(input_lengths, 1)\n", "\n", " # Create a mask based on durations\n", " mask = col_indices < pred_dur.unsqueeze(1)\n", "\n", " # Create offset indices for the columns\n", " offset = torch.cat((torch.tensor([0]), cum_dur[0, :-1])).unsqueeze(1).repeat(1, pred_dur.max().item())\n", "\n", " # Apply the mask and offset to generate the final column indices\n", " final_col_indices = (col_indices + offset) * mask\n", "\n", " # Flatten indices and create a flattened index tensor\n", " flat_row_indices = row_indices[mask].long()\n", " flat_col_indices = final_col_indices[mask].long()\n", " flat_indices = torch.stack([flat_row_indices, flat_col_indices], dim=1)\n", "\n", " # Scatter ones into the alignment matrix\n", " pred_aln_trg[flat_indices.T[0], flat_indices.T[1]] = 1\n", "\n", " return pred_aln_trg\n", "\n", "\n", "pred_aln_trg = create_alignment_matrix(input_lengths.item(), pred_dur)\n", "pl.imshow(pred_aln_trg)\n" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([143])" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input_lengths" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())\n", "c_frame = 0\n", "\n", "for i in range(pred_aln_trg.size(0)):\n", " pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1\n", " c_frame += pred_dur[0,i].item()" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "style_model.eval" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CustomAlbert(\n", " (embeddings): AlbertEmbeddings(\n", " (word_embeddings): Embedding(178, 128, padding_idx=0)\n", " (position_embeddings): Embedding(512, 128)\n", " (token_type_embeddings): Embedding(2, 128)\n", " (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)\n", " (dropout): Dropout(p=0, inplace=False)\n", " )\n", " (encoder): AlbertTransformer(\n", " (embedding_hidden_mapping_in): Linear(in_features=128, out_features=768, bias=True)\n", " (albert_layer_groups): ModuleList(\n", " (0): AlbertLayerGroup(\n", " (albert_layers): ModuleList(\n", " (0): AlbertLayer(\n", " (full_layer_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " (attention): AlbertAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (attention_dropout): Dropout(p=0, inplace=False)\n", " (output_dropout): Dropout(p=0, inplace=False)\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", " )\n", " (ffn): Linear(in_features=768, out_features=2048, bias=True)\n", " (ffn_output): Linear(in_features=2048, out_features=768, bias=True)\n", " (activation): NewGELUActivation()\n", " (dropout): Dropout(p=0, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " )\n", " (pooler): Linear(in_features=768, out_features=768, bias=True)\n", " (pooler_activation): Tanh()\n", ")" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model['bert']" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "only integer tensors of a single element can be converted to an index", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[17], line 11\u001b[0m\n\u001b[1;32m 8\u001b[0m pred_aln_trg1 \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mzeros(input_lengths, pred_dur\u001b[38;5;241m.\u001b[39msum()\u001b[38;5;241m.\u001b[39mitem(), dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[1;32m 9\u001b[0m batch_indices \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39marange(input_lengths\u001b[38;5;241m.\u001b[39mitem())\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m---> 11\u001b[0m \u001b[43mpred_aln_trg1\u001b[49m\u001b[43m[\u001b[49m\u001b[43mbatch_indices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstart_indices\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mend_indices\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 13\u001b[0m pl\u001b[38;5;241m.\u001b[39mimshow(pred_aln_trg1)\n", "\u001b[0;31mTypeError\u001b[0m: only integer tensors of a single element can be converted to an index" ] } ], "source": [ "# Process durations\n", "\n", "cumsum_dur = torch.cumsum(pred_dur, dim=1).to(device)\n", "end_indices = cumsum_dur - 1\n", "start_indices = torch.cat([torch.zeros(1, 1, dtype=torch.long).to(device), end_indices[:, :-1] + 1], dim=1)\n", "\n", "# Create binary alignment target\n", "pred_aln_trg1 = torch.zeros(input_lengths, pred_dur.sum().item(), dtype=torch.float32)\n", "batch_indices = torch.arange(input_lengths.item()).unsqueeze(1)\n", "\n", "pred_aln_trg1[batch_indices, start_indices: end_indices + 1] = 1\n", "\n", "pl.imshow(pred_aln_trg1)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batch_indices" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([143, 329])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred_aln_trg1 = torch.zeros(input_lengths, pred_dur.sum().item()).to(device)\n", "a = torch.arange(pred_aln_trg1.size(0))[:, None].repeat(1, pred_dur.size(1)).to(device)\n", "b = (torch.arange(pred_dur.size(1)).repeat(pred_aln_trg1.size(0), 1).to(device) < pred_dur).to(torch.float32).to(device)\n", "print(pred_aln_trg.dtype, pred_aln_trg1.dtype, a.dtype, b.dtype)\n", "print(a.device, b.device, pred_dur.device)\n", "pred_aln_trg1.scatter_(1, \n", " a, \n", " b)\n", "\n", "pl.imshow(pred_aln_trg1.detach().cpu().numpy())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "Expected index [1, 25] to be smaller than self [143, 329] apart from dimension 1 and to be smaller size than src [1, 1]", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[8], line 18\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;66;03m# Use scatter_add_ to set the appropriate slices to 1\u001b[39;00m\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(pred_dur\u001b[38;5;241m.\u001b[39msize(\u001b[38;5;241m1\u001b[39m)):\n\u001b[0;32m---> 18\u001b[0m \t\u001b[43mmask\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscatter_add_\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mstart_indices\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m:\u001b[49m\u001b[43mi\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marange\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpred_dur\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclamp\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mmax\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpred_aln_trg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalues\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m:\u001b[49m\u001b[43mi\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# Apply the mask to pred_aln_trg\u001b[39;00m\n\u001b[1;32m 21\u001b[0m pred_aln_trg \u001b[38;5;241m=\u001b[39m mask\n", "\u001b[0;31mRuntimeError\u001b[0m: Expected index [1, 25] to be smaller than self [143, 329] apart from dimension 1 and to be smaller size than src [1, 1]" ] } ], "source": [ "# Calculate the cumulative sum of durations to get the end indices\n", "cumulative_durations = torch.cumsum(pred_dur, dim=1).to(device)\n", "\n", "# Calculate the start indices by shifting the cumulative durations\n", "start_indices = cumulative_durations - pred_dur\n", "\n", "# Create a tensor of indices for pred_aln_trg\n", "indices = torch.arange(pred_aln_trg.size(1)).to(device)\n", "\n", "# Create a mask tensor initialized to zeros\n", "mask = torch.zeros_like(pred_aln_trg).to(device)\n", "\n", "# Create a tensor to hold the values to scatter\n", "values = torch.ones_like(pred_dur, dtype=pred_aln_trg.dtype).to(device)\n", "\n", "# Use scatter_ to set the appropriate slices to 1\n", "mask.scatter_(1, start_indices.unsqueeze(2) + torch.arange(pred_dur.max()).unsqueeze(0).unsqueeze(0).to(device), values.unsqueeze(2))\n", "\n", "# Apply the mask to pred_aln_trg\n", "pred_aln_trg = mask" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cpu')" ] }, "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.arange(pred_dur.size(1)).repeat(pred_aln_trg1.size(0), 1).device" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pl.imshow(pred_aln_trg.detach().cpu().numpy())" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 143])\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(pred_dur.shape)\n", "pl.plot(pred_dur[0,:].detach().cpu().numpy().cumsum());" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[50, 157, 43, 135, 16, 53, 135, 46, 16, 43, 102, 16, 56, 156, 57, 135, 6, 16, 102, 62, 61, 16, 70, 56, 16, 138, 56, 156, 72, 56, 61, 85, 123, 83, 44, 83, 54, 16, 53, 65, 156, 86, 61, 62, 131, 83, 56, 4, 16, 54, 156, 43, 102, 53, 16, 156, 72, 61, 53, 102, 112, 16, 70, 56, 16, 138, 56, 44, 156, 76, 158, 123, 56, 16, 62, 131, 156, 43, 102, 54, 46, 16, 102, 48, 16, 81, 47, 102, 54, 16, 54, 156, 51, 158, 46, 16, 70, 16, 92, 156, 135, 46, 16, 54, 156, 43, 102, 48, 4, 16, 81, 47, 102, 16, 50, 156, 72, 64, 83, 56, 62, 16, 156, 51, 158, 64, 83, 56, 16, 44, 157, 102, 56, 16, 44, 156, 76, 158, 123, 56, 4]\n" ] } ], "source": [ "ps = phonemize(text)\n", "tokens = tokenize(ps)\n", "print(tokens)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from models import build_model\n", "import torch\n", "device = \"cpu\" #'cuda' if torch.cuda.is_available() else 'cpu'\n", "model = build_model('kokoro-v0_19.pth', device)\n", "voicepack = torch.load('voices/af.pt', weights_only=True).to(device)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "bert = model[\"bert\"]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "embeddings.word_embeddings.weight torch.Size([178, 128])\n", "embeddings.position_embeddings.weight torch.Size([512, 128])\n", "embeddings.token_type_embeddings.weight torch.Size([2, 128])\n", "embeddings.LayerNorm.weight torch.Size([128])\n", "embeddings.LayerNorm.bias torch.Size([128])\n", "encoder.embedding_hidden_mapping_in.weight torch.Size([768, 128])\n", "encoder.embedding_hidden_mapping_in.bias torch.Size([768])\n", "encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.weight torch.Size([768])\n", "encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm.bias torch.Size([768])\n", "encoder.albert_layer_groups.0.albert_layers.0.attention.query.weight torch.Size([768, 768])\n", "encoder.albert_layer_groups.0.albert_layers.0.attention.query.bias torch.Size([768])\n", "encoder.albert_layer_groups.0.albert_layers.0.attention.key.weight torch.Size([768, 768])\n", "encoder.albert_layer_groups.0.albert_layers.0.attention.key.bias torch.Size([768])\n", "encoder.albert_layer_groups.0.albert_layers.0.attention.value.weight torch.Size([768, 768])\n", "encoder.albert_layer_groups.0.albert_layers.0.attention.value.bias torch.Size([768])\n", "encoder.albert_layer_groups.0.albert_layers.0.attention.dense.weight torch.Size([768, 768])\n", "encoder.albert_layer_groups.0.albert_layers.0.attention.dense.bias torch.Size([768])\n", "encoder.albert_layer_groups.0.albert_layers.0.attention.LayerNorm.weight torch.Size([768])\n", "encoder.albert_layer_groups.0.albert_layers.0.attention.LayerNorm.bias torch.Size([768])\n", "encoder.albert_layer_groups.0.albert_layers.0.ffn.weight torch.Size([2048, 768])\n", "encoder.albert_layer_groups.0.albert_layers.0.ffn.bias torch.Size([2048])\n", "encoder.albert_layer_groups.0.albert_layers.0.ffn_output.weight torch.Size([768, 2048])\n", "encoder.albert_layer_groups.0.albert_layers.0.ffn_output.bias torch.Size([768])\n", "pooler.weight torch.Size([768, 768])\n", "pooler.bias torch.Size([768])\n" ] } ], "source": [ "# show all parameters of model bert\n", "for name, param in bert.named_parameters():\n", " print(name, param.requires_grad())\n", " # print(param)\n", " # print(param.shape)\n", " # break" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Testing LSTM export" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x1.shape=torch.Size([1, 300, 256])\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/onnx/symbolic_opset9.py:4279: UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model. \n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Exported graph: graph(%x : Float(*, 300, 128, strides=[38400, 128, 1], requires_grad=0, device=cpu),\n", " %onnx::LSTM_194 : Float(2, 1024, strides=[1024, 1], requires_grad=0, device=cpu),\n", " %onnx::LSTM_195 : Float(2, 512, 128, strides=[65536, 128, 1], requires_grad=0, device=cpu),\n", " %onnx::LSTM_196 : Float(2, 512, 128, strides=[65536, 128, 1], requires_grad=0, device=cpu)):\n", " %/lstm/Shape_output_0 : Long(3, strides=[1], device=cpu) = onnx::Shape[onnx_name=\"/lstm/Shape\"](%x), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1081:0\n", " %/lstm/Constant_output_0 : Long(device=cpu) = onnx::Constant[value={0}, onnx_name=\"/lstm/Constant\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1081:0\n", " %/lstm/Gather_output_0 : Long(device=cpu) = onnx::Gather[axis=0, onnx_name=\"/lstm/Gather\"](%/lstm/Shape_output_0, %/lstm/Constant_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1081:0\n", " %/lstm/Constant_1_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={2}, onnx_name=\"/lstm/Constant_1\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm\n", " %onnx::Unsqueeze_16 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()\n", " %/lstm/Unsqueeze_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name=\"/lstm/Unsqueeze\"](%/lstm/Gather_output_0, %onnx::Unsqueeze_16), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm\n", " %/lstm/Constant_2_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={128}, onnx_name=\"/lstm/Constant_2\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm\n", " %/lstm/Concat_output_0 : Long(3, strides=[1], device=cpu) = onnx::Concat[axis=0, onnx_name=\"/lstm/Concat\"](%/lstm/Constant_1_output_0, %/lstm/Unsqueeze_output_0, %/lstm/Constant_2_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1085:0\n", " %/lstm/ConstantOfShape_output_0 : Float(*, *, *, strides=[128, 128, 1], requires_grad=0, device=cpu) = onnx::ConstantOfShape[value={0}, onnx_name=\"/lstm/ConstantOfShape\"](%/lstm/Concat_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1085:0\n", " %/lstm/Transpose_output_0 : Float(300, *, 128, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name=\"/lstm/Transpose\"](%x), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n", " %onnx::LSTM_23 : Tensor? = prim::Constant(), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n", " %/lstm/LSTM_output_0 : Float(300, 2, *, 128, device=cpu), %/lstm/LSTM_output_1 : Float(2, *, 128, strides=[128, 128, 1], requires_grad=1, device=cpu), %/lstm/LSTM_output_2 : Float(2, *, 128, strides=[128, 128, 1], requires_grad=1, device=cpu) = onnx::LSTM[direction=\"bidirectional\", hidden_size=128, onnx_name=\"/lstm/LSTM\"](%/lstm/Transpose_output_0, %onnx::LSTM_195, %onnx::LSTM_196, %onnx::LSTM_194, %onnx::LSTM_23, %/lstm/ConstantOfShape_output_0, %/lstm/ConstantOfShape_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n", " %/lstm/Transpose_1_output_0 : Float(300, *, 2, 128, device=cpu) = onnx::Transpose[perm=[0, 2, 1, 3], onnx_name=\"/lstm/Transpose_1\"](%/lstm/LSTM_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n", " %/lstm/Constant_3_output_0 : Long(3, strides=[1], device=cpu) = onnx::Constant[value= 0 0 -1 [ CPULongType{3} ], onnx_name=\"/lstm/Constant_3\"](), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n", " %/lstm/Reshape_output_0 : Float(300, *, 256, device=cpu) = onnx::Reshape[allowzero=0, onnx_name=\"/lstm/Reshape\"](%/lstm/Transpose_1_output_0, %/lstm/Constant_3_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n", " %151 : Float(*, 300, 256, strides=[256, 256, 1], requires_grad=1, device=cpu) = onnx::Transpose[perm=[1, 0, 2], onnx_name=\"/lstm/Transpose_2\"](%/lstm/Reshape_output_0), scope: __main__.Model::/torch.nn.modules.rnn.LSTM::lstm # /rhome/eingerman/mambaforge/envs/styletts2/lib/python3.10/site-packages/torch/nn/modules/rnn.py:1123:0\n", " return (%151)\n", "\n" ] }, { "ename": "AttributeError", "evalue": "'NoneType' object has no attribute 'graph'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[2], line 37\u001b[0m\n\u001b[1;32m 34\u001b[0m export_mod \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39monnx\u001b[38;5;241m.\u001b[39mexport(model\u001b[38;5;241m=\u001b[39mmodel, args\u001b[38;5;241m=\u001b[39m( xa, ), dynamic_axes\u001b[38;5;241m=\u001b[39mdynamic_shapes, input_names\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m\"\u001b[39m], f\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel.onnx\u001b[39m\u001b[38;5;124m\"\u001b[39m, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, dynamo\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 35\u001b[0m \u001b[38;5;66;03m# export_mod.save(\"model.onnx\")\u001b[39;00m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;66;03m# export_mod.save_diagnostics(\"model_diagnostics.sarif\")\u001b[39;00m\n\u001b[0;32m---> 37\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mexport_mod\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgraph\u001b[49m)\n", "\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'graph'" ] } ], "source": [ "import torch\n", "# os.environ['TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED']=\"Eq(s0, 384)\"\n", "\n", "# model class containing a single bidirectional LSTM layer\n", "class Model(torch.nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.lstm = torch.nn.LSTM(128, 128, 1, bidirectional=True, batch_first=True)\n", " #initialize lstm weights\n", " for name, param in self.lstm.named_parameters():\n", " if 'weight' in name:\n", " torch.nn.init.orthogonal_(param)\n", " elif 'bias' in name:\n", " torch.nn.init.zeros_(param)\n", "\n", " def forward(self, x):\n", " x1 = x.transpose(-1,-2)\n", " # print(f\"{x.shape=} {x1.shape=}\")\n", " x2, _ = self.lstm(x)\n", " return x2\n", "\n", "model = Model()\n", "model = model.to(\"cpu\")\n", "model.eval()\n", "\n", "#inital input to LSTM in variable x\n", "xa = torch.zeros((1, 300, 128)).to(\"cpu\")\n", "x1 = model(xa)\n", "print(f\"{x1.shape=}\")\n", "ntokens = torch.export.Dim(\"ntokens\", min=3)\n", "dynamic_shapes= {\"x\":{0:\"ntokens\"}}\n", "\n", "# scripted = torch.jit.script(model)\n", "torch.onnx.export(model=model, args=( xa, ), dynamic_axes=dynamic_shapes, input_names=[\"x\"], f=\"model.onnx\", verbose=True, dynamo=False)\n", "# export_mod.save(\"model.onnx\")\n", "# export_mod.save_diagnostics(\"model_diagnostics.sarif\")\n", "# print(export_mod.graph)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 143])\n" ] } ], "source": [ "from kokoro import phonemize, tokenize\n", "from models_scripting import load_plbert\n", "bert = load_plbert()\n", "\n", "text = \"How could I know? It's an unanswerable question. Like asking an unborn child if they'll lead a good life. They haven't even been born.\"\n", "ps = phonemize(text, \"a\")\n", "tokens = tokenize(ps)\n", "tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)\n", "dynamic_shapes = {\"tokens\":{1:'ntokens'}}\n", "print(tokens.shape)\n", "torch.onnx.export(model=bert, args=( tokens, ), dynamic_axes=dynamic_shapes, input_names=[\"tokens\"], f=\"bert.onnx\", verbose=False, dynamo=False)\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "ename": "Fail", "evalue": "[ONNXRuntimeError] : 1 : FAIL : Load model from style_model.onnx failed:Node (/Transpose_9) Op (Transpose) [TypeInferenceError] Invalid attribute perm {1, -1, 0}, input shape = {0, 0, 128}", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mFail\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[6], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m onnx_model \u001b[38;5;241m=\u001b[39m onnx\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstyle_model.onnx\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01monnxruntime\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mort\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m ort_session \u001b[38;5;241m=\u001b[39m \u001b[43mort\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mInferenceSession\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstyle_model.onnx\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m outputs \u001b[38;5;241m=\u001b[39m ort_session\u001b[38;5;241m.\u001b[39mrun(\u001b[38;5;28;01mNone\u001b[39;00m, {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtokens\u001b[39m\u001b[38;5;124m\"\u001b[39m: tokens\u001b[38;5;241m.\u001b[39mnumpy()})\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:465\u001b[0m, in \u001b[0;36mInferenceSession.__init__\u001b[0;34m(self, path_or_bytes, sess_options, providers, provider_options, **kwargs)\u001b[0m\n\u001b[1;32m 462\u001b[0m disabled_optimizers \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdisabled_optimizers\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 464\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 465\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_create_inference_session\u001b[49m\u001b[43m(\u001b[49m\u001b[43mproviders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprovider_options\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdisabled_optimizers\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 466\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mValueError\u001b[39;00m, \u001b[38;5;167;01mRuntimeError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 467\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_enable_fallback:\n", "File \u001b[0;32m~/mambaforge/envs/styletts2/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:526\u001b[0m, in \u001b[0;36mInferenceSession._create_inference_session\u001b[0;34m(self, providers, provider_options, disabled_optimizers)\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_register_ep_custom_ops(session_options, providers, provider_options, available_providers)\n\u001b[1;32m 525\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model_path:\n\u001b[0;32m--> 526\u001b[0m sess \u001b[38;5;241m=\u001b[39m \u001b[43mC\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mInferenceSession\u001b[49m\u001b[43m(\u001b[49m\u001b[43msession_options\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_model_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_read_config_from_model\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 527\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 528\u001b[0m sess \u001b[38;5;241m=\u001b[39m C\u001b[38;5;241m.\u001b[39mInferenceSession(session_options, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model_bytes, \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_read_config_from_model)\n", "\u001b[0;31mFail\u001b[0m: [ONNXRuntimeError] : 1 : FAIL : Load model from style_model.onnx failed:Node (/Transpose_9) Op (Transpose) [TypeInferenceError] Invalid attribute perm {1, -1, 0}, input shape = {0, 0, 128}" ] } ], "source": [ "import onnx\n", "\n", "onnx_model = onnx.load(\"style_model.onnx\")\n", "import onnxruntime as ort\n", "\n", "ort_session = ort.InferenceSession(\"style_model.onnx\")\n", "outputs = ort_session.run(None, {\"tokens\": tokens.numpy()})" ] } ], "metadata": { "kernelspec": { "display_name": "styletts2", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }