{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "animegan_v2_for_videos.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyP/bydrfrVmE0CzRt9JBw+x",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "code",
"metadata": {
"id": "dufmM-T1Helt"
},
"source": [
"%%capture\n",
"! pip install gradio encoded-video"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "9CY3n8A0Lvdi"
},
"source": [
"import gc\n",
"import math\n",
"import tempfile\n",
"from PIL import Image\n",
"from io import BytesIO\n",
"\n",
"import torch\n",
"import gradio as gr\n",
"import numpy as np\n",
"from encoded_video import EncodedVideo, write_video\n",
"from torchvision.transforms.functional import to_tensor, center_crop"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "YxdCnrTzLw5V"
},
"source": [
"model = torch.hub.load(\n",
" \"AK391/animegan2-pytorch:main\",\n",
" \"generator\",\n",
" pretrained=True,\n",
" device=\"cuda\",\n",
" progress=True,\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "TYAyXUP1UeOd"
},
"source": [
"! curl https://upload.wikimedia.org/wikipedia/commons/transcoded/2/29/2017-01-07_President_Obama%27s_Weekly_Address.webm/2017-01-07_President_Obama%27s_Weekly_Address.webm.360p.vp9.webm -o obama.webm"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "TxT45Nlc88tD"
},
"source": [
"def face2paint(model: torch.nn.Module, img: Image.Image, size: int = 512, device: str = 'cuda'):\n",
" w, h = img.size\n",
" s = min(w, h)\n",
" img = img.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))\n",
" img = img.resize((size, size), Image.LANCZOS)\n",
"\n",
" with torch.no_grad():\n",
" input = to_tensor(img).unsqueeze(0) * 2 - 1\n",
" output = model(input.to(device)).cpu()[0]\n",
"\n",
" output = (output * 0.5 + 0.5).clip(0, 1) * 255.\n",
"\n",
" return output\n",
"\n",
"# This function is taken from pytorchvideo!\n",
"def uniform_temporal_subsample(x: torch.Tensor, num_samples: int, temporal_dim: int = -3) -> torch.Tensor:\n",
" \"\"\"\n",
" Uniformly subsamples num_samples indices from the temporal dimension of the video.\n",
" When num_samples is larger than the size of temporal dimension of the video, it\n",
" will sample frames based on nearest neighbor interpolation.\n",
" Args:\n",
" x (torch.Tensor): A video tensor with dimension larger than one with torch\n",
" tensor type includes int, long, float, complex, etc.\n",
" num_samples (int): The number of equispaced samples to be selected\n",
" temporal_dim (int): dimension of temporal to perform temporal subsample.\n",
" Returns:\n",
" An x-like Tensor with subsampled temporal dimension.\n",
" \"\"\"\n",
" t = x.shape[temporal_dim]\n",
" assert num_samples > 0 and t > 0\n",
" # Sample by nearest neighbor interpolation if num_samples > t.\n",
" indices = torch.linspace(0, t - 1, num_samples)\n",
" indices = torch.clamp(indices, 0, t - 1).long()\n",
" return torch.index_select(x, temporal_dim, indices)\n",
"\n",
"\n",
"def short_side_scale(\n",
" x: torch.Tensor,\n",
" size: int,\n",
" interpolation: str = \"bilinear\",\n",
") -> torch.Tensor:\n",
" \"\"\"\n",
" Determines the shorter spatial dim of the video (i.e. width or height) and scales\n",
" it to the given size. To maintain aspect ratio, the longer side is then scaled\n",
" accordingly.\n",
" Args:\n",
" x (torch.Tensor): A video tensor of shape (C, T, H, W) and type torch.float32.\n",
" size (int): The size the shorter side is scaled to.\n",
" interpolation (str): Algorithm used for upsampling,\n",
" options: nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'\n",
" Returns:\n",
" An x-like Tensor with scaled spatial dims.\n",
" \"\"\"\n",
" assert len(x.shape) == 4\n",
" assert x.dtype == torch.float32\n",
" c, t, h, w = x.shape\n",
" if w < h:\n",
" new_h = int(math.floor((float(h) / w) * size))\n",
" new_w = size\n",
" else:\n",
" new_h = size\n",
" new_w = int(math.floor((float(w) / h) * size))\n",
"\n",
" return torch.nn.functional.interpolate(\n",
" x, size=(new_h, new_w), mode=interpolation, align_corners=False\n",
" )\n",
"\n",
"def inference_step(vid, start_sec, duration, out_fps):\n",
" clip = vid.get_clip(start_sec, start_sec + duration)\n",
" video_arr = torch.from_numpy(clip['video']).permute(3, 0, 1, 2)\n",
" audio_arr = np.expand_dims(clip['audio'], 0)\n",
" audio_fps = None if not vid._has_audio else vid._container.streams.audio[0].sample_rate\n",
"\n",
" x = uniform_temporal_subsample(video_arr, duration * out_fps)\n",
" x = center_crop(short_side_scale(x, 512), 512)\n",
" x /= 255.\n",
" x = x.permute(1, 0, 2, 3)\n",
" with torch.no_grad():\n",
" output = model(x.to('cuda')).detach().cpu()\n",
" output = (output * 0.5 + 0.5).clip(0, 1) * 255.\n",
" output_video = output.permute(0, 2, 3, 1).numpy()\n",
" \n",
" return output_video, audio_arr, out_fps, audio_fps\n",
"\n",
"def predict_fn(filepath, start_sec, duration, out_fps):\n",
" # out_fps=12\n",
" vid = EncodedVideo.from_path(filepath)\n",
" for i in range(duration):\n",
" video, audio, fps, audio_fps = inference_step(\n",
" vid = vid,\n",
" start_sec = i + start_sec,\n",
" duration = 1,\n",
" out_fps = out_fps\n",
" )\n",
" gc.collect()\n",
" if i == 0:\n",
" video_all = video\n",
" audio_all = audio\n",
" else:\n",
" video_all = np.concatenate((video_all, video))\n",
" audio_all = np.hstack((audio_all, audio))\n",
"\n",
" write_video(\n",
" 'out.mp4',\n",
" video_all,\n",
" fps=fps,\n",
" audio_array=audio_all,\n",
" audio_fps=audio_fps,\n",
" audio_codec='aac'\n",
" )\n",
"\n",
" del video_all\n",
" del audio_all\n",
" \n",
" return 'out.mp4'\n",
"\n",
"article = \"\"\"\n",
"
\n", " Github Repo Pytorch\n", "
\n", "\"\"\"\n", "\n", "gr.Interface(\n", " predict_fn,\n", " inputs=[gr.inputs.Video(), gr.inputs.Slider(minimum=0, maximum=300, step=1, default=0), gr.inputs.Slider(minimum=1, maximum=10, step=1, default=2), gr.inputs.Slider(minimum=12, maximum=30, step=6, default=24)],\n", " outputs=gr.outputs.Video(),\n", " title='AnimeGANV2 On Videos',\n", " description=\"Applying AnimeGAN-V2 to frame from video clips\",\n", " article = article,\n", " enable_queue=True,\n", " examples=[\n", " ['obama.webm', 23, 10, 30],\n", " ],\n", " allow_flagging=False\n", ").launch(debug=True)" ], "execution_count": null, "outputs": [] } ] }