diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cd2fe3fa9c8535dfc6419c2fdedf7e810849af64 --- /dev/null +++ b/README.md @@ -0,0 +1,31 @@ +# creco-inference +Unified inference code for SageMaker and Hugging Face endpoints + +## Deployment + +- Inference code (this) should be placed in the model folder respectively, + +### SageMaker + +``` +model/ + code/ + (repo) <-- The repo inference code as direct child (no sub-folder) + vae + unet + ... +``` + +- Refer `deployment.ipynb` for creating endpoint. + +### Hugging Face + +``` +model/ + (repo) <-- The repo inference code as direct child (no sub-folder) + vae + unet + ... +``` + +- Refer [doc](https://huggingface.co/docs/inference-endpoints/guides/create_endpoint) to create endpoint. diff --git a/config.yml b/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..55fd91b5bcacd654e3045a2331e9c186818e6edc --- /dev/null +++ b/config.yml @@ -0,0 +1,157 @@ +run_title: b18_ffc075_batch8x15 +training_model: + kind: default + visualize_each_iters: 1000 + concat_mask: true + store_discr_outputs_for_vis: true +losses: + l1: + weight_missing: 0 + weight_known: 10 + perceptual: + weight: 0 + adversarial: + kind: r1 + weight: 10 + gp_coef: 0.001 + mask_as_fake_target: true + allow_scale_mask: true + feature_matching: + weight: 100 + resnet_pl: + weight: 30 + weights_path: ${env:TORCH_HOME} + +optimizers: + generator: + kind: adam + lr: 0.001 + discriminator: + kind: adam + lr: 0.0001 +visualizer: + key_order: + - image + - predicted_image + - discr_output_fake + - discr_output_real + - inpainted + rescale_keys: + - discr_output_fake + - discr_output_real + kind: directory + outdir: /group-volume/User-Driven-Content-Generation/r.suvorov/inpainting/experiments/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/samples +location: + data_root_dir: /group-volume/User-Driven-Content-Generation/datasets/inpainting_data_root_large + out_root_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/experiments + tb_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/tb_logs +data: + batch_size: 15 + val_batch_size: 2 + num_workers: 3 + train: + indir: ${location.data_root_dir}/train + out_size: 256 + mask_gen_kwargs: + irregular_proba: 1 + irregular_kwargs: + max_angle: 4 + max_len: 200 + max_width: 100 + max_times: 5 + min_times: 1 + box_proba: 1 + box_kwargs: + margin: 10 + bbox_min_size: 30 + bbox_max_size: 150 + max_times: 3 + min_times: 1 + segm_proba: 0 + segm_kwargs: + confidence_threshold: 0.5 + max_object_area: 0.5 + min_mask_area: 0.07 + downsample_levels: 6 + num_variants_per_mask: 1 + rigidness_mode: 1 + max_foreground_coverage: 0.3 + max_foreground_intersection: 0.7 + max_mask_intersection: 0.1 + max_hidden_area: 0.1 + max_scale_change: 0.25 + horizontal_flip: true + max_vertical_shift: 0.2 + position_shuffle: true + transform_variant: distortions + dataloader_kwargs: + batch_size: ${data.batch_size} + shuffle: true + num_workers: ${data.num_workers} + val: + indir: ${location.data_root_dir}/val + img_suffix: .png + dataloader_kwargs: + batch_size: ${data.val_batch_size} + shuffle: false + num_workers: ${data.num_workers} + visual_test: + indir: ${location.data_root_dir}/korean_test + img_suffix: _input.png + pad_out_to_modulo: 32 + dataloader_kwargs: + batch_size: 1 + shuffle: false + num_workers: ${data.num_workers} +generator: + kind: ffc_resnet + input_nc: 4 + output_nc: 3 + ngf: 64 + n_downsampling: 3 + n_blocks: 18 + add_out_act: sigmoid + init_conv_kwargs: + ratio_gin: 0 + ratio_gout: 0 + enable_lfu: false + downsample_conv_kwargs: + ratio_gin: ${generator.init_conv_kwargs.ratio_gout} + ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin} + enable_lfu: false + resnet_conv_kwargs: + ratio_gin: 0.75 + ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin} + enable_lfu: false +discriminator: + kind: pix2pixhd_nlayer + input_nc: 3 + ndf: 64 + n_layers: 4 +evaluator: + kind: default + inpainted_key: inpainted + integral_kind: ssim_fid100_f1 +trainer: + kwargs: + gpus: -1 + accelerator: ddp + max_epochs: 200 + gradient_clip_val: 1 + log_gpu_memory: None + limit_train_batches: 25000 + val_check_interval: ${trainer.kwargs.limit_train_batches} + log_every_n_steps: 1000 + precision: 32 + terminate_on_nan: false + check_val_every_n_epoch: 1 + num_sanity_val_steps: 8 + limit_val_batches: 1000 + replace_sampler_ddp: false + checkpoint_kwargs: + verbose: true + save_top_k: 5 + save_last: true + period: 1 + monitor: val_ssim_fid100_f1_total_mean + mode: max diff --git a/deployment.ipynb b/deployment.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5b79b09bcdf222084ed0a13f1a0fa5505cefd4ce --- /dev/null +++ b/deployment.ipynb @@ -0,0 +1,995 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "5af7e53b-80ff-4058-888d-fe41804f64ba", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com\n", + "Requirement already satisfied: pip in /home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages (23.1.2)\n" + ] + } + ], + "source": [ + "!pip install --upgrade pip\n", + "!pip install \"sagemaker==2.116.0\" \"huggingface_hub==0.10.1\" --upgrade --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "93ee3d96-400f-46b4-8eb3-0f3f3c853a7e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from distutils.dir_util import copy_tree\n", + "from pathlib import Path\n", + "from huggingface_hub import snapshot_download\n", + "import random\n", + "import os\n", + "import tarfile\n", + "import time\n", + "import sagemaker\n", + "from datetime import datetime\n", + "from sagemaker.s3 import S3Uploader\n", + "import boto3\n", + "from sagemaker.huggingface.model import HuggingFaceModel\n", + "from threading import Thread\n", + "import subprocess\n", + "import shutil" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2db37b03-b517-46bc-8602-4999a64399c0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# ------------------------------------------------\n", + "# Configuration\n", + "# ------------------------------------------------\n", + "STAGE = \"prod\"\n", + "model_configs = [\n", + " # {\n", + " # \"inference_2\": False, \n", + " # \"path\": \"icbinp\",\n", + " # \"endpoint_name\": \"gamma-10000-2023-05-16-14-55\"\n", + " # #\"endpoint_name\": f\"{STAGE}-10000-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n", + " # },\n", + " # {\n", + " # \"inference_2\": False, \n", + " # \"path\": \"icb_with_epi\",\n", + " # \"endpoint_name\": \"gamma-10000-2023-05-16-14-55\"\n", + " # # \"endpoint_name\": f\"{STAGE}-10000-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n", + " # },\n", + " {\n", + " \"inference_2\": False, \n", + " \"path\": \"model_v9\",\n", + " # \"endpoint_name\": \"gamma-10000-2023-05-16-14-55\"\n", + " \"endpoint_name\": f\"{STAGE}-10000-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n", + " },\n", + " {\n", + " \"inference_2\": False, \n", + " \"path\": \"model_v8\",\n", + " #\"endpoint_name\": \"gamma-10001-2023-05-08-06-14\"\n", + " \"endpoint_name\": f\"{STAGE}-10001-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n", + " },\n", + " # {\n", + " # \"inference_2\": False, \n", + " # \"path\": \"model_v5_anime\",\n", + " # \"endpoint_name\": \"gamma-10001-2023-05-08-06-14\"\n", + " # #\"endpoint_name\": f\"{STAGE}-10001-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n", + " # },\n", + " # {\n", + " # \"inference_2\": False, \n", + " # \"path\": \"model_v5.3_comic\",\n", + " # #\"endpoint_name\": \"gamma-10002-2023-05-08-07-22\"\n", + " # \"endpoint_name\": f\"{STAGE}-10002-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n", + " # },\n", + " {\n", + " \"inference_2\": False, \n", + " \"path\": \"model_v10\",\n", + " # \"endpoint_name\": \"gamma-10002-2023-05-08-07-22\"\n", + " \"endpoint_name\": f\"{STAGE}-10002-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n", + " },\n", + " {\n", + " \"inference_2\": True, \n", + " \"path\": \"model_v5.2_other\",\n", + " # \"endpoint_name\": \"gamma-other-2023-05-04-09-33\"\n", + " \"endpoint_name\": f\"{STAGE}-other-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n", + " }\n", + " # {\n", + " # \"inference_2\": False, \n", + " # \"path\": \"model_v6_bheem\",\n", + " # \"endpoint_name\": f\"{STAGE}-10003-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n", + " # },\n", + " # {\n", + " # \"inference_2\": False, \n", + " # \"path\": \"model_v12\",\n", + " # \"endpoint_name\": \"gamma-10003-2023-05-04-05-20\"\n", + " # # \"endpoint_name\": f\"{STAGE}-10003-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n", + " # }\n", + "]\n", + "\n", + "VpcConfig = {\n", + " \"Subnets\": [\n", + " \"subnet-0df3f71df4c7b29e5\",\n", + " \"subnet-0d753b7fc74b5ee68\"\n", + " ],\n", + " \"SecurityGroupIds\": [\n", + " \"sg-033a7948e79a501cd\"\n", + " ]\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d7322ac4-aeeb-4a72-a662-5f3fa74e6454", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def compress(tar_dir=None,output_file=\"model.tar.gz\"):\n", + " parent_dir=os.getcwd()\n", + " os.chdir(parent_dir + \"/\" + tar_dir)\n", + " with tarfile.open(os.path.join(parent_dir, output_file), \"w:gz\") as tar:\n", + " for item in os.listdir('.'):\n", + " print(\"- \" + item)\n", + " tar.add(item, arcname=item)\n", + " os.chdir(parent_dir)\n", + "\n", + " \n", + "def create_model_tar(config):\n", + " print(\"Copying inference 'code': \" + config.get(\"path\"))\n", + " \n", + " model_tar = Path(config.get(\"path\"))\n", + " if os.path.exists(model_tar.joinpath(\"code\")):\n", + " shutil.rmtree(model_tar.joinpath(\"code\"))\n", + " out_tar = config.get(\"path\") + \".tar.gz\"\n", + " model_tar.mkdir(exist_ok=True)\n", + " copy_tree(\"code/\", str(model_tar.joinpath(\"code\")))\n", + " copy_tree(\"laur_style/\", str(model_tar.joinpath(\"laur_style\")))\n", + " \n", + " if config.get(\"inference_2\"):\n", + " os.remove(model_tar.joinpath(\"code\").joinpath(\"inference.py\"))\n", + " os.rename(model_tar.joinpath(\"code\").joinpath(\"inference2.py\"), model_tar.joinpath(\"code\").joinpath(\"inference.py\"))\n", + " \n", + " print(\"Compressing: \" + config.get(\"path\"))\n", + "\n", + " if os.path.exists(out_tar):\n", + " os.remove(out_tar)\n", + "\n", + " compress(str(model_tar), out_tar)\n", + " \n", + "def upload_to_s3(config):\n", + " out_tar = config.get(\"path\") + \".tar.gz\"\n", + " print(\"Uploading model to S3: \" + out_tar)\n", + " s3_model_uri=S3Uploader.upload(local_path=out_tar, desired_s3_uri=f\"s3://comic-assets/stable-diffusion-v1-4/v2/\")\n", + " return s3_model_uri\n", + " \n", + " \n", + "def deploy_and_create_endpoint(config, s3_model_uri):\n", + " sess = sagemaker.Session()\n", + " # sagemaker session bucket -> used for uploading data, models and logs\n", + " # sagemaker will automatically create this bucket if it not exists\n", + " sagemaker_session_bucket=None\n", + " if sagemaker_session_bucket is None and sess is not None:\n", + " # set to default bucket if a bucket name is not given\n", + " sagemaker_session_bucket = sess.default_bucket()\n", + " try:\n", + " role = sagemaker.get_execution_role()\n", + " except ValueError:\n", + " iam = boto3.client('iam')\n", + " role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n", + "\n", + " sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)\n", + " \n", + " huggingface_model = HuggingFaceModel(\n", + " model_data=s3_model_uri, # path to your model and script\n", + " role=role, # iam role with permissions to create an Endpoint\n", + " transformers_version=\"4.17\", # transformers version used\n", + " pytorch_version=\"1.10\", # pytorch version used\n", + " py_version='py38',# python version used\n", + " vpc_config=VpcConfig,\n", + " )\n", + "\n", + " print(\"Creating endpoint: \" + config.get(\"endpoint_name\"))\n", + "\n", + " predictor = huggingface_model.deploy(\n", + " initial_instance_count=1,\n", + " instance_type=\"ml.g4dn.xlarge\",\n", + " endpoint_name=config.get(\"endpoint_name\")\n", + " )\n", + "\n", + " \n", + "def start_process(config):\n", + " try:\n", + " create_model_tar(config)\n", + " s3_model_uri = upload_to_s3(config)\n", + " #s3_model_uri = \"s3://comic-assets/stable-diffusion-v1-4/v2//model_v5.2_other.tar.gz\"\n", + " deploy_and_create_endpoint(config, s3_model_uri)\n", + " except Exception as e:\n", + " print(\"Failed to deploy: \" + config.get(\"path\") + \"\\n\" + str(e))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "cdc04669-90a5-4b43-8499-ad1d2dd63a4c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Copying inference 'code': model_v9\n", + "Compressing: model_v9\n", + "- scheduler\n", + "- vae\n", + "- .ipynb_checkpoints\n", + "- feature_extractor\n", + "- tokenizer\n", + "- text_encoder\n", + "- model_index.json\n", + "- laur_style\n", + "- code\n", + "- unet\n", + "- args.json\n", + "Uploading model to S3: model_v9.tar.gz\n", + "Creating endpoint: gamma-10000-2023-05-16-14-55\n", + "-----------------!\n", + "\n", + "Completed in : 992.3517553806305s\n" + ] + } + ], + "source": [ + "threads = []\n", + "\n", + "os.chdir(\"/home/ec2-user/SageMaker\")\n", + "\n", + "start_time = time.time()\n", + "\n", + "for config in model_configs:\n", + " thread = Thread(target=start_process, args=(config,))\n", + " thread.start()\n", + " thread.join()\n", + " threads.append(thread)\n", + "\n", + "for thread in threads:\n", + " thread.join()\n", + " \n", + "print(\"\\n\\nCompleted in : \" + str(time.time() - start_time) + \"s\")\n", + "\n", + "# For redeploying gamma endpoints or promoting gamma endpoints to prod\n", + "\n", + "# thread1 = Thread(target=deploy_and_create_endpoint, args=(model_configs[0],\"s3://comic-assets/stable-diffusion-v1-4/v2//model_v9.tar.gz\",))\n", + "# thread2 = Thread(target=deploy_and_create_endpoint, args=(model_configs[1],\"s3://comic-assets/stable-diffusion-v1-4/v2//anime_mode_with_lora.tar.gz\",))\n", + "# thread3 = Thread(target=deploy_and_create_endpoint, args=(model_configs[0],\"s3://comic-assets/stable-diffusion-v1-4/v2//model_v5.3_comic.tar.gz\",))\n", + "# thread4 = Thread(target=deploy_and_create_endpoint, args=(model_configs[3],\"s3://comic-assets/stable-diffusion-v1-4/v2//model_v5.2_other.tar.gz\",))\n", + "\n", + "# thread1.start()\n", + "# thread2.start()\n", + "# thread3.start()\n", + "# thread4.start()\n", + "\n", + "# thread1.join()\n", + "# thread2.join()\n", + "# thread3.join()\n", + "# thread4.join()\n", + "\n", + "# print(\"Done\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39f007f2-0ff8-487c-b5d7-158f0947b7fd", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "\n", + "# import sagemaker\n", + "# import boto3\n", + "# import time \n", + "\n", + "# start = time.time()\n", + "\n", + "# sess = sagemaker.Session()\n", + "# # sagemaker session bucket -> used for uploading data, models and logs\n", + "# # sagemaker will automatically create this bucket if it not exists\n", + "# sagemaker_session_bucket=None\n", + "# if sagemaker_session_bucket is None and sess is not None:\n", + "# # set to default bucket if a bucket name is not given\n", + "# sagemaker_session_bucket = sess.default_bucket()\n", + "\n", + "# try:\n", + "# role = sagemaker.get_execution_role()\n", + "# except ValueError:\n", + "# iam = boto3.client('iam')\n", + "# role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n", + "\n", + "# sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)\n", + "\n", + "# print(f\"sagemaker role arn: {role}\")\n", + "# print(f\"sagemaker bucket: {sess.default_bucket()}\")\n", + "# print(f\"sagemaker session region: {sess.boto_region_name}\")\n", + "# print(sagemaker.get_execution_role())\n", + "\n", + "# from sagemaker.s3 import S3Uploader\n", + "\n", + "# print(\"Uploading model to S3\")\n", + "\n", + "# # upload model.tar.gz to s3\n", + "# s3_model_uri=S3Uploader.upload(local_path=\"model.tar.gz\", desired_s3_uri=f\"s3://comic-assets/stable-diffusion-v1-4/v2/\")\n", + "\n", + "# print(f\"model uploaded to: {s3_model_uri}\")\n", + "\n", + "\n", + "# from sagemaker.huggingface.model import HuggingFaceModel\n", + "\n", + "# VpcConfig = {\n", + "# \"Subnets\": [\n", + "# \"subnet-0df3f71df4c7b29e5\",\n", + "# \"subnet-0d753b7fc74b5ee68\"\n", + "# ],\n", + "# \"SecurityGroupIds\": [\n", + "# \"sg-033a7948e79a501cd\"\n", + "# ]\n", + "# }\n", + "\n", + "# # create Hugging Face Model Class\n", + "# huggingface_model = HuggingFaceModel(\n", + "# model_data=s3_model_uri, # path to your model and script\n", + "# role=role, # iam role with permissions to create an Endpoint\n", + "# transformers_version=\"4.17\", # transformers version used\n", + "# pytorch_version=\"1.10\", # pytorch version used\n", + "# py_version='py38',# python version used\n", + "# vpc_config=VpcConfig,\n", + "# )\n", + "\n", + "# print(\"Deploying model\")\n", + "\n", + "# predictor = huggingface_model.deploy(\n", + "# initial_instance_count=1,\n", + "# instance_type=\"ml.g4dn.xlarge\",\n", + "# # endpoint_name=endpoint_name\n", + "# )\n", + "\n", + "# print(f\"Done {time.time() - start}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa95a262-d6ba-4e61-8657-6f8e5bab74a1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | sudo bash" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "524ca546-2a67-4b51-9cda-a1b51a49c339", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!sudo yum install git-lfs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c7e661f-5eee-4357-80f6-e7563941a812", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "availableInstances": [ + { + "_defaultOrder": 0, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.t3.medium", + "vcpuNum": 2 + }, + { + "_defaultOrder": 1, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.t3.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 2, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.t3.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 3, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.t3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 4, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 5, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 6, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 7, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 8, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 9, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 10, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 11, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 12, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5d.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 13, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5d.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 14, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5d.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 15, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5d.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 16, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5d.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 17, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5d.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 18, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5d.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 19, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5d.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 20, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": true, + "memoryGiB": 0, + "name": "ml.geospatial.interactive", + "supportedImageNames": [ + "sagemaker-geospatial-v1-0" + ], + "vcpuNum": 0 + }, + { + "_defaultOrder": 21, + "_isFastLaunch": true, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.c5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 22, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.c5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 23, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.c5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 24, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.c5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 25, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 72, + "name": "ml.c5.9xlarge", + "vcpuNum": 36 + }, + { + "_defaultOrder": 26, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 96, + "name": "ml.c5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 27, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 144, + "name": "ml.c5.18xlarge", + "vcpuNum": 72 + }, + { + "_defaultOrder": 28, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.c5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 29, + "_isFastLaunch": true, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g4dn.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 30, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g4dn.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 31, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g4dn.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 32, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g4dn.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 33, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g4dn.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 34, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g4dn.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 35, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 61, + "name": "ml.p3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 36, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 244, + "name": "ml.p3.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 37, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 488, + "name": "ml.p3.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 38, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.p3dn.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 39, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.r5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 40, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.r5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 41, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.r5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 42, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.r5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 43, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.r5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 44, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.r5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 45, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 512, + "name": "ml.r5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 46, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.r5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 47, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 48, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 49, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 50, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 51, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 52, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 53, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.g5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 54, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.g5.48xlarge", + "vcpuNum": 192 + } + ], + "instance_type": "ml.t3.medium", + "kernelspec": { + "display_name": "conda_pytorch_p39", + "language": "python", + "name": "conda_pytorch_p39" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/handler.py b/handler.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3f8e4ed7911d7d4fe464b837c0ebc04f76ef04 --- /dev/null +++ b/handler.py @@ -0,0 +1,11 @@ +from typing import Any, Dict, List + +from inference import model_fn, predict_fn + + +class EndpointHandler: + def __init__(self, path=""): + return model_fn(path) + + def __call__(self, data: Any) -> List[List[Dict[str, float]]]: + return predict_fn(data, None) diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..ea091e4ff3b250f04eb0de0b47e5079a2defcb4b --- /dev/null +++ b/inference.py @@ -0,0 +1,341 @@ +from typing import List, Optional + +import torch + +from internals.data.dataAccessor import update_db +from internals.data.task import Task, TaskType +from internals.pipelines.commons import Img2Img, Text2Img +from internals.pipelines.controlnets import ControlNet +from internals.pipelines.img_classifier import ImageClassifier +from internals.pipelines.img_to_text import Image2Text +from internals.pipelines.prompt_modifier import PromptModifier +from internals.pipelines.safety_checker import SafetyChecker +from internals.util.args import apply_style_args +from internals.util.avatar import Avatar +from internals.util.cache import auto_clear_cuda_and_gc +from internals.util.commons import pickPoses, upload_image, upload_images +from internals.util.config import set_configs_from_task, set_root_dir +from internals.util.failure_hander import FailureHandler +from internals.util.lora_style import LoraStyle +from internals.util.slack import Slack + +torch.backends.cudnn.benchmark = True +torch.backends.cuda.matmul.allow_tf32 = True + +num_return_sequences = 4 # the number of results to generate +auto_mode = False + +prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences) +img2text = Image2Text() +img_classifier = ImageClassifier() +controlnet = ControlNet() +lora_style = LoraStyle() +text2img_pipe = Text2Img() +img2img_pipe = Img2Img() +safety_checker = SafetyChecker() +slack = Slack() +avatar = Avatar() + + +def get_patched_prompt(task: Task): + def add_style_and_character(prompt: List[str], additional: Optional[str] = None): + for i in range(len(prompt)): + prompt[i] = avatar.add_code_names(prompt[i]) + prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style()) + if additional: + prompt[i] = additional + " " + prompt[i] + + prompt = task.get_prompt() + + if task.is_prompt_engineering(): + prompt = prompt_modifier.modify(prompt) + else: + prompt = [prompt] * num_return_sequences + + ori_prompt = [task.get_prompt()] * num_return_sequences + + class_name = None + # if task.get_imageUrl(): + # class_name = img_classifier.classify( + # task.get_imageUrl(), task.get_width(), task.get_height() + # ) + add_style_and_character(ori_prompt, class_name) + add_style_and_character(prompt, class_name) + + print({"prompts": prompt}) + + return (prompt, ori_prompt) + + +def get_patched_prompt_tile_upscale(task: Task): + if task.get_prompt(): + prompt = task.get_prompt() + else: + prompt = img2text.process(task.get_imageUrl()) + + prompt = avatar.add_code_names(prompt) + prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style()) + + class_name = img_classifier.classify( + task.get_imageUrl(), task.get_width(), task.get_height() + ) + prompt = class_name + " " + prompt + + print({"prompt": prompt}) + + return prompt + + +@update_db +@auto_clear_cuda_and_gc(controlnet) +@slack.auto_send_alert +def canny(task: Task): + prompt, _ = get_patched_prompt(task) + + controlnet.load_canny() + + # pipe2 is used for canny and pose + lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style()) + lora_patcher.patch() + + images, has_nsfw = controlnet.process_canny( + prompt=prompt, + imageUrl=task.get_imageUrl(), + seed=task.get_seed(), + steps=task.get_steps(), + width=task.get_width(), + height=task.get_height(), + guidance_scale=task.get_cy_guidance_scale(), + negative_prompt=[ + f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}" + ] + * num_return_sequences, + **lora_patcher.kwargs(), + ) + + generated_image_urls = upload_images(images, "_canny", task.get_taskId()) + + lora_patcher.cleanup() + controlnet.cleanup() + + return { + "modified_prompts": prompt, + "generated_image_urls": generated_image_urls, + "has_nsfw": has_nsfw, + } + + +@update_db +@auto_clear_cuda_and_gc(controlnet) +@slack.auto_send_alert +def tile_upscale(task: Task): + output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId()) + + prompt = get_patched_prompt_tile_upscale(task) + + controlnet.load_tile_upscaler() + + lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style()) + lora_patcher.patch() + + images, has_nsfw = controlnet.process_tile_upscaler( + imageUrl=task.get_imageUrl(), + seed=task.get_seed(), + steps=task.get_steps(), + width=task.get_width(), + height=task.get_height(), + prompt=prompt, + resize_dimension=task.get_resize_dimension(), + negative_prompt=task.get_negative_prompt(), + guidance_scale=task.get_ti_guidance_scale(), + ) + + generated_image_url = upload_image(images[0], output_key) + + lora_patcher.cleanup() + controlnet.cleanup() + + return { + "modified_prompts": prompt, + "generated_image_url": generated_image_url, + "has_nsfw": has_nsfw, + } + + +@update_db +@auto_clear_cuda_and_gc(controlnet) +@slack.auto_send_alert +def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None): + prompt, _ = get_patched_prompt(task) + + controlnet.load_pose() + + # pipe2 is used for canny and pose + lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style()) + lora_patcher.patch() + + if poses is None: + poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences + + images, has_nsfw = controlnet.process_pose( + prompt=prompt, + image=poses, + seed=task.get_seed(), + steps=task.get_steps(), + negative_prompt=[task.get_negative_prompt()] * num_return_sequences, + width=task.get_width(), + height=task.get_height(), + guidance_scale=task.get_po_guidance_scale(), + **lora_patcher.kwargs(), + ) + + generated_image_urls = upload_images(images, s3_outkey, task.get_taskId()) + + lora_patcher.cleanup() + controlnet.cleanup() + + return { + "modified_prompts": prompt, + "generated_image_urls": generated_image_urls, + "has_nsfw": has_nsfw, + } + + +@update_db +@auto_clear_cuda_and_gc(controlnet) +@slack.auto_send_alert +def text2img(task: Task): + prompt, ori_prompt = get_patched_prompt(task) + + lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style()) + lora_patcher.patch() + + torch.manual_seed(task.get_seed()) + + images, has_nsfw = text2img_pipe.process( + prompt=ori_prompt, + modified_prompts=prompt, + num_inference_steps=task.get_steps(), + guidance_scale=7.5, + height=task.get_height(), + width=task.get_width(), + negative_prompt=[task.get_negative_prompt()] * num_return_sequences, + iteration=task.get_iteration(), + **lora_patcher.kwargs(), + ) + + generated_image_urls = upload_images(images, "", task.get_taskId()) + + lora_patcher.cleanup() + + return { + "modified_prompts": prompt, + "generated_image_urls": generated_image_urls, + "has_nsfw": has_nsfw, + } + + +@update_db +@auto_clear_cuda_and_gc(controlnet) +@slack.auto_send_alert +def img2img(task: Task): + prompt, _ = get_patched_prompt(task) + + lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style()) + lora_patcher.patch() + + torch.manual_seed(task.get_seed()) + + images, has_nsfw = img2img_pipe.process( + prompt=prompt, + imageUrl=task.get_imageUrl(), + negative_prompt=[task.get_negative_prompt()] * num_return_sequences, + steps=task.get_steps(), + width=task.get_width(), + height=task.get_height(), + strength=task.get_i2i_strength(), + guidance_scale=task.get_i2i_guidance_scale(), + **lora_patcher.kwargs(), + ) + + generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId()) + + lora_patcher.cleanup() + + return { + "modified_prompts": prompt, + "generated_image_urls": generated_image_urls, + "has_nsfw": has_nsfw, + } + + +def model_fn(model_dir): + print("Logs: model loaded .... starts") + + set_root_dir(__file__) + + FailureHandler.register() + + avatar.load_local() + + prompt_modifier.load() + img2text.load() + img_classifier.load() + + lora_style.load(model_dir) + safety_checker.load() + + controlnet.load(model_dir) + text2img_pipe.load(model_dir) + img2img_pipe.create(text2img_pipe) + + safety_checker.apply(text2img_pipe) + safety_checker.apply(img2img_pipe) + safety_checker.apply(controlnet) + + print("Logs: model loaded ....") + return + + +@FailureHandler.clear +def predict_fn(data, pipe): + task = Task(data) + print("task is ", data) + + FailureHandler.handle(task) + + try: + # Set set_environment + set_configs_from_task(task) + + # Apply arguments + apply_style_args(data) + + # Re-fetch styles + lora_style.fetch_styles() + + # Fetch avatars + avatar.fetch_from_network(task.get_model_id()) + + task_type = task.get_type() + + if task_type == TaskType.TEXT_TO_IMAGE: + # character sheet + if "character sheet" in task.get_prompt().lower(): + return pose(task, s3_outkey="", poses=pickPoses()) + else: + return text2img(task) + elif task_type == TaskType.IMAGE_TO_IMAGE: + return img2img(task) + elif task_type == TaskType.CANNY: + return canny(task) + elif task_type == TaskType.POSE: + return pose(task) + elif task_type == TaskType.TILE_UPSCALE: + return tile_upscale(task) + else: + raise Exception("Invalid task type") + except Exception as e: + print(f"Error: {e}") + slack.error_alert(task, e) + return None diff --git a/inference2.py b/inference2.py new file mode 100644 index 0000000000000000000000000000000000000000..be81471b737a41e8b4d84f3566eae12dd19bc30a --- /dev/null +++ b/inference2.py @@ -0,0 +1,169 @@ +from io import BytesIO + +import torch + +from internals.data.dataAccessor import update_db +from internals.data.task import ModelType, Task, TaskType +from internals.pipelines.inpainter import InPainter +from internals.pipelines.object_remove import ObjectRemoval +from internals.pipelines.prompt_modifier import PromptModifier +from internals.pipelines.remove_background import RemoveBackground +from internals.pipelines.safety_checker import SafetyChecker +from internals.pipelines.upscaler import Upscaler +from internals.util.avatar import Avatar +from internals.util.cache import clear_cuda +from internals.util.commons import (construct_default_s3_url, upload_image, + upload_images) +from internals.util.config import set_configs_from_task, set_root_dir +from internals.util.failure_hander import FailureHandler +from internals.util.slack import Slack + +torch.backends.cudnn.benchmark = True +torch.backends.cuda.matmul.allow_tf32 = True + +num_return_sequences = 4 +auto_mode = False + +slack = Slack() + +prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences) +upscaler = Upscaler() +inpainter = InPainter() +safety_checker = SafetyChecker() +object_removal = ObjectRemoval() +avatar = Avatar() + + +@update_db +@slack.auto_send_alert +def remove_bg(task: Task): + remove_background = RemoveBackground() + output_image = remove_background.remove(task.get_imageUrl()) + + output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId()) + upload_image(output_image, output_key) + + return {"generated_image_url": construct_default_s3_url(output_key)} + + +@update_db +@slack.auto_send_alert +def inpaint(task: Task): + prompt = avatar.add_code_names(task.get_prompt()) + if task.is_prompt_engineering(): + prompt = prompt_modifier.modify(prompt) + else: + prompt = [prompt] * num_return_sequences + + print({"prompts": prompt}) + + images = inpainter.process( + prompt=prompt, + image_url=task.get_imageUrl(), + mask_image_url=task.get_maskImageUrl(), + width=task.get_width(), + height=task.get_height(), + seed=task.get_seed(), + negative_prompt=[task.get_negative_prompt()] * num_return_sequences, + ) + generated_image_urls = upload_images(images, "_inpaint", task.get_taskId()) + + clear_cuda() + + return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} + + +@update_db +@slack.auto_send_alert +def remove_object(task: Task): + output_key = "crecoAI/{}_object_remove.png".format(task.get_taskId()) + + images = object_removal.process( + image_url=task.get_imageUrl(), + mask_image_url=task.get_maskImageUrl(), + seed=task.get_seed(), + width=task.get_width(), + height=task.get_height(), + ) + generated_image_urls = upload_image(images[0], output_key) + + clear_cuda() + + return {"generated_image_urls": generated_image_urls} + + +@update_db +@slack.auto_send_alert +def upscale_image(task: Task): + output_key = "crecoAI/{}_upscale.png".format(task.get_taskId()) + out_img = None + if task.get_modelType() == ModelType.ANIME: + print("Using Anime model") + out_img = upscaler.upscale_anime( + image=task.get_imageUrl(), resize_dimension=task.get_resize_dimension() + ) + else: + print("Using Real model") + out_img = upscaler.upscale( + image=task.get_imageUrl(), resize_dimension=task.get_resize_dimension() + ) + + upload_image(BytesIO(out_img), output_key) + return {"generated_image_url": construct_default_s3_url(output_key)} + + +def model_fn(model_dir): + print("Logs: model loaded .... starts") + + set_root_dir(__file__) + + FailureHandler.register() + + avatar.load_local() + + prompt_modifier.load() + safety_checker.load() + + object_removal.load(model_dir) + upscaler.load() + inpainter.load() + + safety_checker.apply(inpainter) + + print("Logs: model loaded ....") + return + + +@FailureHandler.clear +def predict_fn(data, pipe): + task = Task(data) + print("task is ", data) + + FailureHandler.handle(task) + + # Set set_environment + set_configs_from_task(task) + + try: + # Set set_environment + set_configs_from_task(task) + + # Fetch avatars + avatar.fetch_from_network(task.get_model_id()) + + task_type = task.get_type() + + if task_type == TaskType.REMOVE_BG: + return remove_bg(task) + elif task_type == TaskType.INPAINT: + return inpaint(task) + elif task_type == TaskType.UPSCALE_IMAGE: + return upscale_image(task) + elif task_type == TaskType.OBJECT_REMOVAL: + return remove_object(task) + else: + raise Exception("Invalid task type") + except Exception as e: + print(f"Error: {e}") + slack.error_alert(task, e) + return None diff --git a/internals/__init__.py b/internals/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/internals/data/__init__.py b/internals/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/internals/data/dataAccessor.py b/internals/data/dataAccessor.py new file mode 100644 index 0000000000000000000000000000000000000000..d7b62ffd613beb783a3a5bd6d0a49533bda8d56e --- /dev/null +++ b/internals/data/dataAccessor.py @@ -0,0 +1,104 @@ +import traceback +from typing import Dict, List, Optional + +import requests +from pydash import includes + +from internals.data.task import Task +from internals.util.config import api_endpoint, api_headers +from internals.util.slack import Slack + + +def updateSource(sourceId, userId, state): + print("update source is called") + url = api_endpoint() + f"/comic-crecoai/source/{sourceId}" + headers = { + "Content-Type": "application/json", + "user-id": str(userId), + **api_headers(), + } + + data = {"state": state} + + try: + response = requests.patch(url, headers=headers, json=data, timeout=10) + print("update source response", response) + except requests.exceptions.Timeout: + print("Request timed out while updating source") + except requests.exceptions.RequestException as e: + print(f"Error while updating source: {e}") + + return + + +def saveGeneratedImages(sourceId, userId, has_nsfw: bool): + print("save generation called") + url = api_endpoint() + "/comic-crecoai/source/" + str(sourceId) + "/generatedImages" + headers = { + "Content-Type": "application/json", + "user-id": str(userId), + **api_headers(), + } + data = {"state": "ACTIVE", "has_nsfw": has_nsfw} + + try: + requests.patch(url, headers=headers, json=data) + # print("save generation response", response) + except requests.exceptions.Timeout: + print("Request timed out while saving image") + except requests.exceptions.RequestException as e: + print("Failed to mark source as active: ", e) + return + return + + +def getStyles() -> Optional[Dict]: + url = api_endpoint() + "/comic-crecoai/style" + try: + response = requests.get( + url, + timeout=10, + headers={"x-api-key": "kGyEMp)oHB(zf^E5>-{o]I%go", **api_headers()}, + ) + return response.json() + except requests.exceptions.Timeout: + print("Request timed out while fetching styles") + except requests.exceptions.RequestException as e: + print(f"Error while fetching styles: {e}") + return None + + +def getCharacters(model_id: str) -> Optional[List]: + url = api_endpoint() + "/comic-crecoai/model/{}".format(model_id) + try: + response = requests.get(url, timeout=10, headers=api_headers()) + response = response.json() + response = response["data"]["characters"] + return response + except requests.exceptions.Timeout: + print("Request timed out while fetching characters") + except Exception as e: + print(f"Error while fetching characters: {e}") + return None + + +def update_db(func): + def caller(*args, **kwargs): + if type(args[0]) is not Task: + raise Exception("First argument must be a Task object") + task = args[0] + try: + updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS") + rargs = func(*args, **kwargs) + has_nsfw = rargs.get("has_nsfw", False) + updateSource(task.get_sourceId(), task.get_userId(), "COMPLETED") + saveGeneratedImages(task.get_sourceId(), task.get_userId(), has_nsfw) + return rargs + except Exception as e: + print("Error processing image: {}".format(str(e))) + traceback.print_exc() + slack = Slack() + slack.error_alert(task, e) + updateSource(task.get_sourceId(), task.get_userId(), "FAILED") + + return caller diff --git a/internals/data/result.py b/internals/data/result.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d590c4f0e7be8a8e8f6f9a711558ac92e12e56 --- /dev/null +++ b/internals/data/result.py @@ -0,0 +1,19 @@ +from internals.util.config import get_nsfw_access + + +class Result: + images, nsfw = None, None + + def __init__(self, images, nsfw): + self.images = images + self.nsfw = nsfw + + @staticmethod + def from_result(result): + has_nsfw = result.nsfw_content_detected + if has_nsfw and isinstance(has_nsfw, list): + has_nsfw = any(has_nsfw) + + has_nsfw = ~get_nsfw_access() and has_nsfw + return (result.images, bool(has_nsfw)) + # return Result(result.images, result.has_nsfw_concepts) diff --git a/internals/data/task.py b/internals/data/task.py new file mode 100644 index 0000000000000000000000000000000000000000..89279b7ab60d8a7f3495d3729f22bbcc065f852c --- /dev/null +++ b/internals/data/task.py @@ -0,0 +1,125 @@ +from enum import Enum +from typing import Union + +import numpy as np + + +class TaskType(Enum): + TEXT_TO_IMAGE = "GENERATE_AI_IMAGE" + IMAGE_TO_IMAGE = "IMAGE_TO_IMAGE" + POSE = "POSE" + CANNY = "CANNY" + REMOVE_BG = "REMOVE_BG" + INPAINT = "INPAINT" + UPSCALE_IMAGE = "UPSCALE_IMAGE" + TILE_UPSCALE = "TILE_UPSCALE" + OBJECT_REMOVAL = "OBJECT_REMOVAL" + + +class ModelType(Enum): + REAL = 10000 + ANIME = 10001 + COMIC = 10002 + + +class Task: + def __init__(self, data): + self.__data = data + if data.get("seed", -1) == None or self.get_seed() == -1: + self.__data["seed"] = np.random.randint(0, np.iinfo(np.int64).max) + prompt = data.get("prompt", "") + if prompt is None: + self.__data["prompt"] = "" + else: + self.__data["prompt"] = data.get("prompt", "")[:200] + + def get_taskId(self) -> str: + return self.__data.get("task_id") + + def get_sourceId(self) -> str: + return self.__data.get("source_id") + + def get_imageUrl(self) -> str: + return self.__data.get("imageUrl", None) + + def get_prompt(self) -> str: + return self.__data.get("prompt", "") + + def get_userId(self) -> str: + return self.__data.get("userId", "") + + def get_email(self) -> str: + return self.__data.get("email", "") + + def get_style(self) -> str: + return self.__data.get("style", None) + + def get_iteration(self) -> float: + return float(self.__data.get("iteration", 3.0)) + + def get_modelType(self) -> ModelType: + id = self.get_model_id() + return ModelType(id) + + def get_model_id(self) -> int: + return int(self.__data.get("modelId", 10000)) + + def get_width(self) -> int: + return int(self.__data.get("width", 512)) + + def get_height(self) -> int: + return int(self.__data.get("height", 512)) + + def get_seed(self) -> int: + return int(self.__data.get("seed", -1)) + + def get_steps(self) -> int: + return int(self.__data.get("steps", "75")) + + def get_type(self) -> Union[TaskType, None]: + try: + return TaskType(self.__data.get("task_type")) + except ValueError: + return None + + def get_maskImageUrl(self) -> str: + return self.__data.get("maskImageUrl") + + def get_negative_prompt(self) -> str: + return self.__data.get("negative_prompt", "") + + def is_prompt_engineering(self) -> bool: + return self.__data.get("auto_mode", True) + + def get_queue_name(self) -> str: + return self.__data.get("queue_name", "") + + def get_resize_dimension(self) -> int: + return self.__data.get("resize_dimension", 1024) + + def get_ti_guidance_scale(self) -> float: + return self.__data.get("ti_guidance_scale", 7.5) + + def get_i2i_guidance_scale(self) -> float: + return self.__data.get("i2i_guidance_scale", 7.5) + + def get_i2i_strength(self) -> float: + return self.__data.get("i2i_strength", 0.75) + + def get_cy_guidance_scale(self) -> float: + return self.__data.get("cy_guidance_scale", 9) + + def get_po_guidance_scale(self) -> float: + return self.__data.get("po_guidance_scale", 7.5) + + def get_nsfw_threshold(self) -> float: + return self.__data.get("nsfw_threshold", 0.03) + + def can_access_nsfw(self) -> bool: + return self.__data.get("can_access_nsfw", False) + + def get_access_token(self) -> str: + return self.__data.get("access_token", "") + + def get_raw(self) -> dict: + return self.__data.copy() diff --git a/internals/pipelines/commons.py b/internals/pipelines/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..47fb890f801fd101067ac1d25bbab0e019b11a5f --- /dev/null +++ b/internals/pipelines/commons.py @@ -0,0 +1,119 @@ +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from diffusers import StableDiffusionImg2ImgPipeline + +from internals.data.result import Result +from internals.pipelines.twoStepPipeline import two_step_pipeline +from internals.util.commons import disable_safety_checker, download_image + + +class AbstractPipeline: + def load(self, model_dir: str): + pass + + def create(self, pipe): + pass + + +class Text2Img(AbstractPipeline): + def load(self, model_dir: str): + self.pipe = two_step_pipeline.from_pretrained( + model_dir, torch_dtype=torch.float16 + ).to("cuda") + self.__patch() + + def create(self, pipeline: AbstractPipeline): + self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda") + self.__patch() + + def __patch(self): + self.pipe.enable_xformers_memory_efficient_attention() + + @torch.inference_mode() + def process( + self, + prompt: Union[str, List[str]] = None, + modified_prompts: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + iteration: float = 3.0, + ): + result = self.pipe.two_step_pipeline( + prompt=prompt, + modified_prompts=modified_prompts, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + eta=eta, + generator=generator, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + output_type=output_type, + return_dict=return_dict, + callback=callback, + callback_steps=callback_steps, + cross_attention_kwargs=cross_attention_kwargs, + iteration=iteration, + ) + return Result.from_result(result) + + +class Img2Img(AbstractPipeline): + def load(self, model_dir: str): + self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + model_dir, torch_dtype=torch.float16 + ).to("cuda") + self.__patch() + + def create(self, pipeline: AbstractPipeline): + self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to( + "cuda" + ) + self.__patch() + + def __patch(self): + self.pipe.enable_xformers_memory_efficient_attention() + + @torch.inference_mode() + def process( + self, + prompt: List[str], + imageUrl: str, + negative_prompt: List[str], + strength: float, + guidance_scale: float, + steps: int, + width: int, + height: int, + ): + image = download_image(imageUrl).resize((width, height)) + + result = self.pipe.__call__( + prompt=prompt, + image=image, + strength=strength, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + num_images_per_prompt=1, + num_inference_steps=steps, + ) + return Result.from_result(result) diff --git a/internals/pipelines/controlnets.py b/internals/pipelines/controlnets.py new file mode 100644 index 0000000000000000000000000000000000000000..f2a2982f13e579cf929791c04c11d218c91a9793 --- /dev/null +++ b/internals/pipelines/controlnets.py @@ -0,0 +1,221 @@ +from typing import List + +import cv2 +import numpy as np +import torch +from controlnet_aux import OpenposeDetector +from diffusers import ( + ControlNetModel, + DiffusionPipeline, + StableDiffusionControlNetPipeline, + UniPCMultistepScheduler, +) +from PIL import Image +from tqdm import gui + +from internals.data.result import Result +from internals.pipelines.commons import AbstractPipeline +from internals.util.cache import clear_cuda_and_gc +from internals.util.commons import download_image + + +class ControlNet(AbstractPipeline): + __current_task_name = "" + + def load(self, model_dir: str): + # we will load canny by default + self.load_canny() + + # controlnet pipeline for canny and pose + pipe = DiffusionPipeline.from_pretrained( + model_dir, + controlnet=self.controlnet, + torch_dtype=torch.float16, + custom_pipeline="stable_diffusion_controlnet_img2img", + ).to("cuda") + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + pipe.enable_xformers_memory_efficient_attention() + self.pipe = pipe + + # controlnet pipeline for tile upscaler + pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda") + pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config) + pipe2.enable_xformers_memory_efficient_attention() + self.pipe2 = pipe2 + + def load_canny(self): + if self.__current_task_name == "canny": + return + canny = ControlNetModel.from_pretrained( + "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16 + ).to("cuda") + self.__current_task_name = "canny" + self.controlnet = canny + if hasattr(self, "pipe"): + self.pipe.controlnet = canny + if hasattr(self, "pipe2"): + self.pipe2.controlnet = canny + clear_cuda_and_gc() + + def load_pose(self): + if self.__current_task_name == "pose": + return + pose = ControlNetModel.from_pretrained( + "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16 + ).to("cuda") + self.__current_task_name = "pose" + self.controlnet = pose + if hasattr(self, "pipe"): + self.pipe.controlnet = pose + if hasattr(self, "pipe2"): + self.pipe2.controlnet = pose + clear_cuda_and_gc() + + def load_tile_upscaler(self): + if self.__current_task_name == "tile_upscaler": + return + tile_upscaler = ControlNetModel.from_pretrained( + "lllyasviel/control_v11f1e_sd15_tile", torch_dtype=torch.float16 + ).to("cuda") + self.__current_task_name = "tile_upscaler" + self.controlnet = tile_upscaler + if hasattr(self, "pipe"): + self.pipe.controlnet = tile_upscaler + if hasattr(self, "pipe2"): + self.pipe2.controlnet = tile_upscaler + clear_cuda_and_gc() + + def cleanup(self): + self.pipe.controlnet = None + self.pipe2.controlnet = None + self.controlnet = None + self.__current_task_name = "" + + clear_cuda_and_gc() + + @torch.inference_mode() + def process_canny( + self, + prompt: List[str], + imageUrl: str, + seed: int, + steps: int, + negative_prompt: List[str], + guidance_scale: float, + height: int, + width: int, + ): + if self.__current_task_name != "canny": + raise Exception("ControlNet is not loaded with canny model") + + torch.manual_seed(seed) + + init_image = download_image(imageUrl).resize((width, height)) + init_image = self.__canny_detect_edge(init_image) + + result = self.pipe2.__call__( + prompt=prompt, + image=init_image, + guidance_scale=guidance_scale, + num_images_per_prompt=1, + negative_prompt=negative_prompt, + num_inference_steps=steps, + height=height, + width=width, + ) + return Result.from_result(result) + + @torch.inference_mode() + def process_pose( + self, + prompt: List[str], + image: List[Image.Image], + seed: int, + steps: int, + guidance_scale: float, + negative_prompt: List[str], + height: int, + width: int, + ): + if self.__current_task_name != "pose": + raise Exception("ControlNet is not loaded with pose model") + + torch.manual_seed(seed) + + result = self.pipe2.__call__( + prompt=prompt, + image=image, + num_images_per_prompt=1, + num_inference_steps=steps, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + height=height, + width=width, + ) + return Result.from_result(result) + + @torch.inference_mode() + def process_tile_upscaler( + self, + imageUrl: str, + prompt: str, + negative_prompt: str, + steps: int, + seed: int, + height: int, + width: int, + resize_dimension: int, + guidance_scale: float, + ): + if self.__current_task_name != "tile_upscaler": + raise Exception("ControlNet is not loaded with tile_upscaler model") + + torch.manual_seed(seed) + + init_image = download_image(imageUrl).resize((width, height)) + condition_image = self.__resize_for_condition_image( + init_image, resize_dimension + ) + + result = self.pipe.__call__( + image=condition_image, + prompt=prompt, + controlnet_conditioning_image=condition_image, + num_inference_steps=steps, + negative_prompt=negative_prompt, + height=condition_image.size[1], + width=condition_image.size[0], + strength=1.0, + guidance_scale=guidance_scale, + ) + return Result.from_result(result) + + def detect_pose(self, imageUrl: str) -> Image.Image: + detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") + image = download_image(imageUrl) + image = detector.__call__(image, hand_and_face=True) + return image + + def __canny_detect_edge(self, image: Image.Image) -> Image.Image: + image_array = np.array(image) + + low_threshold = 100 + high_threshold = 200 + + image_array = cv2.Canny(image_array, low_threshold, high_threshold) + image_array = image_array[:, :, None] + image_array = np.concatenate([image_array, image_array, image_array], axis=2) + canny_image = Image.fromarray(image_array) + return canny_image + + def __resize_for_condition_image(self, image: Image.Image, resolution: int): + input_image = image.convert("RGB") + W, H = input_image.size + k = float(resolution) / min(W, H) + H *= k + W *= k + H = int(round(H / 64.0)) * 64 + W = int(round(W / 64.0)) * 64 + img = input_image.resize((W, H), resample=Image.LANCZOS) + return img diff --git a/internals/pipelines/img_classifier.py b/internals/pipelines/img_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..aca47585b9762c450aac0e5554fa57aecfe1c1ea --- /dev/null +++ b/internals/pipelines/img_classifier.py @@ -0,0 +1,24 @@ +from typing import List + +from transformers import pipeline + +from internals.util.commons import download_image + + +class ImageClassifier: + def __init__(self, candidates: List[str] = ["realistic", "anime", "comic"]): + self.__candidates = candidates + + def load(self): + self.pipe = pipeline( + "zero-shot-image-classification", + model="philschmid/clip-zero-shot-image-classification", + ) + + def classify(self, image_url: str, width: int, height: int) -> str: + image = download_image(image_url).resize((width, height)) + results = self.pipe.__call__([image], candidate_labels=self.__candidates) + results = results[0] + if len(results) > 0: + return results[0]["label"] + return "" diff --git a/internals/pipelines/img_to_text.py b/internals/pipelines/img_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..0e8354eb16e5db2fa6d4a09746c2f2e13507b81e --- /dev/null +++ b/internals/pipelines/img_to_text.py @@ -0,0 +1,31 @@ +import re + +import torch +from torchvision import transforms +from transformers import BlipForConditionalGeneration, BlipProcessor + +from internals.util.commons import download_image + + +class Image2Text: + def load(self): + self.processor = BlipProcessor.from_pretrained( + "Salesforce/blip-image-captioning-large" + ) + self.model = BlipForConditionalGeneration.from_pretrained( + "Salesforce/blip-image-captioning-large", torch_dtype=torch.float16 + ).to("cuda") + + def process(self, imageUrl: str) -> str: + image = download_image(imageUrl).resize((512, 512)) + inputs = self.processor.__call__(image, return_tensors="pt").to( + "cuda", torch.float16 + ) + output_ids = self.model.generate( + **inputs, do_sample=False, top_p=0.9, max_length=128 + ) + output_text = self.processor.batch_decode(output_ids) + print(output_text) + output_text = output_text[0] + output_text = re.sub("|\\n|\[SEP\]", "", output_text) + return output_text diff --git a/internals/pipelines/inpainter.py b/internals/pipelines/inpainter.py new file mode 100644 index 0000000000000000000000000000000000000000..04fb2abb3e3371811f7415b24fff745e8c387fb7 --- /dev/null +++ b/internals/pipelines/inpainter.py @@ -0,0 +1,41 @@ +from typing import List, Union + +import torch +from diffusers import StableDiffusionInpaintPipeline + +from internals.pipelines.commons import AbstractPipeline +from internals.util.commons import disable_safety_checker, download_image + + +class InPainter(AbstractPipeline): + def load(self): + self.pipe = StableDiffusionInpaintPipeline.from_pretrained( + "jayparmr/icbinp_v8_inpaint_v2", + torch_dtype=torch.float16, + ).to("cuda") + disable_safety_checker(self.pipe) + + @torch.inference_mode() + def process( + self, + image_url: str, + mask_image_url: str, + width: int, + height: int, + seed: int, + prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str]], + ): + torch.manual_seed(seed) + + input_img = download_image(image_url).resize((width, height)) + mask_img = download_image(mask_image_url).resize((width, height)) + + return self.pipe.__call__( + prompt=prompt, + image=input_img, + mask_image=mask_img, + height=height, + width=width, + negative_prompt=negative_prompt, + ).images diff --git a/internals/pipelines/object_remove.py b/internals/pipelines/object_remove.py new file mode 100644 index 0000000000000000000000000000000000000000..0003a1bca1ef6bd153dc68ea84a0b8e32df12d22 --- /dev/null +++ b/internals/pipelines/object_remove.py @@ -0,0 +1,82 @@ +import os +from pathlib import Path +from typing import List + +import cv2 +import numpy as np +import torch +import tqdm +from omegaconf import OmegaConf +from PIL import Image +from torch.utils.data._utils.collate import default_collate + +from internals.util.commons import download_file, download_image +from internals.util.config import get_root_dir +from saicinpainting.evaluation.utils import move_to_device +from saicinpainting.training.data.datasets import make_default_val_dataset +from saicinpainting.training.trainers import load_checkpoint + + +class ObjectRemoval: + def load(self, model_dir): + print("Downloading LAMA model...") + + self.lama_path = Path.home() / ".cache" / "lama" + + out_file = self.lama_path / "models" / "best.ckpt" + os.makedirs(os.path.dirname(out_file), exist_ok=True) + download_file( + "https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt", out_file + ) + config = OmegaConf.load(get_root_dir() + "/config.yml") + config.training_model.predict_only = True + self.model = load_checkpoint( + config, str(out_file), strict=False, map_location="cuda" + ) + self.model.freeze() + self.model.to("cuda") + + @torch.no_grad() + def process( + self, + image_url: str, + mask_image_url: str, + seed: int, + width: int, + height: int, + ) -> List: + torch.manual_seed(seed) + + img_folder = self.lama_path / "images" + indir = img_folder / "input" + + img_folder.mkdir(parents=True, exist_ok=True) + indir.mkdir(parents=True, exist_ok=True) + + download_image(image_url).resize((width, height)).save(indir / "data.png") + download_image(mask_image_url).resize((width, height)).save( + indir / "data_mask.png" + ) + + dataset = make_default_val_dataset( + img_folder / "input", img_suffix=".png", pad_out_to_modulo=8 + ) + + out_images = [] + for img_i in tqdm.trange(len(dataset)): + batch = move_to_device(default_collate([dataset[img_i]]), "cuda") + batch["mask"] = (batch["mask"] > 0) * 1 + batch = self.model(batch) + out_path = str(img_folder / "out.png") + + cur_res = batch["inpainted"][0].permute(1, 2, 0).detach().cpu().numpy() + + cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) + cv2.imwrite(out_path, cur_res) + + image = Image.open(out_path).convert("RGB") + out_images.append(image) + os.remove(out_path) + + return out_images diff --git a/internals/pipelines/prompt_modifier.py b/internals/pipelines/prompt_modifier.py new file mode 100644 index 0000000000000000000000000000000000000000..2093fedcbb45eece0cd6adada3273eaa12e7034e --- /dev/null +++ b/internals/pipelines/prompt_modifier.py @@ -0,0 +1,54 @@ +from typing import List, Optional + +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + + +class PromptModifier: + def __init__(self, num_of_sequences: Optional[int] = 4): + self.__blacklist = {"alphonse mucha": "", "adolphe bouguereau": ""} + self.__num_of_sequences = num_of_sequences + + def load(self): + self.prompter_model = AutoModelForCausalLM.from_pretrained( + "Gustavosta/MagicPrompt-Stable-Diffusion" + ) + self.prompter_tokenizer = AutoTokenizer.from_pretrained( + "Gustavosta/MagicPrompt-Stable-Diffusion" + ) + self.prompter_tokenizer.pad_token = self.prompter_tokenizer.eos_token + self.prompter_tokenizer.padding_side = "left" + + def modify(self, text: str) -> List[str]: + eos_id = self.prompter_tokenizer.eos_token_id + # restricted_words_list = ["octane", "cyber"] + # restricted_words_token_ids = prompter_tokenizer( + # restricted_words_list, add_special_tokens=False + # ).input_ids + + generation_config = GenerationConfig( + do_sample=False, + max_new_tokens=75, + num_beams=4, + num_return_sequences=self.__num_of_sequences, + eos_token_id=eos_id, + pad_token_id=eos_id, + length_penalty=-1.0, + ) + + input_ids = self.prompter_tokenizer(text.strip(), return_tensors="pt").input_ids + outputs = self.prompter_model.generate( + input_ids, generation_config=generation_config + ) + output_texts = self.prompter_tokenizer.batch_decode( + outputs, skip_special_tokens=True + ) + output_texts = self.__patch_blacklist_words(output_texts) + return output_texts + + def __patch_blacklist_words(self, texts: List[str]): + def replace_all(text, dic): + for i, j in dic.items(): + text = text.replace(i, j) + return text + + return [replace_all(text, self.__blacklist) for text in texts] diff --git a/internals/pipelines/remove_background.py b/internals/pipelines/remove_background.py new file mode 100644 index 0000000000000000000000000000000000000000..be65bf674c6a5887202d505f594bf68428748485 --- /dev/null +++ b/internals/pipelines/remove_background.py @@ -0,0 +1,16 @@ +import io +from typing import Union + +from PIL import Image +from rembg import remove + +from internals.util.commons import read_url + + +class RemoveBackground: + def remove(self, image: Union[str, Image.Image]) -> Image.Image: + if type(image) is str: + image = Image.open(io.BytesIO(read_url(image))) + + output = remove(image) + return output diff --git a/internals/pipelines/safety_checker.py b/internals/pipelines/safety_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..27e069f83f4205c743eb99ea3cf01efd506f99ce --- /dev/null +++ b/internals/pipelines/safety_checker.py @@ -0,0 +1,163 @@ +from re import L + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel + +from internals.pipelines.commons import AbstractPipeline +from internals.util.config import get_nsfw_access, get_nsfw_threshold + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) + + +class SafetyChecker: + def load(self): + self.model = StableDiffusionSafetyCheckerV2.from_pretrained( + "CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16 + ).to("cuda") + + def apply(self, pipeline: AbstractPipeline): + if hasattr(pipeline, "pipe"): + pipeline.pipe.safety_checker = self.model + if hasattr(pipeline, "pipe2"): + pipeline.pipe2.safety_checker = self.model + + +class StableDiffusionSafetyCheckerV2(PreTrainedModel): + config_class = CLIPConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear( + config.vision_config.hidden_size, config.projection_dim, bias=False + ) + + self.concept_embeds = nn.Parameter( + torch.ones(17, config.projection_dim), requires_grad=False + ) + self.special_care_embeds = nn.Parameter( + torch.ones(3, config.projection_dim), requires_grad=False + ) + + self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) + self.special_care_embeds_weights = nn.Parameter( + torch.ones(3), requires_grad=False + ) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + special_cos_dist = ( + cosine_distance(image_embeds, self.special_care_embeds) + .cpu() + .float() + .numpy() + ) + cos_dist = ( + cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + ) + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = { + "special_scores": {}, + "special_care": [], + "concept_scores": {}, + "bad_concepts": [], + } + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round( + concept_cos - concept_threshold + adjustment, 3 + ) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append( + {concept_idx, result_img["special_scores"][concept_idx]} + ) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round( + concept_cos - concept_threshold + adjustment, 3 + ) + if result_img["concept_scores"][concept_idx] > get_nsfw_threshold(): + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + # Blur images based on NSFW score + # ------------------------------- + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if any(has_nsfw_concepts) and not get_nsfw_access(): + if torch.is_tensor(images) or torch.is_tensor(images[0]): + image = images[idx].cpu().numpy().astype(np.float32) + image = cv2.blur(image, (30, 30)) + image = torch.from_numpy(image) + images[idx] = image + else: + images[idx] = cv2.blur(images[idx], (30, 30)) + + if any(has_nsfw_concepts): + print("NSFW") + + return images, has_nsfw_concepts + + @torch.no_grad() + def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = cosine_distance(image_embeds, self.concept_embeds) + + # increase this value to create a stronger `nsfw` filter + # at the cost of increasing the possibility of filtering benign images + adjustment = 0.0 + + special_scores = ( + special_cos_dist - self.special_care_embeds_weights + adjustment + ) + # special_scores = special_scores.round(decimals=3) + special_care = torch.any(special_scores > 0, dim=1) + special_adjustment = special_care * 0.01 + special_adjustment = special_adjustment.unsqueeze(1).expand( + -1, cos_dist.shape[1] + ) + + concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment + # concept_scores = concept_scores.round(decimals=3) + has_nsfw_concepts = torch.any(concept_scores > get_nsfw_threshold(), dim=1) + + # Blur images based on NSFW score + # ------------------------------- + if not get_nsfw_access(): + image = images[has_nsfw_concepts].cpu().numpy().astype(np.float32) + image = cv2.blur(image, (30, 30)) + image = torch.from_numpy(image) + images[has_nsfw_concepts] = image + + return images, has_nsfw_concepts diff --git a/internals/pipelines/twoStepPipeline.py b/internals/pipelines/twoStepPipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4822ff9ddff6890f47bc1a1f81bc3ac28a3392 --- /dev/null +++ b/internals/pipelines/twoStepPipeline.py @@ -0,0 +1,252 @@ +import torch +from diffusers import StableDiffusionPipeline + +torch.backends.cudnn.benchmark = True +torch.backends.cuda.matmul.allow_tf32 = True + +from typing import Any, Callable, Dict, List, Optional, Union + +from diffusers import StableDiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + + +class two_step_pipeline(StableDiffusionPipeline): + @torch.no_grad() + def two_step_pipeline( + self, + prompt: Union[str, List[str]] = None, + modified_prompts: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + iteration: float = 3.0, + ): + r""" + Function invoked when calling the pipeline for generation. + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + Examples: + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + modified_embeds = self._encode_prompt( + modified_prompts, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + print("mod prompt size : ", modified_embeds.size(), modified_embeds.dtype) + + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + print("prompt size : ", prompt_embeds.size(), prompt_embeds.dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if i == int(len(timesteps) / iteration): + print("modified prompts") + prompt_embeds = modified_embeds + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) diff --git a/internals/pipelines/upscaler.py b/internals/pipelines/upscaler.py new file mode 100644 index 0000000000000000000000000000000000000000..462e7985c4f76b83ebedd067928fd568f0b12496 --- /dev/null +++ b/internals/pipelines/upscaler.py @@ -0,0 +1,91 @@ +import math +import os +from pathlib import Path +from typing import Union + +import cv2 +import numpy as np +from basicsr.archs.rrdbnet_arch import RRDBNet +from basicsr.utils.download_util import load_file_from_url +from PIL import Image +from realesrgan import RealESRGANer + +import internals.util.image as ImageUtil +from internals.util.commons import download_image + + +class Upscaler: + __model_esrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth" + __model_esrgan_anime_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth" + + def load(self): + download_dir = Path(Path.home() / ".cache" / "realesrgan") + download_dir.mkdir(parents=True, exist_ok=True) + + self.__model_path = self.__preload_model(self.__model_esrgan_url, download_dir) + self.__model_path_anime = self.__preload_model( + self.__model_esrgan_anime_url, download_dir + ) + + def upscale(self, image: Union[str, Image.Image], resize_dimension: int) -> bytes: + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4, + ) + return self.__internal_upscale( + image, resize_dimension, self.__model_path, model + ) + + def upscale_anime( + self, image: Union[str, Image.Image], resize_dimension: int + ) -> bytes: + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4, + ) + return self.__internal_upscale( + image, resize_dimension, self.__model_path_anime, model + ) + + def __preload_model(self, url: str, download_dir: Path): + name = url.split("/")[-1] + if not os.path.exists(str(download_dir / name)): + return load_file_from_url( + url=url, + model_dir=str(download_dir), + progress=True, + file_name=None, + ) + else: + return str(download_dir / name) + + def __internal_upscale( + self, + image, + resize_dimension: int, + model_path: str, + rrbdnet: RRDBNet, + ) -> bytes: + if type(image) is str: + image = download_image(image) + image = ImageUtil.resize_image_to512(image) + image = ImageUtil.to_bytes(image) + + upsampler = RealESRGANer( + scale=4, model_path=model_path, model=rrbdnet, half="fp16", gpu_id="0" + ) + image_array = np.frombuffer(image, dtype=np.uint8) + input_image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) + dimension = min(input_image.shape[0], input_image.shape[1]) + scale = max(math.floor(resize_dimension / dimension), 2) + output, _ = upsampler.enhance(input_image, outscale=scale) + out_bytes = cv2.imencode(".png", output)[1].tobytes() + return out_bytes diff --git a/internals/util/__init__.py b/internals/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/internals/util/args.py b/internals/util/args.py new file mode 100644 index 0000000000000000000000000000000000000000..eae0dd06ce2bb3e7057a4981338448c50d8c75a3 --- /dev/null +++ b/internals/util/args.py @@ -0,0 +1,13 @@ +import re +from typing import Dict + + +def apply_style_args(data: Dict): + prompt = data.get("prompt", None) + if prompt is None: + return + result = re.match(r"\[style:(.*?)\]", prompt) + if result is not None: + style = result.group(1) + data["style"] = style + data["prompt"] = prompt.replace(f"[style:{style}]", "").strip() diff --git a/internals/util/avatar.py b/internals/util/avatar.py new file mode 100644 index 0000000000000000000000000000000000000000..94bef951e1cd6927c2c34eebe823e895846e671e --- /dev/null +++ b/internals/util/avatar.py @@ -0,0 +1,59 @@ +import json +import os +import re + +from internals.data.dataAccessor import getCharacters +from internals.util.config import root_dir + + +class Avatar: + __avatars = {} + + def load_local(self): + self.__find_available_characters(root_dir) + if len(self.__avatars.items()) > 0: + print("Local characters", self.__avatars) + + def fetch_from_network(self, model_id: int): + characters = getCharacters(str(model_id)) + if characters is not None: + for character in characters: + item = { + "avatarName": str(character["title"]).lower(), + "codename": character["tag"], + "extraPrompt": character["extraData"]["extraPrompt"], + } + self.__avatars[item["avatarName"]] = item + + def add_code_names(self, prompt): + array_of_objects = self.__avatars.values() + + for obj in array_of_objects: + prompt = ( + re.sub( + r"\b" + obj["avatarName"] + r"\b", + obj["extraPrompt"], + prompt, + flags=re.IGNORECASE, + ) + + " " + ) + print(prompt) + return prompt + + def __find_available_characters(self, path: str): + if os.path.exists(path + "/characters.json"): + print(path) + try: + print("Loading characters") + with open(path + "/characters.json") as f: + data = json.load(f) + print("Characters: ", data) + if "avatarName" in data[0]: + for item in data: + self.__avatars[item["avatarName"]] = item + print("Avatars", self.__avatars) + else: + print("Invalid characters.json file") + except Exception as e: + print("Error Loading characters", e) diff --git a/internals/util/cache.py b/internals/util/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..d80e5fe0090c505b3609c8866597e8c09878fcae --- /dev/null +++ b/internals/util/cache.py @@ -0,0 +1,31 @@ +import gc + +import torch + + +def clear_cuda_and_gc(): + clear_cuda() + clear_gc() + + +def clear_cuda(): + torch.cuda.empty_cache() + + +def clear_gc(): + gc.collect() + + +def auto_clear_cuda_and_gc(controlnet): + def auto_clear_cuda_and_gc_wrapper(func): + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + controlnet.cleanup() + clear_cuda_and_gc() + raise e + + return wrapper + + return auto_clear_cuda_and_gc_wrapper diff --git a/internals/util/commons.py b/internals/util/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..355ede3467d721d120c662f1103856ac981978e6 --- /dev/null +++ b/internals/util/commons.py @@ -0,0 +1,203 @@ +import json +import os +import pprint +import random +import re +from io import BytesIO +from pathlib import Path +from typing import Union + +import boto3 +import requests + +from internals.util.config import api_endpoint, api_headers + +s3 = boto3.client("s3") +import io +import urllib.request + +from PIL import Image + +black_list = {"alphonse mucha": "", "adolphe bouguereau": ""} +pp = pprint.PrettyPrinter(indent=4) + +webhook_url = ( + "https://hooks.slack.com/services/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" +) +error_webhook = ( + "https://hooks.slack.com/services/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" +) + +characterSheets = [ + "character+sheets/1.1.png", + "character+sheets/10.1.png", + "character+sheets/11.1.png", + "character+sheets/12.1.png", + "character+sheets/13.1.png", + "character+sheets/14.1.png", + "character+sheets/16.1.png", + "character+sheets/17.1.png", + "character+sheets/18.1.png", + "character+sheets/19.1.png", + "character+sheets/2.1.png", + "character+sheets/20.1.png", + "character+sheets/21.1.png", + "character+sheets/22.1.png", + "character+sheets/23.1.png", + "character+sheets/24.1.png", + "character+sheets/25.1.png", + "character+sheets/26.1.png", + "character+sheets/27.1.png", + "character+sheets/28.1.png", + "character+sheets/29.1.png", + "character+sheets/3.1.png", + "character+sheets/30.1.png", + "character+sheets/31.1.png", + "character+sheets/32.1.png", + "character+sheets/33.1.png", + "character+sheets/34.1.png", + "character+sheets/35.1.png", + "character+sheets/36.1.png", + "character+sheets/38.1.png", + "character+sheets/39.1.png", + "character+sheets/4.1.png", + "character+sheets/40.1.png", + "character+sheets/42.1.png", + "character+sheets/43.1.png", + "character+sheets/44.1.png", + "character+sheets/45.1.png", + "character+sheets/46.1.png", + "character+sheets/47.1.png", + "character+sheets/48.1.png", + "character+sheets/49.1.png", + "character+sheets/5.1.png", + "character+sheets/50.1.png", + "character+sheets/51.1.png", + "character+sheets/52.1.png", + "character+sheets/53.1.png", + "character+sheets/54.1.png", + "character+sheets/55.1.png", + "character+sheets/56.1.png", + "character+sheets/57.1.png", + "character+sheets/58.1.png", + "character+sheets/59.1.png", + "character+sheets/60.1.png", + "character+sheets/61.1.png", + "character+sheets/62.1.png", + "character+sheets/63.1.png", + "character+sheets/64.1.png", + "character+sheets/65.1.png", + "character+sheets/66.1.png", + "character+sheets/7.1.png", + "character+sheets/8.1.png", + "character+sheets/9.1.png", +] + + +def upload_images(images, processName: str, taskId: str): + imageUrls = [] + for i, image in enumerate(images): + img_io = BytesIO() + image.save(img_io, "JPEG", quality=100) + img_io.seek(0) + key = "crecoAI/{}{}_{}.png".format(taskId, processName, i) + requests.post( + api_endpoint() + + "/comic-content/v1.0/upload/crecoai-assets-2?fileName=" + + "{}{}_{}.png".format(taskId, processName, i), + headers=api_headers(), + files={"file": ("image.png", img_io, "image/png")}, + ) + # t = s3.put_object( + # Bucket="comic-assets", Key=key, Body=img_io.getvalue(), ACL="public-read" + # ) + # print("uploading done to s3", key, t) + imageUrls.append( + "https://comic-assets.s3.ap-south-1.amazonaws.com/crecoAI/{}{}_{}.png".format( + taskId, processName, i + ) + ) + + print({"promptImages": imageUrls}) + + return imageUrls + + +def upload_image(image: Union[Image.Image, BytesIO], out_path): + if type(image) is Image.Image: + buffer = io.BytesIO() + image.save(buffer, format="PNG") + image = buffer + + image.seek(0) + requests.post( + api_endpoint() + + "/comic-content/v1.0/upload/crecoai-assets-2?fileName=" + + str(out_path).replace("crecoAI/", ""), + headers=api_headers(), + files={"file": ("image.png", image, "image/png")}, + ) + # s3.upload_fileobj(image, "comic-assets", out_path, ExtraArgs={"ACL": "public-read"}) + image.close() + + image_url = "https://comic-assets.s3.ap-south-1.amazonaws.com/" + out_path + print({"promptImages": image_url}) + + return image_url + + +def download_image(url) -> Image.Image: + response = requests.get(url) + return Image.open(BytesIO(response.content)).convert("RGB") + + +def download_file(url, out_path: Path): + with requests.get(url, stream=True) as r: + r.raise_for_status() + with open(out_path, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + +def pickPoses(): + random_images = random.sample(characterSheets, 4) + poses = [] + prefix = "https://comic-assets.s3.ap-south-1.amazonaws.com/" + + # Use list comprehension to add prefix to all elements in the array + random_images_with_prefix = [prefix + img for img in random_images] + + print(random_images_with_prefix) + for imageUrl in random_images_with_prefix: + # Download and resize the image + init_image = download_image(imageUrl).resize((512, 512)) + + # Open the pose image + imageUrlPose = imageUrl + # print(imageUrl) + input_image_bytes = read_url(imageUrlPose) + # print(input_image_bytes) + pose_image = Image.open(io.BytesIO(input_image_bytes)).convert("RGB") + # print(pose_image) + pose_image = pose_image.resize((512, 512)) + # print(pose_image) + # Append the result to the poses array + poses.append(pose_image) + + return poses + + +def construct_default_s3_url(key): + return "https://comic-assets.s3.ap-south-1.amazonaws.com/" + key + + +def read_url(url: str): + with urllib.request.urlopen(url) as u: + return u.read() + + +def disable_safety_checker(pipe): + def dummy(images, **kwargs): + return images, False + + pipe.safety_checker = None diff --git a/internals/util/config.py b/internals/util/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c641b3e213da760e2885f110b06b08d329ec0379 --- /dev/null +++ b/internals/util/config.py @@ -0,0 +1,66 @@ +import os + +from internals.data.task import Task + +env = "gamma" +nsfw_threshold = 0.0 +nsfw_access = False +access_token = "" +root_dir = "" + + +def set_root_dir(main_file: str): + global root_dir + root_dir = os.path.dirname(os.path.abspath(main_file)) + + +def set_configs_from_task(task: Task): + global env, nsfw_threshold, nsfw_access, access_token + name = task.get_queue_name() + if name.startswith("prod"): + env = "prod" + else: + env = "gamma" + nsfw_threshold = task.get_nsfw_threshold() + nsfw_access = task.can_access_nsfw() + access_token = task.get_access_token() + + +def get_root_dir(): + global root_dir + return root_dir + + +def get_environment(): + global env + return env + + +def get_nsfw_threshold(): + global nsfw_threshold + return nsfw_threshold + + +def get_nsfw_access(): + global nsfw_access + return nsfw_access + + +def api_headers(): + return { + "Access-Token": access_token, + } + + +def api_endpoint(): + if env == "prod": + return "https://prod.pratilipicomics.com" + else: + return "https://gamma.pratilipicomics.com" + + +def comic_url(): + if env == "prod": + return "http://internal-k8s-prod-internal-bb9c57a6bb-1524739074.ap-south-1.elb.amazonaws.com:80" + else: + return "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80" diff --git a/internals/util/failure_hander.py b/internals/util/failure_hander.py new file mode 100644 index 0000000000000000000000000000000000000000..8bdfa2660675b41b0d32ffd8c93c90562a6590a1 --- /dev/null +++ b/internals/util/failure_hander.py @@ -0,0 +1,40 @@ +import json +import os +from pathlib import Path + +from internals.data.dataAccessor import updateSource +from internals.data.task import Task +from internals.util.config import set_configs_from_task +from internals.util.slack import Slack + + +class FailureHandler: + __task_path = Path.home() / ".cache" / "inference" / "task.json" + + @staticmethod + def register(): + path = FailureHandler.__task_path + path.parent.mkdir(parents=True, exist_ok=True) + if path.exists(): + task = Task(json.loads(path.read_text())) + set_configs_from_task(task) + # Slack().error_alert(task, Exception("CATASTROPHIC FAILURE")) + updateSource(task.get_sourceId(), task.get_userId(), "FAILED") + os.remove(path) + + @staticmethod + def clear(func): + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + if result is not None: + path = FailureHandler.__task_path + if path.exists(): + os.remove(path) + return result + + return wrapper + + @staticmethod + def handle(task: Task): + path = FailureHandler.__task_path + path.write_text(json.dumps(task.get_raw())) diff --git a/internals/util/image.py b/internals/util/image.py new file mode 100644 index 0000000000000000000000000000000000000000..e6412c8342c4eae9a3da3d847557624084df797d --- /dev/null +++ b/internals/util/image.py @@ -0,0 +1,18 @@ +import io + +from PIL import Image + + +def to_bytes(image: Image.Image) -> bytes: + with io.BytesIO() as output: + image.save(output, format="JPEG") + return output.getvalue() + + +def resize_image_to512(image: Image.Image) -> Image.Image: + iw, ih = image.size + if iw > ih: + image = image.resize((512, int(512 * ih / iw))) + else: + image = image.resize((int(512 * iw / ih), 512)) + return image diff --git a/internals/util/lora_style.py b/internals/util/lora_style.py new file mode 100644 index 0000000000000000000000000000000000000000..95eb2c83011aa6c4e42b2ce6fa0d71b3c16f1207 --- /dev/null +++ b/internals/util/lora_style.py @@ -0,0 +1,154 @@ +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Union + +import boto3 +import torch +from lora_diffusion import patch_pipe, tune_lora_scale +from pydash import chain + +from internals.data.dataAccessor import getStyles +from internals.util.commons import download_file + + +class LoraStyle: + class LoraPatcher: + def __init__(self, pipe, style: Dict[str, Any]): + self.__style = style + self.pipe = pipe + + @torch.inference_mode() + def patch(self): + path = self.__style["path"] + if str(path).endswith((".pt", ".safetensors")): + patch_pipe(self.pipe, self.__style["path"]) + tune_lora_scale(self.pipe.unet, self.__style["weight"]) + tune_lora_scale(self.pipe.text_encoder, self.__style["weight"]) + + def kwargs(self): + return {} + + def cleanup(self): + tune_lora_scale(self.pipe.unet, 0.0) + tune_lora_scale(self.pipe.text_encoder, 0.0) + pass + + class EmptyLoraPatcher: + def __init__(self, pipe): + self.pipe = pipe + + def patch(self): + "Patch will act as cleanup, to tune down any corrupted lora" + self.cleanup() + pass + + def kwargs(self): + return {} + + def cleanup(self): + tune_lora_scale(self.pipe.unet, 0.0) + tune_lora_scale(self.pipe.text_encoder, 0.0) + pass + + def load(self, model_dir: str): + self.model = model_dir + self.fetch_styles() + + def fetch_styles(self): + model_dir = self.model + result = getStyles() + if result is not None: + self.__styles = self.__parse_styles(model_dir, result["data"]) + else: + self.__styles = self.__get_default_styles(model_dir) + self.__verify() + + def prepend_style_to_prompt(self, prompt: str, key: str) -> str: + if key in self.__styles: + style = self.__styles[key] + return f"{', '.join(style['text'])}, {prompt}" + return prompt + + def get_patcher(self, pipe, key: str) -> Union[LoraPatcher, EmptyLoraPatcher]: + if key in self.__styles: + style = self.__styles[key] + return self.LoraPatcher(pipe, style) + return self.EmptyLoraPatcher(pipe) + + def __parse_styles(self, model_dir: str, data: List[Dict]) -> Dict: + styles = {} + download_dir = Path(Path.home() / ".cache" / "lora") + download_dir.mkdir(exist_ok=True) + data = chain(data).uniq_by(lambda x: x["tag"]).value() + for item in data: + if item["attributes"] is not None: + attr = json.loads(item["attributes"]) + if "path" in attr: + file_path = Path(download_dir / attr["path"].split("/")[-1]) + + if not file_path.exists(): + s3_uri = attr["path"] + download_file(s3_uri, file_path) + + styles[item["tag"]] = { + "path": str(file_path), + "weight": attr["weight"], + "type": attr["type"], + "text": attr["text"], + "negativePrompt": attr["negativePrompt"], + } + if len(styles) == 0: + return self.__get_default_styles(model_dir) + return styles + + def __get_default_styles(self, model_dir: str) -> Dict: + return { + "nq6akX1CIp": { + "path": model_dir + "/laur_style/nq6akX1CIp/final_lora.safetensors", + "text": ["nq6akX1CIp style"], + "weight": 0.5, + "negativePrompt": [""], + "type": "custom", + }, + "ghibli": { + "path": model_dir + "/laur_style/nq6akX1CIp/ghibli.bin", + "text": ["ghibli style"], + "weight": 1, + "negativePrompt": [""], + "type": "custom", + }, + "eQAmnK2kB2": { + "path": model_dir + "/laur_style/eQAmnK2kB2/final_lora.safetensors", + "text": ["eQAmnK2kB2 style"], + "weight": 0.5, + "negativePrompt": [""], + "type": "custom", + }, + "to8contrast": { + "path": model_dir + "/laur_style/rpjgusOgqD/final_lora.bin", + "text": ["to8contrast style"], + "weight": 0.5, + "negativePrompt": [""], + "type": "custom", + }, + "sfrrfz8vge": { + "path": model_dir + "/laur_style/replicate/sfrrfz8vge.safetensors", + "text": ["sfrrfz8vge style"], + "weight": 1.2, + "negativePrompt": [""], + "type": "custom", + }, + } + + def __verify(self): + "A method to verify if lora exists within the required path otherwise throw error" + + for item in self.__styles.keys(): + if not os.path.exists(self.__styles[item]["path"]): + raise Exception( + "Lora style model " + + item + + " not found at path: " + + self.__styles[item]["path"] + ) diff --git a/internals/util/slack.py b/internals/util/slack.py new file mode 100644 index 0000000000000000000000000000000000000000..81e99ef502858304e2de3e638c656df9f3c8102a --- /dev/null +++ b/internals/util/slack.py @@ -0,0 +1,58 @@ +from time import sleep +from typing import Optional + +import requests + +from internals.data.task import Task +from internals.util.config import get_environment + + +class Slack: + def __init__(self): + # self.webhook_url = "https://hooks.slack.com/services/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" + self.webhook_url = "https://hooks.slack.com/services/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" + self.error_webhook = "https://hooks.slack.com/services/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" + + def send_alert(self, task: Task, args: Optional[dict]): + raw = task.get_raw().copy() + + raw["environment"] = get_environment() + raw.pop("queue_name", None) + raw.pop("attempt", None) + raw.pop("timestamp", None) + raw.pop("task_id", None) + raw.pop("maskImageUrl", None) + + if args is not None: + raw.update(args.items()) + + message = "" + for key, value in raw.items(): + if value: + if type(value) == list: + message += f"*{key}*: {', '.join(value)}\n" + else: + message += f"*{key}*: {value}\n" + + requests.post( + self.webhook_url, + headers={"Content-Type": "application/json"}, + json={"text": message}, + ) + + def error_alert(self, task: Task, e: Exception): + requests.post( + self.error_webhook, + headers={"Content-Type": "application/json"}, + json={ + "text": "Task failed:\n{} \n error is: \n {}".format(task.get_raw(), e) + }, + ) + + def auto_send_alert(self, func): + def inner(*args, **kwargs): + rargs = func(*args, **kwargs) + self.send_alert(args[0], rargs) + return rargs + + return inner diff --git a/models/ade20k/.DS_Store b/models/ade20k/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..a947acca1e1e3d10608cb935fdb95c532acaae23 Binary files /dev/null and b/models/ade20k/.DS_Store differ diff --git a/models/ade20k/__init__.py b/models/ade20k/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..773cfc4664eef45a4f6fe05bd3fe2aa2143fdb5c --- /dev/null +++ b/models/ade20k/__init__.py @@ -0,0 +1 @@ +from .base import * \ No newline at end of file diff --git a/models/ade20k/base.py b/models/ade20k/base.py new file mode 100644 index 0000000000000000000000000000000000000000..8cdbe2d3e7dbadf4ed5e5a7cf2d248761ef25d9c --- /dev/null +++ b/models/ade20k/base.py @@ -0,0 +1,627 @@ +"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch""" + +import os + +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.io import loadmat +from torch.nn.modules import BatchNorm2d + +from . import resnet +from . import mobilenet + + +NUM_CLASS = 150 +base_path = os.path.dirname(os.path.abspath(__file__)) # current file path +colors_path = os.path.join(base_path, 'color150.mat') +classes_path = os.path.join(base_path, 'object150_info.csv') + +segm_options = dict(colors=loadmat(colors_path)['colors'], + classes=pd.read_csv(classes_path),) + + +class NormalizeTensor: + def __init__(self, mean, std, inplace=False): + """Normalize a tensor image with mean and standard deviation. + .. note:: + This transform acts out of place by default, i.e., it does not mutates the input tensor. + See :class:`~torchvision.transforms.Normalize` for more details. + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + inplace(bool,optional): Bool to make this operation inplace. + Returns: + Tensor: Normalized Tensor image. + """ + + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, tensor): + if not self.inplace: + tensor = tensor.clone() + + dtype = tensor.dtype + mean = torch.as_tensor(self.mean, dtype=dtype, device=tensor.device) + std = torch.as_tensor(self.std, dtype=dtype, device=tensor.device) + tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) + return tensor + + +# Model Builder +class ModelBuilder: + # custom weights initialization + @staticmethod + def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight.data) + elif classname.find('BatchNorm') != -1: + m.weight.data.fill_(1.) + m.bias.data.fill_(1e-4) + + @staticmethod + def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''): + pretrained = True if len(weights) == 0 else False + arch = arch.lower() + if arch == 'mobilenetv2dilated': + orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained) + net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8) + elif arch == 'resnet18': + orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + elif arch == 'resnet18dilated': + orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) + elif arch == 'resnet50dilated': + orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) + net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) + elif arch == 'resnet50': + orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) + net_encoder = Resnet(orig_resnet) + else: + raise Exception('Architecture undefined!') + + # encoders are usually pretrained + # net_encoder.apply(ModelBuilder.weights_init) + if len(weights) > 0: + print('Loading weights for net_encoder') + net_encoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), strict=False) + return net_encoder + + @staticmethod + def build_decoder(arch='ppm_deepsup', + fc_dim=512, num_class=NUM_CLASS, + weights='', use_softmax=False, drop_last_conv=False): + arch = arch.lower() + if arch == 'ppm_deepsup': + net_decoder = PPMDeepsup( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax, + drop_last_conv=drop_last_conv) + elif arch == 'c1_deepsup': + net_decoder = C1DeepSup( + num_class=num_class, + fc_dim=fc_dim, + use_softmax=use_softmax, + drop_last_conv=drop_last_conv) + else: + raise Exception('Architecture undefined!') + + net_decoder.apply(ModelBuilder.weights_init) + if len(weights) > 0: + print('Loading weights for net_decoder') + net_decoder.load_state_dict( + torch.load(weights, map_location=lambda storage, loc: storage), strict=False) + return net_decoder + + @staticmethod + def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, drop_last_conv, *arts, **kwargs): + path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth') + return ModelBuilder.build_decoder(arch=arch_decoder, fc_dim=fc_dim, weights=path, use_softmax=True, drop_last_conv=drop_last_conv) + + @staticmethod + def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, segmentation, + *arts, **kwargs): + if segmentation: + path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth') + else: + path = '' + return ModelBuilder.build_encoder(arch=arch_encoder, fc_dim=fc_dim, weights=path) + + +def conv3x3_bn_relu(in_planes, out_planes, stride=1): + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False), + BatchNorm2d(out_planes), + nn.ReLU(inplace=True), + ) + + +class SegmentationModule(nn.Module): + def __init__(self, + weights_path, + num_classes=150, + arch_encoder="resnet50dilated", + drop_last_conv=False, + net_enc=None, # None for Default encoder + net_dec=None, # None for Default decoder + encode=None, # {None, 'binary', 'color', 'sky'} + use_default_normalization=False, + return_feature_maps=False, + return_feature_maps_level=3, # {0, 1, 2, 3} + return_feature_maps_only=True, + **kwargs, + ): + super().__init__() + self.weights_path = weights_path + self.drop_last_conv = drop_last_conv + self.arch_encoder = arch_encoder + if self.arch_encoder == "resnet50dilated": + self.arch_decoder = "ppm_deepsup" + self.fc_dim = 2048 + elif self.arch_encoder == "mobilenetv2dilated": + self.arch_decoder = "c1_deepsup" + self.fc_dim = 320 + else: + raise NotImplementedError(f"No such arch_encoder={self.arch_encoder}") + model_builder_kwargs = dict(arch_encoder=self.arch_encoder, + arch_decoder=self.arch_decoder, + fc_dim=self.fc_dim, + drop_last_conv=drop_last_conv, + weights_path=self.weights_path) + + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.encoder = ModelBuilder.get_encoder(**model_builder_kwargs) if net_enc is None else net_enc + self.decoder = ModelBuilder.get_decoder(**model_builder_kwargs) if net_dec is None else net_dec + self.use_default_normalization = use_default_normalization + self.default_normalization = NormalizeTensor(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + self.encode = encode + + self.return_feature_maps = return_feature_maps + + assert 0 <= return_feature_maps_level <= 3 + self.return_feature_maps_level = return_feature_maps_level + + def normalize_input(self, tensor): + if tensor.min() < 0 or tensor.max() > 1: + raise ValueError("Tensor should be 0..1 before using normalize_input") + return self.default_normalization(tensor) + + @property + def feature_maps_channels(self): + return 256 * 2**(self.return_feature_maps_level) # 256, 512, 1024, 2048 + + def forward(self, img_data, segSize=None): + if segSize is None: + raise NotImplementedError("Please pass segSize param. By default: (300, 300)") + + fmaps = self.encoder(img_data, return_feature_maps=True) + pred = self.decoder(fmaps, segSize=segSize) + + if self.return_feature_maps: + return pred, fmaps + # print("BINARY", img_data.shape, pred.shape) + return pred + + def multi_mask_from_multiclass(self, pred, classes): + def isin(ar1, ar2): + return (ar1[..., None] == ar2).any(-1).float() + return isin(pred, torch.LongTensor(classes).to(self.device)) + + @staticmethod + def multi_mask_from_multiclass_probs(scores, classes): + res = None + for c in classes: + if res is None: + res = scores[:, c] + else: + res += scores[:, c] + return res + + def predict(self, tensor, imgSizes=(-1,), # (300, 375, 450, 525, 600) + segSize=None): + """Entry-point for segmentation. Use this methods instead of forward + Arguments: + tensor {torch.Tensor} -- BCHW + Keyword Arguments: + imgSizes {tuple or list} -- imgSizes for segmentation input. + default: (300, 450) + original implementation: (300, 375, 450, 525, 600) + + """ + if segSize is None: + segSize = tensor.shape[-2:] + segSize = (tensor.shape[2], tensor.shape[3]) + with torch.no_grad(): + if self.use_default_normalization: + tensor = self.normalize_input(tensor) + scores = torch.zeros(1, NUM_CLASS, segSize[0], segSize[1]).to(self.device) + features = torch.zeros(1, self.feature_maps_channels, segSize[0], segSize[1]).to(self.device) + + result = [] + for img_size in imgSizes: + if img_size != -1: + img_data = F.interpolate(tensor.clone(), size=img_size) + else: + img_data = tensor.clone() + + if self.return_feature_maps: + pred_current, fmaps = self.forward(img_data, segSize=segSize) + else: + pred_current = self.forward(img_data, segSize=segSize) + + + result.append(pred_current) + scores = scores + pred_current / len(imgSizes) + + # Disclaimer: We use and aggregate only last fmaps: fmaps[3] + if self.return_feature_maps: + features = features + F.interpolate(fmaps[self.return_feature_maps_level], size=segSize) / len(imgSizes) + + _, pred = torch.max(scores, dim=1) + + if self.return_feature_maps: + return features + + return pred, result + + def get_edges(self, t): + edge = torch.cuda.ByteTensor(t.size()).zero_() + edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) + edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) + edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) + edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) + + if True: + return edge.half() + return edge.float() + + +# pyramid pooling, deep supervision +class PPMDeepsup(nn.Module): + def __init__(self, num_class=NUM_CLASS, fc_dim=4096, + use_softmax=False, pool_scales=(1, 2, 3, 6), + drop_last_conv=False): + super().__init__() + self.use_softmax = use_softmax + self.drop_last_conv = drop_last_conv + + self.ppm = [] + for scale in pool_scales: + self.ppm.append(nn.Sequential( + nn.AdaptiveAvgPool2d(scale), + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + BatchNorm2d(512), + nn.ReLU(inplace=True) + )) + self.ppm = nn.ModuleList(self.ppm) + self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) + + self.conv_last = nn.Sequential( + nn.Conv2d(fc_dim + len(pool_scales) * 512, 512, + kernel_size=3, padding=1, bias=False), + BatchNorm2d(512), + nn.ReLU(inplace=True), + nn.Dropout2d(0.1), + nn.Conv2d(512, num_class, kernel_size=1) + ) + self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + self.dropout_deepsup = nn.Dropout2d(0.1) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale in self.ppm: + ppm_out.append(nn.functional.interpolate( + pool_scale(conv5), + (input_size[2], input_size[3]), + mode='bilinear', align_corners=False)) + ppm_out = torch.cat(ppm_out, 1) + + if self.drop_last_conv: + return ppm_out + else: + x = self.conv_last(ppm_out) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + return x + + # deep sup + conv4 = conv_out[-2] + _ = self.cbr_deepsup(conv4) + _ = self.dropout_deepsup(_) + _ = self.conv_last_deepsup(_) + + x = nn.functional.log_softmax(x, dim=1) + _ = nn.functional.log_softmax(_, dim=1) + + return (x, _) + + +class Resnet(nn.Module): + def __init__(self, orig_resnet): + super(Resnet, self).__init__() + + # take pretrained resnet, except AvgPool and FC + self.conv1 = orig_resnet.conv1 + self.bn1 = orig_resnet.bn1 + self.relu1 = orig_resnet.relu1 + self.conv2 = orig_resnet.conv2 + self.bn2 = orig_resnet.bn2 + self.relu2 = orig_resnet.relu2 + self.conv3 = orig_resnet.conv3 + self.bn3 = orig_resnet.bn3 + self.relu3 = orig_resnet.relu3 + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def forward(self, x, return_feature_maps=False): + conv_out = [] + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x); conv_out.append(x); + x = self.layer2(x); conv_out.append(x); + x = self.layer3(x); conv_out.append(x); + x = self.layer4(x); conv_out.append(x); + + if return_feature_maps: + return conv_out + return [x] + +# Resnet Dilated +class ResnetDilated(nn.Module): + def __init__(self, orig_resnet, dilate_scale=8): + super().__init__() + from functools import partial + + if dilate_scale == 8: + orig_resnet.layer3.apply( + partial(self._nostride_dilate, dilate=2)) + orig_resnet.layer4.apply( + partial(self._nostride_dilate, dilate=4)) + elif dilate_scale == 16: + orig_resnet.layer4.apply( + partial(self._nostride_dilate, dilate=2)) + + # take pretrained resnet, except AvgPool and FC + self.conv1 = orig_resnet.conv1 + self.bn1 = orig_resnet.bn1 + self.relu1 = orig_resnet.relu1 + self.conv2 = orig_resnet.conv2 + self.bn2 = orig_resnet.bn2 + self.relu2 = orig_resnet.relu2 + self.conv3 = orig_resnet.conv3 + self.bn3 = orig_resnet.bn3 + self.relu3 = orig_resnet.relu3 + self.maxpool = orig_resnet.maxpool + self.layer1 = orig_resnet.layer1 + self.layer2 = orig_resnet.layer2 + self.layer3 = orig_resnet.layer3 + self.layer4 = orig_resnet.layer4 + + def _nostride_dilate(self, m, dilate): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + # the convolution with stride + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate // 2, dilate // 2) + m.padding = (dilate // 2, dilate // 2) + # other convoluions + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + def forward(self, x, return_feature_maps=False): + conv_out = [] + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + conv_out.append(x) + x = self.layer2(x) + conv_out.append(x) + x = self.layer3(x) + conv_out.append(x) + x = self.layer4(x) + conv_out.append(x) + + if return_feature_maps: + return conv_out + return [x] + +class MobileNetV2Dilated(nn.Module): + def __init__(self, orig_net, dilate_scale=8): + super(MobileNetV2Dilated, self).__init__() + from functools import partial + + # take pretrained mobilenet features + self.features = orig_net.features[:-1] + + self.total_idx = len(self.features) + self.down_idx = [2, 4, 7, 14] + + if dilate_scale == 8: + for i in range(self.down_idx[-2], self.down_idx[-1]): + self.features[i].apply( + partial(self._nostride_dilate, dilate=2) + ) + for i in range(self.down_idx[-1], self.total_idx): + self.features[i].apply( + partial(self._nostride_dilate, dilate=4) + ) + elif dilate_scale == 16: + for i in range(self.down_idx[-1], self.total_idx): + self.features[i].apply( + partial(self._nostride_dilate, dilate=2) + ) + + def _nostride_dilate(self, m, dilate): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + # the convolution with stride + if m.stride == (2, 2): + m.stride = (1, 1) + if m.kernel_size == (3, 3): + m.dilation = (dilate//2, dilate//2) + m.padding = (dilate//2, dilate//2) + # other convoluions + else: + if m.kernel_size == (3, 3): + m.dilation = (dilate, dilate) + m.padding = (dilate, dilate) + + def forward(self, x, return_feature_maps=False): + if return_feature_maps: + conv_out = [] + for i in range(self.total_idx): + x = self.features[i](x) + if i in self.down_idx: + conv_out.append(x) + conv_out.append(x) + return conv_out + + else: + return [self.features(x)] + + +# last conv, deep supervision +class C1DeepSup(nn.Module): + def __init__(self, num_class=150, fc_dim=2048, use_softmax=False, drop_last_conv=False): + super(C1DeepSup, self).__init__() + self.use_softmax = use_softmax + self.drop_last_conv = drop_last_conv + + self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) + self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) + + # last conv + self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + x = self.cbr(conv5) + + if self.drop_last_conv: + return x + else: + x = self.conv_last(x) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + return x + + # deep sup + conv4 = conv_out[-2] + _ = self.cbr_deepsup(conv4) + _ = self.conv_last_deepsup(_) + + x = nn.functional.log_softmax(x, dim=1) + _ = nn.functional.log_softmax(_, dim=1) + + return (x, _) + + +# last conv +class C1(nn.Module): + def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): + super(C1, self).__init__() + self.use_softmax = use_softmax + + self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) + + # last conv + self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + x = self.cbr(conv5) + x = self.conv_last(x) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + else: + x = nn.functional.log_softmax(x, dim=1) + + return x + + +# pyramid pooling +class PPM(nn.Module): + def __init__(self, num_class=150, fc_dim=4096, + use_softmax=False, pool_scales=(1, 2, 3, 6)): + super(PPM, self).__init__() + self.use_softmax = use_softmax + + self.ppm = [] + for scale in pool_scales: + self.ppm.append(nn.Sequential( + nn.AdaptiveAvgPool2d(scale), + nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), + BatchNorm2d(512), + nn.ReLU(inplace=True) + )) + self.ppm = nn.ModuleList(self.ppm) + + self.conv_last = nn.Sequential( + nn.Conv2d(fc_dim+len(pool_scales)*512, 512, + kernel_size=3, padding=1, bias=False), + BatchNorm2d(512), + nn.ReLU(inplace=True), + nn.Dropout2d(0.1), + nn.Conv2d(512, num_class, kernel_size=1) + ) + + def forward(self, conv_out, segSize=None): + conv5 = conv_out[-1] + + input_size = conv5.size() + ppm_out = [conv5] + for pool_scale in self.ppm: + ppm_out.append(nn.functional.interpolate( + pool_scale(conv5), + (input_size[2], input_size[3]), + mode='bilinear', align_corners=False)) + ppm_out = torch.cat(ppm_out, 1) + + x = self.conv_last(ppm_out) + + if self.use_softmax: # is True during inference + x = nn.functional.interpolate( + x, size=segSize, mode='bilinear', align_corners=False) + x = nn.functional.softmax(x, dim=1) + else: + x = nn.functional.log_softmax(x, dim=1) + return x diff --git a/models/ade20k/color150.mat b/models/ade20k/color150.mat new file mode 100644 index 0000000000000000000000000000000000000000..c518b64fbbe899d4a8b2705f012eeba795339892 Binary files /dev/null and b/models/ade20k/color150.mat differ diff --git a/models/ade20k/mobilenet.py b/models/ade20k/mobilenet.py new file mode 100644 index 0000000000000000000000000000000000000000..f501266e56ee71cdf455744020f8fc1a58ec9fff --- /dev/null +++ b/models/ade20k/mobilenet.py @@ -0,0 +1,154 @@ +""" +This MobileNetV2 implementation is modified from the following repository: +https://github.com/tonylins/pytorch-mobilenet-v2 +""" + +import torch.nn as nn +import math +from .utils import load_url +from .segm_lib.nn import SynchronizedBatchNorm2d + +BatchNorm2d = SynchronizedBatchNorm2d + + +__all__ = ['mobilenetv2'] + + +model_urls = { + 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar', +} + + +def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +def conv_1x1_bn(inp, oup): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = round(inp * expand_ratio) + self.use_res_connect = self.stride == 1 and inp == oup + + if expand_ratio == 1: + self.conv = nn.Sequential( + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + BatchNorm2d(oup), + ) + else: + self.conv = nn.Sequential( + # pw + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), + BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + BatchNorm2d(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, n_class=1000, input_size=224, width_mult=1.): + super(MobileNetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + interverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # building first layer + assert input_size % 32 == 0 + input_channel = int(input_channel * width_mult) + self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel + self.features = [conv_bn(3, input_channel, 2)] + # building inverted residual blocks + for t, c, n, s in interverted_residual_setting: + output_channel = int(c * width_mult) + for i in range(n): + if i == 0: + self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) + else: + self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) + input_channel = output_channel + # building last several layers + self.features.append(conv_1x1_bn(input_channel, self.last_channel)) + # make it nn.Sequential + self.features = nn.Sequential(*self.features) + + # building classifier + self.classifier = nn.Sequential( + nn.Dropout(0.2), + nn.Linear(self.last_channel, n_class), + ) + + self._initialize_weights() + + def forward(self, x): + x = self.features(x) + x = x.mean(3).mean(2) + x = self.classifier(x) + return x + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + +def mobilenetv2(pretrained=False, **kwargs): + """Constructs a MobileNet_V2 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = MobileNetV2(n_class=1000, **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False) + return model \ No newline at end of file diff --git a/models/ade20k/object150_info.csv b/models/ade20k/object150_info.csv new file mode 100644 index 0000000000000000000000000000000000000000..8b34d8f3874a38b96894863c5458a7c3c2b0e2e6 --- /dev/null +++ b/models/ade20k/object150_info.csv @@ -0,0 +1,151 @@ +Idx,Ratio,Train,Val,Stuff,Name +1,0.1576,11664,1172,1,wall +2,0.1072,6046,612,1,building;edifice +3,0.0878,8265,796,1,sky +4,0.0621,9336,917,1,floor;flooring +5,0.0480,6678,641,0,tree +6,0.0450,6604,643,1,ceiling +7,0.0398,4023,408,1,road;route +8,0.0231,1906,199,0,bed +9,0.0198,4688,460,0,windowpane;window +10,0.0183,2423,225,1,grass +11,0.0181,2874,294,0,cabinet +12,0.0166,3068,310,1,sidewalk;pavement +13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul +14,0.0151,1804,190,1,earth;ground +15,0.0118,6666,796,0,door;double;door +16,0.0110,4269,411,0,table +17,0.0109,1691,160,1,mountain;mount +18,0.0104,3999,441,0,plant;flora;plant;life +19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall +20,0.0103,3261,318,0,chair +21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar +22,0.0074,709,75,1,water +23,0.0067,3296,315,0,painting;picture +24,0.0065,1191,106,0,sofa;couch;lounge +25,0.0061,1516,162,0,shelf +26,0.0060,667,69,1,house +27,0.0053,651,57,1,sea +28,0.0052,1847,224,0,mirror +29,0.0046,1158,128,1,rug;carpet;carpeting +30,0.0044,480,44,1,field +31,0.0044,1172,98,0,armchair +32,0.0044,1292,184,0,seat +33,0.0033,1386,138,0,fence;fencing +34,0.0031,698,61,0,desk +35,0.0030,781,73,0,rock;stone +36,0.0027,380,43,0,wardrobe;closet;press +37,0.0026,3089,302,0,lamp +38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub +39,0.0024,804,99,0,railing;rail +40,0.0023,1453,153,0,cushion +41,0.0023,411,37,0,base;pedestal;stand +42,0.0022,1440,162,0,box +43,0.0022,800,77,0,column;pillar +44,0.0020,2650,298,0,signboard;sign +45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser +46,0.0019,367,36,0,counter +47,0.0018,311,30,1,sand +48,0.0018,1181,122,0,sink +49,0.0018,287,23,1,skyscraper +50,0.0018,468,38,0,fireplace;hearth;open;fireplace +51,0.0018,402,43,0,refrigerator;icebox +52,0.0018,130,12,1,grandstand;covered;stand +53,0.0018,561,64,1,path +54,0.0017,880,102,0,stairs;steps +55,0.0017,86,12,1,runway +56,0.0017,172,11,0,case;display;case;showcase;vitrine +57,0.0017,198,18,0,pool;table;billiard;table;snooker;table +58,0.0017,930,109,0,pillow +59,0.0015,139,18,0,screen;door;screen +60,0.0015,564,52,1,stairway;staircase +61,0.0015,320,26,1,river +62,0.0015,261,29,1,bridge;span +63,0.0014,275,22,0,bookcase +64,0.0014,335,60,0,blind;screen +65,0.0014,792,75,0,coffee;table;cocktail;table +66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne +67,0.0014,1309,138,0,flower +68,0.0013,1112,113,0,book +69,0.0013,266,27,1,hill +70,0.0013,659,66,0,bench +71,0.0012,331,31,0,countertop +72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove +73,0.0012,369,36,0,palm;palm;tree +74,0.0012,144,9,0,kitchen;island +75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system +76,0.0010,324,33,0,swivel;chair +77,0.0009,304,27,0,boat +78,0.0009,170,20,0,bar +79,0.0009,68,6,0,arcade;machine +80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty +81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle +82,0.0008,492,49,0,towel +83,0.0008,2510,269,0,light;light;source +84,0.0008,440,39,0,truck;motortruck +85,0.0008,147,18,1,tower +86,0.0008,583,56,0,chandelier;pendant;pendent +87,0.0007,533,61,0,awning;sunshade;sunblind +88,0.0007,1989,239,0,streetlight;street;lamp +89,0.0007,71,5,0,booth;cubicle;stall;kiosk +90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box +91,0.0007,135,12,0,airplane;aeroplane;plane +92,0.0007,83,5,1,dirt;track +93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes +94,0.0006,1003,104,0,pole +95,0.0006,182,12,1,land;ground;soil +96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail +97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway +98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock +99,0.0006,965,114,0,bottle +100,0.0006,117,13,0,buffet;counter;sideboard +101,0.0006,354,35,0,poster;posting;placard;notice;bill;card +102,0.0006,108,9,1,stage +103,0.0006,557,55,0,van +104,0.0006,52,4,0,ship +105,0.0005,99,5,0,fountain +106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter +107,0.0005,292,31,0,canopy +108,0.0005,77,9,0,washer;automatic;washer;washing;machine +109,0.0005,340,38,0,plaything;toy +110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium +111,0.0005,465,49,0,stool +112,0.0005,50,4,0,barrel;cask +113,0.0005,622,75,0,basket;handbasket +114,0.0005,80,9,1,waterfall;falls +115,0.0005,59,3,0,tent;collapsible;shelter +116,0.0005,531,72,0,bag +117,0.0005,282,30,0,minibike;motorbike +118,0.0005,73,7,0,cradle +119,0.0005,435,44,0,oven +120,0.0005,136,25,0,ball +121,0.0005,116,24,0,food;solid;food +122,0.0004,266,31,0,step;stair +123,0.0004,58,12,0,tank;storage;tank +124,0.0004,418,83,0,trade;name;brand;name;brand;marque +125,0.0004,319,43,0,microwave;microwave;oven +126,0.0004,1193,139,0,pot;flowerpot +127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna +128,0.0004,347,36,0,bicycle;bike;wheel;cycle +129,0.0004,52,5,1,lake +130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine +131,0.0004,108,13,0,screen;silver;screen;projection;screen +132,0.0004,201,30,0,blanket;cover +133,0.0004,285,21,0,sculpture +134,0.0004,268,27,0,hood;exhaust;hood +135,0.0003,1020,108,0,sconce +136,0.0003,1282,122,0,vase +137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight +138,0.0003,453,57,0,tray +139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin +140,0.0003,397,44,0,fan +141,0.0003,92,8,1,pier;wharf;wharfage;dock +142,0.0003,228,18,0,crt;screen +143,0.0003,570,59,0,plate +144,0.0003,217,22,0,monitor;monitoring;device +145,0.0003,206,19,0,bulletin;board;notice;board +146,0.0003,130,14,0,shower +147,0.0003,178,28,0,radiator +148,0.0002,504,57,0,glass;drinking;glass +149,0.0002,775,96,0,clock +150,0.0002,421,56,0,flag diff --git a/models/ade20k/resnet.py b/models/ade20k/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..3e1d521f171c984cf6a7ff3dcebd96f8c5faf908 --- /dev/null +++ b/models/ade20k/resnet.py @@ -0,0 +1,181 @@ +"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch""" + +import math + +import torch.nn as nn +from torch.nn import BatchNorm2d + +from .utils import load_url + +__all__ = ['ResNet', 'resnet50'] + + +model_urls = { + 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 128 + super(ResNet, self).__init__() + self.conv1 = conv3x3(3, 64, stride=2) + self.bn1 = BatchNorm2d(64) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = conv3x3(64, 64) + self.bn2 = BatchNorm2d(64) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = conv3x3(64, 128) + self.bn3 = BatchNorm2d(128) + self.relu3 = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet50']), strict=False) + return model + + +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(load_url(model_urls['resnet18'])) + return model \ No newline at end of file diff --git a/models/ade20k/segm_lib/.DS_Store b/models/ade20k/segm_lib/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..56160ef3245ede32d722e7166700a20b006d0b7e Binary files /dev/null and b/models/ade20k/segm_lib/.DS_Store differ diff --git a/models/ade20k/segm_lib/nn/.DS_Store b/models/ade20k/segm_lib/nn/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..7b4aa135e992184d67b70d5342cf93a2057a3bf4 Binary files /dev/null and b/models/ade20k/segm_lib/nn/.DS_Store differ diff --git a/models/ade20k/segm_lib/nn/__init__.py b/models/ade20k/segm_lib/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..98a96370ef04570f516052bb73f568d0ebc346c3 --- /dev/null +++ b/models/ade20k/segm_lib/nn/__init__.py @@ -0,0 +1,2 @@ +from .modules import * +from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to diff --git a/models/ade20k/segm_lib/nn/modules/__init__.py b/models/ade20k/segm_lib/nn/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc8709d92c610b36e0bcbd7da20c1eb41dc8cfcf --- /dev/null +++ b/models/ade20k/segm_lib/nn/modules/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/models/ade20k/segm_lib/nn/modules/batchnorm.py b/models/ade20k/segm_lib/nn/modules/batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..18318965335b37cc671004a6aceda3229dc7b477 --- /dev/null +++ b/models/ade20k/segm_lib/nn/modules/batchnorm.py @@ -0,0 +1,329 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + # customed batch norm statistics + self._moving_average_fraction = 1. - momentum + self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features)) + self.register_buffer('_tmp_running_var', torch.ones(self.num_features)) + self.register_buffer('_running_iter', torch.ones(1)) + self._tmp_running_mean = self.running_mean.clone() * self._running_iter + self._tmp_running_var = self.running_var.clone() * self._running_iter + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0): + """return *dest* by `dest := dest*alpha + delta*beta + bias`""" + return dest * alpha + delta * beta + bias + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction) + self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction) + self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction) + + self.running_mean = self._tmp_running_mean / self._running_iter + self.running_var = self._tmp_running_var / self._running_iter + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) diff --git a/models/ade20k/segm_lib/nn/modules/comm.py b/models/ade20k/segm_lib/nn/modules/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..b64bf6ba3b3e7abbab375c6dd4a87d8239e62138 --- /dev/null +++ b/models/ade20k/segm_lib/nn/modules/comm.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/models/ade20k/segm_lib/nn/modules/replicate.py b/models/ade20k/segm_lib/nn/modules/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c7b8ed51a1d6c55b1f753bdd8d90bad79bd06 --- /dev/null +++ b/models/ade20k/segm_lib/nn/modules/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/models/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py b/models/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..8bd45a930d3dc84912e58659ee575be08e9038f0 --- /dev/null +++ b/models/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# File : test_numeric_batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. + +import unittest + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from sync_batchnorm.unittest import TorchTestCase + + +def handy_var(a, unbias=True): + n = a.size(0) + asum = a.sum(dim=0) + as_sum = (a ** 2).sum(dim=0) # a square sum + sumvar = as_sum - asum * asum / n + if unbias: + return sumvar / (n - 1) + else: + return sumvar / n + + +class NumericTestCase(TorchTestCase): + def testNumericBatchNorm(self): + a = torch.rand(16, 10) + bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) + bn.train() + + a_var1 = Variable(a, requires_grad=True) + b_var1 = bn(a_var1) + loss1 = b_var1.sum() + loss1.backward() + + a_var2 = Variable(a, requires_grad=True) + a_mean2 = a_var2.mean(dim=0, keepdim=True) + a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) + # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) + b_var2 = (a_var2 - a_mean2) / a_std2 + loss2 = b_var2.sum() + loss2.backward() + + self.assertTensorClose(bn.running_mean, a.mean(dim=0)) + self.assertTensorClose(bn.running_var, handy_var(a)) + self.assertTensorClose(a_var1.data, a_var2.data) + self.assertTensorClose(b_var1.data, b_var2.data) + self.assertTensorClose(a_var1.grad, a_var2.grad) + + +if __name__ == '__main__': + unittest.main() diff --git a/models/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py b/models/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..45bb3c8cfd36d8f668e6fde756b17587eab72082 --- /dev/null +++ b/models/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# File : test_sync_batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. + +import unittest + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback +from sync_batchnorm.unittest import TorchTestCase + + +def handy_var(a, unbias=True): + n = a.size(0) + asum = a.sum(dim=0) + as_sum = (a ** 2).sum(dim=0) # a square sum + sumvar = as_sum - asum * asum / n + if unbias: + return sumvar / (n - 1) + else: + return sumvar / n + + +def _find_bn(module): + for m in module.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): + return m + + +class SyncTestCase(TorchTestCase): + def _syncParameters(self, bn1, bn2): + bn1.reset_parameters() + bn2.reset_parameters() + if bn1.affine and bn2.affine: + bn2.weight.data.copy_(bn1.weight.data) + bn2.bias.data.copy_(bn1.bias.data) + + def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): + """Check the forward and backward for the customized batch normalization.""" + bn1.train(mode=is_train) + bn2.train(mode=is_train) + + if cuda: + input = input.cuda() + + self._syncParameters(_find_bn(bn1), _find_bn(bn2)) + + input1 = Variable(input, requires_grad=True) + output1 = bn1(input1) + output1.sum().backward() + input2 = Variable(input, requires_grad=True) + output2 = bn2(input2) + output2.sum().backward() + + self.assertTensorClose(input1.data, input2.data) + self.assertTensorClose(output1.data, output2.data) + self.assertTensorClose(input1.grad, input2.grad) + self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) + self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) + + def testSyncBatchNormNormalTrain(self): + bn = nn.BatchNorm1d(10) + sync_bn = SynchronizedBatchNorm1d(10) + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) + + def testSyncBatchNormNormalEval(self): + bn = nn.BatchNorm1d(10) + sync_bn = SynchronizedBatchNorm1d(10) + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) + + def testSyncBatchNormSyncTrain(self): + bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) + + def testSyncBatchNormSyncEval(self): + bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) + + def testSyncBatchNorm2DSyncTrain(self): + bn = nn.BatchNorm2d(10) + sync_bn = SynchronizedBatchNorm2d(10) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/models/ade20k/segm_lib/nn/modules/unittest.py b/models/ade20k/segm_lib/nn/modules/unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..0675c022e4ba85d38d1f813490f6740150909524 --- /dev/null +++ b/models/ade20k/segm_lib/nn/modules/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest + +import numpy as np +from torch.autograd import Variable + + +def as_numpy(v): + if isinstance(v, Variable): + v = v.data + return v.cpu().numpy() + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): + npa, npb = as_numpy(a), as_numpy(b) + self.assertTrue( + np.allclose(npa, npb, atol=atol), + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + ) diff --git a/models/ade20k/segm_lib/nn/parallel/__init__.py b/models/ade20k/segm_lib/nn/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b52f49cc0755562218a460483cbf02514ddd773 --- /dev/null +++ b/models/ade20k/segm_lib/nn/parallel/__init__.py @@ -0,0 +1 @@ +from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to diff --git a/models/ade20k/segm_lib/nn/parallel/data_parallel.py b/models/ade20k/segm_lib/nn/parallel/data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..376fc038919aa2a5bd696141e7bb6025d4981306 --- /dev/null +++ b/models/ade20k/segm_lib/nn/parallel/data_parallel.py @@ -0,0 +1,112 @@ +# -*- coding: utf8 -*- + +import torch.cuda as cuda +import torch.nn as nn +import torch +import collections +from torch.nn.parallel._functions import Gather + + +__all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] + + +def async_copy_to(obj, dev, main_stream=None): + if torch.is_tensor(obj): + v = obj.cuda(dev, non_blocking=True) + if main_stream is not None: + v.data.record_stream(main_stream) + return v + elif isinstance(obj, collections.Mapping): + return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [async_copy_to(o, dev, main_stream) for o in obj] + else: + return obj + + +def dict_gather(outputs, target_device, dim=0): + """ + Gathers variables from different GPUs on a specified device + (-1 means the CPU), with dictionary support. + """ + def gather_map(outputs): + out = outputs[0] + if torch.is_tensor(out): + # MJY(20180330) HACK:: force nr_dims > 0 + if out.dim() == 0: + outputs = [o.unsqueeze(0) for o in outputs] + return Gather.apply(target_device, dim, *outputs) + elif out is None: + return None + elif isinstance(out, collections.Mapping): + return {k: gather_map([o[k] for o in outputs]) for k in out} + elif isinstance(out, collections.Sequence): + return type(out)(map(gather_map, zip(*outputs))) + return gather_map(outputs) + + +class DictGatherDataParallel(nn.DataParallel): + def gather(self, outputs, output_device): + return dict_gather(outputs, output_device, dim=self.dim) + + +class UserScatteredDataParallel(DictGatherDataParallel): + def scatter(self, inputs, kwargs, device_ids): + assert len(inputs) == 1 + inputs = inputs[0] + inputs = _async_copy_stream(inputs, device_ids) + inputs = [[i] for i in inputs] + assert len(kwargs) == 0 + kwargs = [{} for _ in range(len(inputs))] + + return inputs, kwargs + + +def user_scattered_collate(batch): + return batch + + +def _async_copy(inputs, device_ids): + nr_devs = len(device_ids) + assert type(inputs) in (tuple, list) + assert len(inputs) == nr_devs + + outputs = [] + for i, dev in zip(inputs, device_ids): + with cuda.device(dev): + outputs.append(async_copy_to(i, dev)) + + return tuple(outputs) + + +def _async_copy_stream(inputs, device_ids): + nr_devs = len(device_ids) + assert type(inputs) in (tuple, list) + assert len(inputs) == nr_devs + + outputs = [] + streams = [_get_stream(d) for d in device_ids] + for i, dev, stream in zip(inputs, device_ids, streams): + with cuda.device(dev): + main_stream = cuda.current_stream() + with cuda.stream(stream): + outputs.append(async_copy_to(i, dev, main_stream=main_stream)) + main_stream.wait_stream(stream) + + return outputs + + +"""Adapted from: torch/nn/parallel/_functions.py""" +# background streams used for copying +_streams = None + + +def _get_stream(device): + """Gets a background stream for copying between CPU and GPU""" + global _streams + if device == -1: + return None + if _streams is None: + _streams = [None] * cuda.device_count() + if _streams[device] is None: _streams[device] = cuda.Stream(device) + return _streams[device] diff --git a/models/ade20k/segm_lib/utils/__init__.py b/models/ade20k/segm_lib/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abe3cbe49477fe37d4fc16249de8a10f4fb4a013 --- /dev/null +++ b/models/ade20k/segm_lib/utils/__init__.py @@ -0,0 +1 @@ +from .th import * diff --git a/models/ade20k/segm_lib/utils/data/__init__.py b/models/ade20k/segm_lib/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b008fb13c5e8a84b1b785056e8c4f5226dc976 --- /dev/null +++ b/models/ade20k/segm_lib/utils/data/__init__.py @@ -0,0 +1,3 @@ + +from .dataset import Dataset, TensorDataset, ConcatDataset +from .dataloader import DataLoader diff --git a/models/ade20k/segm_lib/utils/data/dataloader.py b/models/ade20k/segm_lib/utils/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..039b9ec3645b2a4626ff47c221e372f32a6ad339 --- /dev/null +++ b/models/ade20k/segm_lib/utils/data/dataloader.py @@ -0,0 +1,425 @@ +import torch +import torch.multiprocessing as multiprocessing +from torch._C import _set_worker_signal_handlers, \ + _remove_worker_pids, _error_if_any_worker_fails +try: + from torch._C import _set_worker_pids +except: + from torch._C import _update_worker_pids as _set_worker_pids +from .sampler import SequentialSampler, RandomSampler, BatchSampler +import signal +import collections +import re +import sys +import threading +import traceback +from torch._six import string_classes, int_classes +import numpy as np + +if sys.version_info[0] == 2: + import Queue as queue +else: + import queue + + +class ExceptionWrapper(object): + r"Wraps an exception plus traceback to communicate across threads" + + def __init__(self, exc_info): + self.exc_type = exc_info[0] + self.exc_msg = "".join(traceback.format_exception(*exc_info)) + + +_use_shared_memory = False +"""Whether to use shared memory in default_collate""" + + +def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): + global _use_shared_memory + _use_shared_memory = True + + # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal + # module's handlers are executed after Python returns from C low-level + # handlers, likely when the same fatal signal happened again already. + # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 + _set_worker_signal_handlers() + + torch.set_num_threads(1) + torch.manual_seed(seed) + np.random.seed(seed) + + if init_fn is not None: + init_fn(worker_id) + + while True: + r = index_queue.get() + if r is None: + break + idx, batch_indices = r + try: + samples = collate_fn([dataset[i] for i in batch_indices]) + except Exception: + data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) + else: + data_queue.put((idx, samples)) + + +def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): + if pin_memory: + torch.cuda.set_device(device_id) + + while True: + try: + r = in_queue.get() + except Exception: + if done_event.is_set(): + return + raise + if r is None: + break + if isinstance(r[1], ExceptionWrapper): + out_queue.put(r) + continue + idx, batch = r + try: + if pin_memory: + batch = pin_memory_batch(batch) + except Exception: + out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) + else: + out_queue.put((idx, batch)) + +numpy_type_map = { + 'float64': torch.DoubleTensor, + 'float32': torch.FloatTensor, + 'float16': torch.HalfTensor, + 'int64': torch.LongTensor, + 'int32': torch.IntTensor, + 'int16': torch.ShortTensor, + 'int8': torch.CharTensor, + 'uint8': torch.ByteTensor, +} + + +def default_collate(batch): + "Puts each data field into a tensor with outer dimension batch size" + + error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" + elem_type = type(batch[0]) + if torch.is_tensor(batch[0]): + out = None + if _use_shared_memory: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = batch[0].storage()._new_shared(numel) + out = batch[0].new(storage) + return torch.stack(batch, 0, out=out) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + elem = batch[0] + if elem_type.__name__ == 'ndarray': + # array of string classes and object + if re.search('[SaUO]', elem.dtype.str) is not None: + raise TypeError(error_msg.format(elem.dtype)) + + return torch.stack([torch.from_numpy(b) for b in batch], 0) + if elem.shape == (): # scalars + py_type = float if elem.dtype.name.startswith('float') else int + return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) + elif isinstance(batch[0], int_classes): + return torch.LongTensor(batch) + elif isinstance(batch[0], float): + return torch.DoubleTensor(batch) + elif isinstance(batch[0], string_classes): + return batch + elif isinstance(batch[0], collections.Mapping): + return {key: default_collate([d[key] for d in batch]) for key in batch[0]} + elif isinstance(batch[0], collections.Sequence): + transposed = zip(*batch) + return [default_collate(samples) for samples in transposed] + + raise TypeError((error_msg.format(type(batch[0])))) + + +def pin_memory_batch(batch): + if torch.is_tensor(batch): + return batch.pin_memory() + elif isinstance(batch, string_classes): + return batch + elif isinstance(batch, collections.Mapping): + return {k: pin_memory_batch(sample) for k, sample in batch.items()} + elif isinstance(batch, collections.Sequence): + return [pin_memory_batch(sample) for sample in batch] + else: + return batch + + +_SIGCHLD_handler_set = False +"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one +handler needs to be set for all DataLoaders in a process.""" + + +def _set_SIGCHLD_handler(): + # Windows doesn't support SIGCHLD handler + if sys.platform == 'win32': + return + # can't set signal in child threads + if not isinstance(threading.current_thread(), threading._MainThread): + return + global _SIGCHLD_handler_set + if _SIGCHLD_handler_set: + return + previous_handler = signal.getsignal(signal.SIGCHLD) + if not callable(previous_handler): + previous_handler = None + + def handler(signum, frame): + # This following call uses `waitid` with WNOHANG from C side. Therefore, + # Python can still get and update the process status successfully. + _error_if_any_worker_fails() + if previous_handler is not None: + previous_handler(signum, frame) + + signal.signal(signal.SIGCHLD, handler) + _SIGCHLD_handler_set = True + + +class DataLoaderIter(object): + "Iterates once over the DataLoader's dataset, as specified by the sampler" + + def __init__(self, loader): + self.dataset = loader.dataset + self.collate_fn = loader.collate_fn + self.batch_sampler = loader.batch_sampler + self.num_workers = loader.num_workers + self.pin_memory = loader.pin_memory and torch.cuda.is_available() + self.timeout = loader.timeout + self.done_event = threading.Event() + + self.sample_iter = iter(self.batch_sampler) + + if self.num_workers > 0: + self.worker_init_fn = loader.worker_init_fn + self.index_queue = multiprocessing.SimpleQueue() + self.worker_result_queue = multiprocessing.SimpleQueue() + self.batches_outstanding = 0 + self.worker_pids_set = False + self.shutdown = False + self.send_idx = 0 + self.rcvd_idx = 0 + self.reorder_dict = {} + + base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0] + self.workers = [ + multiprocessing.Process( + target=_worker_loop, + args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, + base_seed + i, self.worker_init_fn, i)) + for i in range(self.num_workers)] + + if self.pin_memory or self.timeout > 0: + self.data_queue = queue.Queue() + if self.pin_memory: + maybe_device_id = torch.cuda.current_device() + else: + # do not initialize cuda context if not necessary + maybe_device_id = None + self.worker_manager_thread = threading.Thread( + target=_worker_manager_loop, + args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, + maybe_device_id)) + self.worker_manager_thread.daemon = True + self.worker_manager_thread.start() + else: + self.data_queue = self.worker_result_queue + + for w in self.workers: + w.daemon = True # ensure that the worker exits on process exit + w.start() + + _set_worker_pids(id(self), tuple(w.pid for w in self.workers)) + _set_SIGCHLD_handler() + self.worker_pids_set = True + + # prime the prefetch loop + for _ in range(2 * self.num_workers): + self._put_indices() + + def __len__(self): + return len(self.batch_sampler) + + def _get_batch(self): + if self.timeout > 0: + try: + return self.data_queue.get(timeout=self.timeout) + except queue.Empty: + raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) + else: + return self.data_queue.get() + + def __next__(self): + if self.num_workers == 0: # same-process loading + indices = next(self.sample_iter) # may raise StopIteration + batch = self.collate_fn([self.dataset[i] for i in indices]) + if self.pin_memory: + batch = pin_memory_batch(batch) + return batch + + # check if the next sample has already been generated + if self.rcvd_idx in self.reorder_dict: + batch = self.reorder_dict.pop(self.rcvd_idx) + return self._process_next_batch(batch) + + if self.batches_outstanding == 0: + self._shutdown_workers() + raise StopIteration + + while True: + assert (not self.shutdown and self.batches_outstanding > 0) + idx, batch = self._get_batch() + self.batches_outstanding -= 1 + if idx != self.rcvd_idx: + # store out-of-order samples + self.reorder_dict[idx] = batch + continue + return self._process_next_batch(batch) + + next = __next__ # Python 2 compatibility + + def __iter__(self): + return self + + def _put_indices(self): + assert self.batches_outstanding < 2 * self.num_workers + indices = next(self.sample_iter, None) + if indices is None: + return + self.index_queue.put((self.send_idx, indices)) + self.batches_outstanding += 1 + self.send_idx += 1 + + def _process_next_batch(self, batch): + self.rcvd_idx += 1 + self._put_indices() + if isinstance(batch, ExceptionWrapper): + raise batch.exc_type(batch.exc_msg) + return batch + + def __getstate__(self): + # TODO: add limited pickling support for sharing an iterator + # across multiple threads for HOGWILD. + # Probably the best way to do this is by moving the sample pushing + # to a separate thread and then just sharing the data queue + # but signalling the end is tricky without a non-blocking API + raise NotImplementedError("DataLoaderIterator cannot be pickled") + + def _shutdown_workers(self): + try: + if not self.shutdown: + self.shutdown = True + self.done_event.set() + # if worker_manager_thread is waiting to put + while not self.data_queue.empty(): + self.data_queue.get() + for _ in self.workers: + self.index_queue.put(None) + # done_event should be sufficient to exit worker_manager_thread, + # but be safe here and put another None + self.worker_result_queue.put(None) + finally: + # removes pids no matter what + if self.worker_pids_set: + _remove_worker_pids(id(self)) + self.worker_pids_set = False + + def __del__(self): + if self.num_workers > 0: + self._shutdown_workers() + + +class DataLoader(object): + """ + Data loader. Combines a dataset and a sampler, and provides + single- or multi-process iterators over the dataset. + + Arguments: + dataset (Dataset): dataset from which to load the data. + batch_size (int, optional): how many samples per batch to load + (default: 1). + shuffle (bool, optional): set to ``True`` to have the data reshuffled + at every epoch (default: False). + sampler (Sampler, optional): defines the strategy to draw samples from + the dataset. If specified, ``shuffle`` must be False. + batch_sampler (Sampler, optional): like sampler, but returns a batch of + indices at a time. Mutually exclusive with batch_size, shuffle, + sampler, and drop_last. + num_workers (int, optional): how many subprocesses to use for data + loading. 0 means that the data will be loaded in the main process. + (default: 0) + collate_fn (callable, optional): merges a list of samples to form a mini-batch. + pin_memory (bool, optional): If ``True``, the data loader will copy tensors + into CUDA pinned memory before returning them. + drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, + if the dataset size is not divisible by the batch size. If ``False`` and + the size of dataset is not divisible by the batch size, then the last batch + will be smaller. (default: False) + timeout (numeric, optional): if positive, the timeout value for collecting a batch + from workers. Should always be non-negative. (default: 0) + worker_init_fn (callable, optional): If not None, this will be called on each + worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as + input, after seeding and before data loading. (default: None) + + .. note:: By default, each worker will have its PyTorch seed set to + ``base_seed + worker_id``, where ``base_seed`` is a long generated + by main process using its RNG. You may use ``torch.initial_seed()`` to access + this value in :attr:`worker_init_fn`, which can be used to set other seeds + (e.g. NumPy) before data loading. + + .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an + unpicklable object, e.g., a lambda function. + """ + + def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, + num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, + timeout=0, worker_init_fn=None): + self.dataset = dataset + self.batch_size = batch_size + self.num_workers = num_workers + self.collate_fn = collate_fn + self.pin_memory = pin_memory + self.drop_last = drop_last + self.timeout = timeout + self.worker_init_fn = worker_init_fn + + if timeout < 0: + raise ValueError('timeout option should be non-negative') + + if batch_sampler is not None: + if batch_size > 1 or shuffle or sampler is not None or drop_last: + raise ValueError('batch_sampler is mutually exclusive with ' + 'batch_size, shuffle, sampler, and drop_last') + + if sampler is not None and shuffle: + raise ValueError('sampler is mutually exclusive with shuffle') + + if self.num_workers < 0: + raise ValueError('num_workers cannot be negative; ' + 'use num_workers=0 to disable multiprocessing.') + + if batch_sampler is None: + if sampler is None: + if shuffle: + sampler = RandomSampler(dataset) + else: + sampler = SequentialSampler(dataset) + batch_sampler = BatchSampler(sampler, batch_size, drop_last) + + self.sampler = sampler + self.batch_sampler = batch_sampler + + def __iter__(self): + return DataLoaderIter(self) + + def __len__(self): + return len(self.batch_sampler) diff --git a/models/ade20k/segm_lib/utils/data/dataset.py b/models/ade20k/segm_lib/utils/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..605aa877f7031a5cd2b98c0f831410aa80fddefa --- /dev/null +++ b/models/ade20k/segm_lib/utils/data/dataset.py @@ -0,0 +1,118 @@ +import bisect +import warnings + +from torch._utils import _accumulate +from torch import randperm + + +class Dataset(object): + """An abstract class representing a Dataset. + + All other datasets should subclass it. All subclasses should override + ``__len__``, that provides the size of the dataset, and ``__getitem__``, + supporting integer indexing in range from 0 to len(self) exclusive. + """ + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def __add__(self, other): + return ConcatDataset([self, other]) + + +class TensorDataset(Dataset): + """Dataset wrapping data and target tensors. + + Each sample will be retrieved by indexing both tensors along the first + dimension. + + Arguments: + data_tensor (Tensor): contains sample data. + target_tensor (Tensor): contains sample targets (labels). + """ + + def __init__(self, data_tensor, target_tensor): + assert data_tensor.size(0) == target_tensor.size(0) + self.data_tensor = data_tensor + self.target_tensor = target_tensor + + def __getitem__(self, index): + return self.data_tensor[index], self.target_tensor[index] + + def __len__(self): + return self.data_tensor.size(0) + + +class ConcatDataset(Dataset): + """ + Dataset to concatenate multiple datasets. + Purpose: useful to assemble different existing datasets, possibly + large-scale datasets as the concatenation operation is done in an + on-the-fly manner. + + Arguments: + datasets (iterable): List of datasets to be concatenated + """ + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, datasets): + super(ConcatDataset, self).__init__() + assert len(datasets) > 0, 'datasets should not be an empty iterable' + self.datasets = list(datasets) + self.cumulative_sizes = self.cumsum(self.datasets) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + @property + def cummulative_sizes(self): + warnings.warn("cummulative_sizes attribute is renamed to " + "cumulative_sizes", DeprecationWarning, stacklevel=2) + return self.cumulative_sizes + + +class Subset(Dataset): + def __init__(self, dataset, indices): + self.dataset = dataset + self.indices = indices + + def __getitem__(self, idx): + return self.dataset[self.indices[idx]] + + def __len__(self): + return len(self.indices) + + +def random_split(dataset, lengths): + """ + Randomly split a dataset into non-overlapping new datasets of given lengths + ds + + Arguments: + dataset (Dataset): Dataset to be split + lengths (iterable): lengths of splits to be produced + """ + if sum(lengths) != len(dataset): + raise ValueError("Sum of input lengths does not equal the length of the input dataset!") + + indices = randperm(sum(lengths)) + return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] diff --git a/models/ade20k/segm_lib/utils/data/distributed.py b/models/ade20k/segm_lib/utils/data/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d890e28fd2b9e044bdd9494de4a43ad2471eed --- /dev/null +++ b/models/ade20k/segm_lib/utils/data/distributed.py @@ -0,0 +1,58 @@ +import math +import torch +from .sampler import Sampler +from torch.distributed import get_world_size, get_rank + + +class DistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size. + + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None): + if num_replicas is None: + num_replicas = get_world_size() + if rank is None: + rank = get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = list(torch.randperm(len(self.dataset), generator=g)) + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset:offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/models/ade20k/segm_lib/utils/data/sampler.py b/models/ade20k/segm_lib/utils/data/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..62a9a43bd1d4c21fbdcb262db7da8d4fe27b26de --- /dev/null +++ b/models/ade20k/segm_lib/utils/data/sampler.py @@ -0,0 +1,131 @@ +import torch + + +class Sampler(object): + """Base class for all Samplers. + + Every Sampler subclass has to provide an __iter__ method, providing a way + to iterate over indices of dataset elements, and a __len__ method that + returns the length of the returned iterators. + """ + + def __init__(self, data_source): + pass + + def __iter__(self): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + +class SequentialSampler(Sampler): + """Samples elements sequentially, always in the same order. + + Arguments: + data_source (Dataset): dataset to sample from + """ + + def __init__(self, data_source): + self.data_source = data_source + + def __iter__(self): + return iter(range(len(self.data_source))) + + def __len__(self): + return len(self.data_source) + + +class RandomSampler(Sampler): + """Samples elements randomly, without replacement. + + Arguments: + data_source (Dataset): dataset to sample from + """ + + def __init__(self, data_source): + self.data_source = data_source + + def __iter__(self): + return iter(torch.randperm(len(self.data_source)).long()) + + def __len__(self): + return len(self.data_source) + + +class SubsetRandomSampler(Sampler): + """Samples elements randomly from a given list of indices, without replacement. + + Arguments: + indices (list): a list of indices + """ + + def __init__(self, indices): + self.indices = indices + + def __iter__(self): + return (self.indices[i] for i in torch.randperm(len(self.indices))) + + def __len__(self): + return len(self.indices) + + +class WeightedRandomSampler(Sampler): + """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). + + Arguments: + weights (list) : a list of weights, not necessary summing up to one + num_samples (int): number of samples to draw + replacement (bool): if ``True``, samples are drawn with replacement. + If not, they are drawn without replacement, which means that when a + sample index is drawn for a row, it cannot be drawn again for that row. + """ + + def __init__(self, weights, num_samples, replacement=True): + self.weights = torch.DoubleTensor(weights) + self.num_samples = num_samples + self.replacement = replacement + + def __iter__(self): + return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) + + def __len__(self): + return self.num_samples + + +class BatchSampler(object): + """Wraps another sampler to yield a mini-batch of indices. + + Args: + sampler (Sampler): Base sampler. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + + Example: + >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + """ + + def __init__(self, sampler, batch_size, drop_last): + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self): + batch = [] + for idx in self.sampler: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + if len(batch) > 0 and not self.drop_last: + yield batch + + def __len__(self): + if self.drop_last: + return len(self.sampler) // self.batch_size + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size diff --git a/models/ade20k/segm_lib/utils/th.py b/models/ade20k/segm_lib/utils/th.py new file mode 100644 index 0000000000000000000000000000000000000000..ca6ef9385e3b5c0a439579d3fd7aa73b5dc62758 --- /dev/null +++ b/models/ade20k/segm_lib/utils/th.py @@ -0,0 +1,41 @@ +import torch +from torch.autograd import Variable +import numpy as np +import collections + +__all__ = ['as_variable', 'as_numpy', 'mark_volatile'] + +def as_variable(obj): + if isinstance(obj, Variable): + return obj + if isinstance(obj, collections.Sequence): + return [as_variable(v) for v in obj] + elif isinstance(obj, collections.Mapping): + return {k: as_variable(v) for k, v in obj.items()} + else: + return Variable(obj) + +def as_numpy(obj): + if isinstance(obj, collections.Sequence): + return [as_numpy(v) for v in obj] + elif isinstance(obj, collections.Mapping): + return {k: as_numpy(v) for k, v in obj.items()} + elif isinstance(obj, Variable): + return obj.data.cpu().numpy() + elif torch.is_tensor(obj): + return obj.cpu().numpy() + else: + return np.array(obj) + +def mark_volatile(obj): + if torch.is_tensor(obj): + obj = Variable(obj) + if isinstance(obj, Variable): + obj.no_grad = True + return obj + elif isinstance(obj, collections.Mapping): + return {k: mark_volatile(o) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [mark_volatile(o) for o in obj] + else: + return obj diff --git a/models/ade20k/utils.py b/models/ade20k/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f337db7db54c82be041698d694e1403e8918c4c0 --- /dev/null +++ b/models/ade20k/utils.py @@ -0,0 +1,40 @@ +"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch""" + +import os +import sys + +import numpy as np +import torch + +try: + from urllib import urlretrieve +except ImportError: + from urllib.request import urlretrieve + + +def load_url(url, model_dir='./pretrained', map_location=None): + if not os.path.exists(model_dir): + os.makedirs(model_dir) + filename = url.split('/')[-1] + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + urlretrieve(url, cached_file) + return torch.load(cached_file, map_location=map_location) + + +def color_encode(labelmap, colors, mode='RGB'): + labelmap = labelmap.astype('int') + labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), + dtype=np.uint8) + for label in np.unique(labelmap): + if label < 0: + continue + labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ + np.tile(colors[label], + (labelmap.shape[0], labelmap.shape[1], 1)) + + if mode == 'BGR': + return labelmap_rgb[:, :, ::-1] + else: + return labelmap_rgb diff --git a/models/lpips_models/alex.pth b/models/lpips_models/alex.pth new file mode 100644 index 0000000000000000000000000000000000000000..fa4067abc5d4da16a7204fd94776506e4868030e --- /dev/null +++ b/models/lpips_models/alex.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0 +size 6009 diff --git a/models/lpips_models/squeeze.pth b/models/lpips_models/squeeze.pth new file mode 100644 index 0000000000000000000000000000000000000000..f892a84a130828b1c9e2e8156e84fc5a962c665d --- /dev/null +++ b/models/lpips_models/squeeze.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76 +size 10811 diff --git a/models/lpips_models/vgg.pth b/models/lpips_models/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..f57dcf5cc764d61c8a460365847fb2137ff0a62d --- /dev/null +++ b/models/lpips_models/vgg.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 +size 7289 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..d7dea64b874cae0cdafdc65a43f6ae71ea47ca0c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,4 @@ +[tool.pyright] +venvPath = "." +venv = "env" +exclude = "env" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..306b540e0bf1eab49e9a661854e536441bad3788 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,29 @@ +aioredis==1.3.1 +boto3==1.24.61 +triton==2.0.0 +diffusers==0.17.1 +fastapi==0.87.0 +Pillow==9.3.0 +redis==4.3.4 +requests==2.28.1 +transformers +rembg==2.0.30 +accelerate==0.17.0 +gfpgan==1.3.8 +rembg==2.0.30 +controlnet-aux==0.0.5 +realesrgan==0.3.0 +compel==1.0.4 +scikit-learn==0.24.2 +easydict==1.9.0 +albumentations==0.5.2 +kornia==0.5.0 +pytorch-lightning==1.2.9 +pydash +pandas +xformers +torchvision +scikit-image +omegaconf +webdataset +git+https://github.com/cloneofsimo/lora.git diff --git a/saicinpainting/__init__.py b/saicinpainting/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/saicinpainting/evaluation/__init__.py b/saicinpainting/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c8117565b252ca069a808b31b8c52aaddd2289 --- /dev/null +++ b/saicinpainting/evaluation/__init__.py @@ -0,0 +1,33 @@ +import logging + +import torch + +from saicinpainting.evaluation.evaluator import InpaintingEvaluatorOnline, ssim_fid100_f1, lpips_fid100_f1 +from saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore + + +def make_evaluator(kind='default', ssim=True, lpips=True, fid=True, integral_kind=None, **kwargs): + logging.info(f'Make evaluator {kind}') + device = "cuda" if torch.cuda.is_available() else "cpu" + metrics = {} + if ssim: + metrics['ssim'] = SSIMScore() + if lpips: + metrics['lpips'] = LPIPSScore() + if fid: + metrics['fid'] = FIDScore().to(device) + + if integral_kind is None: + integral_func = None + elif integral_kind == 'ssim_fid100_f1': + integral_func = ssim_fid100_f1 + elif integral_kind == 'lpips_fid100_f1': + integral_func = lpips_fid100_f1 + else: + raise ValueError(f'Unexpected integral_kind={integral_kind}') + + if kind == 'default': + return InpaintingEvaluatorOnline(scores=metrics, + integral_func=integral_func, + integral_title=integral_kind, + **kwargs) diff --git a/saicinpainting/evaluation/data.py b/saicinpainting/evaluation/data.py new file mode 100644 index 0000000000000000000000000000000000000000..69ddb8d3c12d0261e459f7c4f66a702d0c477df0 --- /dev/null +++ b/saicinpainting/evaluation/data.py @@ -0,0 +1,167 @@ +import glob +import os + +import cv2 +import PIL.Image as Image +import numpy as np + +from torch.utils.data import Dataset +import torch.nn.functional as F + + +def load_image(fname, mode='RGB', return_orig=False): + img = np.array(Image.open(fname).convert(mode)) + if img.ndim == 3: + img = np.transpose(img, (2, 0, 1)) + out_img = img.astype('float32') / 255 + if return_orig: + return out_img, img + else: + return out_img + + +def ceil_modulo(x, mod): + if x % mod == 0: + return x + return (x // mod + 1) * mod + + +def pad_img_to_modulo(img, mod): + channels, height, width = img.shape + out_height = ceil_modulo(height, mod) + out_width = ceil_modulo(width, mod) + return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric') + + +def pad_tensor_to_modulo(img, mod): + batch_size, channels, height, width = img.shape + out_height = ceil_modulo(height, mod) + out_width = ceil_modulo(width, mod) + return F.pad(img, pad=(0, out_width - width, 0, out_height - height), mode='reflect') + + +def scale_image(img, factor, interpolation=cv2.INTER_AREA): + if img.shape[0] == 1: + img = img[0] + else: + img = np.transpose(img, (1, 2, 0)) + + img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation) + + if img.ndim == 2: + img = img[None, ...] + else: + img = np.transpose(img, (2, 0, 1)) + return img + + +class InpaintingDataset(Dataset): + def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None): + self.datadir = datadir + self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, '**', '*mask*.png'), recursive=True))) + self.img_filenames = [fname.rsplit('_mask', 1)[0] + img_suffix for fname in self.mask_filenames] + self.pad_out_to_modulo = pad_out_to_modulo + self.scale_factor = scale_factor + + def __len__(self): + return len(self.mask_filenames) + + def __getitem__(self, i): + image = load_image(self.img_filenames[i], mode='RGB') + mask = load_image(self.mask_filenames[i], mode='L') + result = dict(image=image, mask=mask[None, ...]) + + if self.scale_factor is not None: + result['image'] = scale_image(result['image'], self.scale_factor) + result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST) + + if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: + result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo) + result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo) + + return result + +class OurInpaintingDataset(Dataset): + def __init__(self, datadir, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None): + self.datadir = datadir + self.mask_filenames = sorted(list(glob.glob(os.path.join(self.datadir, 'mask', '**', '*mask*.png'), recursive=True))) + self.img_filenames = [os.path.join(self.datadir, 'img', os.path.basename(fname.rsplit('-', 1)[0].rsplit('_', 1)[0]) + '.png') for fname in self.mask_filenames] + self.pad_out_to_modulo = pad_out_to_modulo + self.scale_factor = scale_factor + + def __len__(self): + return len(self.mask_filenames) + + def __getitem__(self, i): + result = dict(image=load_image(self.img_filenames[i], mode='RGB'), + mask=load_image(self.mask_filenames[i], mode='L')[None, ...]) + + if self.scale_factor is not None: + result['image'] = scale_image(result['image'], self.scale_factor) + result['mask'] = scale_image(result['mask'], self.scale_factor) + + if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: + result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo) + result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo) + + return result + +class PrecomputedInpaintingResultsDataset(InpaintingDataset): + def __init__(self, datadir, predictdir, inpainted_suffix='_inpainted.jpg', **kwargs): + super().__init__(datadir, **kwargs) + if not datadir.endswith('/'): + datadir += '/' + self.predictdir = predictdir + self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix) + for fname in self.mask_filenames] + + def __getitem__(self, i): + result = super().__getitem__(i) + result['inpainted'] = load_image(self.pred_filenames[i]) + if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: + result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo) + return result + +class OurPrecomputedInpaintingResultsDataset(OurInpaintingDataset): + def __init__(self, datadir, predictdir, inpainted_suffix="png", **kwargs): + super().__init__(datadir, **kwargs) + if not datadir.endswith('/'): + datadir += '/' + self.predictdir = predictdir + self.pred_filenames = [os.path.join(predictdir, os.path.basename(os.path.splitext(fname)[0]) + f'_inpainted.{inpainted_suffix}') + for fname in self.mask_filenames] + # self.pred_filenames = [os.path.join(predictdir, os.path.splitext(fname[len(datadir):])[0] + inpainted_suffix) + # for fname in self.mask_filenames] + + def __getitem__(self, i): + result = super().__getitem__(i) + result['inpainted'] = self.file_loader(self.pred_filenames[i]) + + if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: + result['inpainted'] = pad_img_to_modulo(result['inpainted'], self.pad_out_to_modulo) + return result + +class InpaintingEvalOnlineDataset(Dataset): + def __init__(self, indir, mask_generator, img_suffix='.jpg', pad_out_to_modulo=None, scale_factor=None, **kwargs): + self.indir = indir + self.mask_generator = mask_generator + self.img_filenames = sorted(list(glob.glob(os.path.join(self.indir, '**', f'*{img_suffix}' ), recursive=True))) + self.pad_out_to_modulo = pad_out_to_modulo + self.scale_factor = scale_factor + + def __len__(self): + return len(self.img_filenames) + + def __getitem__(self, i): + img, raw_image = load_image(self.img_filenames[i], mode='RGB', return_orig=True) + mask = self.mask_generator(img, raw_image=raw_image) + result = dict(image=img, mask=mask) + + if self.scale_factor is not None: + result['image'] = scale_image(result['image'], self.scale_factor) + result['mask'] = scale_image(result['mask'], self.scale_factor, interpolation=cv2.INTER_NEAREST) + + if self.pad_out_to_modulo is not None and self.pad_out_to_modulo > 1: + result['image'] = pad_img_to_modulo(result['image'], self.pad_out_to_modulo) + result['mask'] = pad_img_to_modulo(result['mask'], self.pad_out_to_modulo) + return result \ No newline at end of file diff --git a/saicinpainting/evaluation/evaluator.py b/saicinpainting/evaluation/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..aa9e80402633c08a580929b38a5cb695cb7171d8 --- /dev/null +++ b/saicinpainting/evaluation/evaluator.py @@ -0,0 +1,220 @@ +import logging +import math +from typing import Dict + +import numpy as np +import torch +import torch.nn as nn +import tqdm +from torch.utils.data import DataLoader + +from saicinpainting.evaluation.utils import move_to_device + +LOGGER = logging.getLogger(__name__) + + +class InpaintingEvaluator(): + def __init__(self, dataset, scores, area_grouping=True, bins=10, batch_size=32, device='cuda', + integral_func=None, integral_title=None, clamp_image_range=None): + """ + :param dataset: torch.utils.data.Dataset which contains images and masks + :param scores: dict {score_name: EvaluatorScore object} + :param area_grouping: in addition to the overall scores, allows to compute score for the groups of samples + which are defined by share of area occluded by mask + :param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1) + :param batch_size: batch_size for the dataloader + :param device: device to use + """ + self.scores = scores + self.dataset = dataset + + self.area_grouping = area_grouping + self.bins = bins + + self.device = torch.device(device) + + self.dataloader = DataLoader(self.dataset, shuffle=False, batch_size=batch_size) + + self.integral_func = integral_func + self.integral_title = integral_title + self.clamp_image_range = clamp_image_range + + def _get_bin_edges(self): + bin_edges = np.linspace(0, 1, self.bins + 1) + + num_digits = max(0, math.ceil(math.log10(self.bins)) - 1) + interval_names = [] + for idx_bin in range(self.bins): + start_percent, end_percent = round(100 * bin_edges[idx_bin], num_digits), \ + round(100 * bin_edges[idx_bin + 1], num_digits) + start_percent = '{:.{n}f}'.format(start_percent, n=num_digits) + end_percent = '{:.{n}f}'.format(end_percent, n=num_digits) + interval_names.append("{0}-{1}%".format(start_percent, end_percent)) + + groups = [] + for batch in self.dataloader: + mask = batch['mask'] + batch_size = mask.shape[0] + area = mask.to(self.device).reshape(batch_size, -1).mean(dim=-1) + bin_indices = np.searchsorted(bin_edges, area.detach().cpu().numpy(), side='right') - 1 + # corner case: when area is equal to 1, bin_indices should return bins - 1, not bins for that element + bin_indices[bin_indices == self.bins] = self.bins - 1 + groups.append(bin_indices) + groups = np.hstack(groups) + + return groups, interval_names + + def evaluate(self, model=None): + """ + :param model: callable with signature (image_batch, mask_batch); should return inpainted_batch + :return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or + name of the particular group arranged by area of mask (e.g. '10-20%') + and score statistics for the group as values. + """ + results = dict() + if self.area_grouping: + groups, interval_names = self._get_bin_edges() + else: + groups = None + + for score_name, score in tqdm.auto.tqdm(self.scores.items(), desc='scores'): + score.to(self.device) + with torch.no_grad(): + score.reset() + for batch in tqdm.auto.tqdm(self.dataloader, desc=score_name, leave=False): + batch = move_to_device(batch, self.device) + image_batch, mask_batch = batch['image'], batch['mask'] + if self.clamp_image_range is not None: + image_batch = torch.clamp(image_batch, + min=self.clamp_image_range[0], + max=self.clamp_image_range[1]) + if model is None: + assert 'inpainted' in batch, \ + 'Model is None, so we expected precomputed inpainting results at key "inpainted"' + inpainted_batch = batch['inpainted'] + else: + inpainted_batch = model(image_batch, mask_batch) + score(inpainted_batch, image_batch, mask_batch) + total_results, group_results = score.get_value(groups=groups) + + results[(score_name, 'total')] = total_results + if groups is not None: + for group_index, group_values in group_results.items(): + group_name = interval_names[group_index] + results[(score_name, group_name)] = group_values + + if self.integral_func is not None: + results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results)) + + return results + + +def ssim_fid100_f1(metrics, fid_scale=100): + ssim = metrics[('ssim', 'total')]['mean'] + fid = metrics[('fid', 'total')]['mean'] + fid_rel = max(0, fid_scale - fid) / fid_scale + f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3) + return f1 + + +def lpips_fid100_f1(metrics, fid_scale=100): + neg_lpips = 1 - metrics[('lpips', 'total')]['mean'] # invert, so bigger is better + fid = metrics[('fid', 'total')]['mean'] + fid_rel = max(0, fid_scale - fid) / fid_scale + f1 = 2 * neg_lpips * fid_rel / (neg_lpips + fid_rel + 1e-3) + return f1 + + + +class InpaintingEvaluatorOnline(nn.Module): + def __init__(self, scores, bins=10, image_key='image', inpainted_key='inpainted', + integral_func=None, integral_title=None, clamp_image_range=None): + """ + :param scores: dict {score_name: EvaluatorScore object} + :param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1) + :param device: device to use + """ + super().__init__() + LOGGER.info(f'{type(self)} init called') + self.scores = nn.ModuleDict(scores) + self.image_key = image_key + self.inpainted_key = inpainted_key + self.bins_num = bins + self.bin_edges = np.linspace(0, 1, self.bins_num + 1) + + num_digits = max(0, math.ceil(math.log10(self.bins_num)) - 1) + self.interval_names = [] + for idx_bin in range(self.bins_num): + start_percent, end_percent = round(100 * self.bin_edges[idx_bin], num_digits), \ + round(100 * self.bin_edges[idx_bin + 1], num_digits) + start_percent = '{:.{n}f}'.format(start_percent, n=num_digits) + end_percent = '{:.{n}f}'.format(end_percent, n=num_digits) + self.interval_names.append("{0}-{1}%".format(start_percent, end_percent)) + + self.groups = [] + + self.integral_func = integral_func + self.integral_title = integral_title + self.clamp_image_range = clamp_image_range + + LOGGER.info(f'{type(self)} init done') + + def _get_bins(self, mask_batch): + batch_size = mask_batch.shape[0] + area = mask_batch.view(batch_size, -1).mean(dim=-1).detach().cpu().numpy() + bin_indices = np.clip(np.searchsorted(self.bin_edges, area) - 1, 0, self.bins_num - 1) + return bin_indices + + def forward(self, batch: Dict[str, torch.Tensor]): + """ + Calculate and accumulate metrics for batch. To finalize evaluation and obtain final metrics, call evaluation_end + :param batch: batch dict with mandatory fields mask, image, inpainted (can be overriden by self.inpainted_key) + """ + result = {} + with torch.no_grad(): + image_batch, mask_batch, inpainted_batch = batch[self.image_key], batch['mask'], batch[self.inpainted_key] + if self.clamp_image_range is not None: + image_batch = torch.clamp(image_batch, + min=self.clamp_image_range[0], + max=self.clamp_image_range[1]) + self.groups.extend(self._get_bins(mask_batch)) + + for score_name, score in self.scores.items(): + result[score_name] = score(inpainted_batch, image_batch, mask_batch) + return result + + def process_batch(self, batch: Dict[str, torch.Tensor]): + return self(batch) + + def evaluation_end(self, states=None): + """:return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or + name of the particular group arranged by area of mask (e.g. '10-20%') + and score statistics for the group as values. + """ + LOGGER.info(f'{type(self)}: evaluation_end called') + + self.groups = np.array(self.groups) + + results = {} + for score_name, score in self.scores.items(): + LOGGER.info(f'Getting value of {score_name}') + cur_states = [s[score_name] for s in states] if states is not None else None + total_results, group_results = score.get_value(groups=self.groups, states=cur_states) + LOGGER.info(f'Getting value of {score_name} done') + results[(score_name, 'total')] = total_results + + for group_index, group_values in group_results.items(): + group_name = self.interval_names[group_index] + results[(score_name, group_name)] = group_values + + if self.integral_func is not None: + results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results)) + + LOGGER.info(f'{type(self)}: reset scores') + self.groups = [] + for sc in self.scores.values(): + sc.reset() + LOGGER.info(f'{type(self)}: reset scores done') + + LOGGER.info(f'{type(self)}: evaluation_end done') + return results diff --git a/saicinpainting/evaluation/losses/__init__.py b/saicinpainting/evaluation/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/saicinpainting/evaluation/losses/base_loss.py b/saicinpainting/evaluation/losses/base_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..391191ce2ed8665f1f15bd3877dc22bb85b147d6 --- /dev/null +++ b/saicinpainting/evaluation/losses/base_loss.py @@ -0,0 +1,528 @@ +import logging +from abc import abstractmethod, ABC + +import numpy as np +import sklearn +import sklearn.svm +import torch +import torch.nn as nn +import torch.nn.functional as F +from joblib import Parallel, delayed +from scipy import linalg + +from models.ade20k import SegmentationModule, NUM_CLASS, segm_options +from .fid.inception import InceptionV3 +from .lpips import PerceptualLoss +from .ssim import SSIM + +LOGGER = logging.getLogger(__name__) + + +def get_groupings(groups): + """ + :param groups: group numbers for respective elements + :return: dict of kind {group_idx: indices of the corresponding group elements} + """ + label_groups, count_groups = np.unique(groups, return_counts=True) + + indices = np.argsort(groups) + + grouping = dict() + cur_start = 0 + for label, count in zip(label_groups, count_groups): + cur_end = cur_start + count + cur_indices = indices[cur_start:cur_end] + grouping[label] = cur_indices + cur_start = cur_end + return grouping + + +class EvaluatorScore(nn.Module): + @abstractmethod + def forward(self, pred_batch, target_batch, mask): + pass + + @abstractmethod + def get_value(self, groups=None, states=None): + pass + + @abstractmethod + def reset(self): + pass + + +class PairwiseScore(EvaluatorScore, ABC): + def __init__(self): + super().__init__() + self.individual_values = None + + def get_value(self, groups=None, states=None): + """ + :param groups: + :return: + total_results: dict of kind {'mean': score mean, 'std': score std} + group_results: None, if groups is None; + else dict {group_idx: {'mean': score mean among group, 'std': score std among group}} + """ + individual_values = torch.stack(states, dim=0).reshape(-1).cpu().numpy() if states is not None \ + else self.individual_values + + total_results = { + 'mean': individual_values.mean(), + 'std': individual_values.std() + } + + if groups is None: + return total_results, None + + group_results = dict() + grouping = get_groupings(groups) + for label, index in grouping.items(): + group_scores = individual_values[index] + group_results[label] = { + 'mean': group_scores.mean(), + 'std': group_scores.std() + } + return total_results, group_results + + def reset(self): + self.individual_values = [] + + +class SSIMScore(PairwiseScore): + def __init__(self, window_size=11): + super().__init__() + self.score = SSIM(window_size=window_size, size_average=False).eval() + self.reset() + + def forward(self, pred_batch, target_batch, mask=None): + batch_values = self.score(pred_batch, target_batch) + self.individual_values = np.hstack([ + self.individual_values, batch_values.detach().cpu().numpy() + ]) + return batch_values + + +class LPIPSScore(PairwiseScore): + def __init__(self, model='net-lin', net='vgg', model_path=None, use_gpu=True): + super().__init__() + self.score = PerceptualLoss(model=model, net=net, model_path=model_path, + use_gpu=use_gpu, spatial=False).eval() + self.reset() + + def forward(self, pred_batch, target_batch, mask=None): + batch_values = self.score(pred_batch, target_batch).flatten() + self.individual_values = np.hstack([ + self.individual_values, batch_values.detach().cpu().numpy() + ]) + return batch_values + + +def fid_calculate_activation_statistics(act): + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +def calculate_frechet_distance(activations_pred, activations_target, eps=1e-6): + mu1, sigma1 = fid_calculate_activation_statistics(activations_pred) + mu2, sigma2 = fid_calculate_activation_statistics(activations_target) + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + LOGGER.warning(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + # if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + +class FIDScore(EvaluatorScore): + def __init__(self, dims=2048, eps=1e-6): + LOGGER.info("FIDscore init called") + super().__init__() + if getattr(FIDScore, '_MODEL', None) is None: + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + FIDScore._MODEL = InceptionV3([block_idx]).eval() + self.model = FIDScore._MODEL + self.eps = eps + self.reset() + LOGGER.info("FIDscore init done") + + def forward(self, pred_batch, target_batch, mask=None): + activations_pred = self._get_activations(pred_batch) + activations_target = self._get_activations(target_batch) + + self.activations_pred.append(activations_pred.detach().cpu()) + self.activations_target.append(activations_target.detach().cpu()) + + return activations_pred, activations_target + + def get_value(self, groups=None, states=None): + LOGGER.info("FIDscore get_value called") + activations_pred, activations_target = zip(*states) if states is not None \ + else (self.activations_pred, self.activations_target) + activations_pred = torch.cat(activations_pred).cpu().numpy() + activations_target = torch.cat(activations_target).cpu().numpy() + + total_distance = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps) + total_results = dict(mean=total_distance) + + if groups is None: + group_results = None + else: + group_results = dict() + grouping = get_groupings(groups) + for label, index in grouping.items(): + if len(index) > 1: + group_distance = calculate_frechet_distance(activations_pred[index], activations_target[index], + eps=self.eps) + group_results[label] = dict(mean=group_distance) + + else: + group_results[label] = dict(mean=float('nan')) + + self.reset() + + LOGGER.info("FIDscore get_value done") + + return total_results, group_results + + def reset(self): + self.activations_pred = [] + self.activations_target = [] + + def _get_activations(self, batch): + activations = self.model(batch)[0] + if activations.shape[2] != 1 or activations.shape[3] != 1: + assert False, \ + 'We should not have got here, because Inception always scales inputs to 299x299' + # activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1)) + activations = activations.squeeze(-1).squeeze(-1) + return activations + + +class SegmentationAwareScore(EvaluatorScore): + def __init__(self, weights_path): + super().__init__() + self.segm_network = SegmentationModule(weights_path=weights_path, use_default_normalization=True).eval() + self.target_class_freq_by_image_total = [] + self.target_class_freq_by_image_mask = [] + self.pred_class_freq_by_image_mask = [] + + def forward(self, pred_batch, target_batch, mask): + pred_segm_flat = self.segm_network.predict(pred_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy() + target_segm_flat = self.segm_network.predict(target_batch)[0].view(pred_batch.shape[0], -1).long().detach().cpu().numpy() + mask_flat = (mask.view(mask.shape[0], -1) > 0.5).detach().cpu().numpy() + + batch_target_class_freq_total = [] + batch_target_class_freq_mask = [] + batch_pred_class_freq_mask = [] + + for cur_pred_segm, cur_target_segm, cur_mask in zip(pred_segm_flat, target_segm_flat, mask_flat): + cur_target_class_freq_total = np.bincount(cur_target_segm, minlength=NUM_CLASS)[None, ...] + cur_target_class_freq_mask = np.bincount(cur_target_segm[cur_mask], minlength=NUM_CLASS)[None, ...] + cur_pred_class_freq_mask = np.bincount(cur_pred_segm[cur_mask], minlength=NUM_CLASS)[None, ...] + + self.target_class_freq_by_image_total.append(cur_target_class_freq_total) + self.target_class_freq_by_image_mask.append(cur_target_class_freq_mask) + self.pred_class_freq_by_image_mask.append(cur_pred_class_freq_mask) + + batch_target_class_freq_total.append(cur_target_class_freq_total) + batch_target_class_freq_mask.append(cur_target_class_freq_mask) + batch_pred_class_freq_mask.append(cur_pred_class_freq_mask) + + batch_target_class_freq_total = np.concatenate(batch_target_class_freq_total, axis=0) + batch_target_class_freq_mask = np.concatenate(batch_target_class_freq_mask, axis=0) + batch_pred_class_freq_mask = np.concatenate(batch_pred_class_freq_mask, axis=0) + return batch_target_class_freq_total, batch_target_class_freq_mask, batch_pred_class_freq_mask + + def reset(self): + super().reset() + self.target_class_freq_by_image_total = [] + self.target_class_freq_by_image_mask = [] + self.pred_class_freq_by_image_mask = [] + + +def distribute_values_to_classes(target_class_freq_by_image_mask, values, idx2name): + assert target_class_freq_by_image_mask.ndim == 2 and target_class_freq_by_image_mask.shape[0] == values.shape[0] + total_class_freq = target_class_freq_by_image_mask.sum(0) + distr_values = (target_class_freq_by_image_mask * values[..., None]).sum(0) + result = distr_values / (total_class_freq + 1e-3) + return {idx2name[i]: val for i, val in enumerate(result) if total_class_freq[i] > 0} + + +def get_segmentation_idx2name(): + return {i - 1: name for i, name in segm_options['classes'].set_index('Idx', drop=True)['Name'].to_dict().items()} + + +class SegmentationAwarePairwiseScore(SegmentationAwareScore): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.individual_values = [] + self.segm_idx2name = get_segmentation_idx2name() + + def forward(self, pred_batch, target_batch, mask): + cur_class_stats = super().forward(pred_batch, target_batch, mask) + score_values = self.calc_score(pred_batch, target_batch, mask) + self.individual_values.append(score_values) + return cur_class_stats + (score_values,) + + @abstractmethod + def calc_score(self, pred_batch, target_batch, mask): + raise NotImplementedError() + + def get_value(self, groups=None, states=None): + """ + :param groups: + :return: + total_results: dict of kind {'mean': score mean, 'std': score std} + group_results: None, if groups is None; + else dict {group_idx: {'mean': score mean among group, 'std': score std among group}} + """ + if states is not None: + (target_class_freq_by_image_total, + target_class_freq_by_image_mask, + pred_class_freq_by_image_mask, + individual_values) = states + else: + target_class_freq_by_image_total = self.target_class_freq_by_image_total + target_class_freq_by_image_mask = self.target_class_freq_by_image_mask + pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask + individual_values = self.individual_values + + target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0) + target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0) + pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0) + individual_values = np.concatenate(individual_values, axis=0) + + total_results = { + 'mean': individual_values.mean(), + 'std': individual_values.std(), + **distribute_values_to_classes(target_class_freq_by_image_mask, individual_values, self.segm_idx2name) + } + + if groups is None: + return total_results, None + + group_results = dict() + grouping = get_groupings(groups) + for label, index in grouping.items(): + group_class_freq = target_class_freq_by_image_mask[index] + group_scores = individual_values[index] + group_results[label] = { + 'mean': group_scores.mean(), + 'std': group_scores.std(), + ** distribute_values_to_classes(group_class_freq, group_scores, self.segm_idx2name) + } + return total_results, group_results + + def reset(self): + super().reset() + self.individual_values = [] + + +class SegmentationClassStats(SegmentationAwarePairwiseScore): + def calc_score(self, pred_batch, target_batch, mask): + return 0 + + def get_value(self, groups=None, states=None): + """ + :param groups: + :return: + total_results: dict of kind {'mean': score mean, 'std': score std} + group_results: None, if groups is None; + else dict {group_idx: {'mean': score mean among group, 'std': score std among group}} + """ + if states is not None: + (target_class_freq_by_image_total, + target_class_freq_by_image_mask, + pred_class_freq_by_image_mask, + _) = states + else: + target_class_freq_by_image_total = self.target_class_freq_by_image_total + target_class_freq_by_image_mask = self.target_class_freq_by_image_mask + pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask + + target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0) + target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0) + pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0) + + target_class_freq_by_image_total_marginal = target_class_freq_by_image_total.sum(0).astype('float32') + target_class_freq_by_image_total_marginal /= target_class_freq_by_image_total_marginal.sum() + + target_class_freq_by_image_mask_marginal = target_class_freq_by_image_mask.sum(0).astype('float32') + target_class_freq_by_image_mask_marginal /= target_class_freq_by_image_mask_marginal.sum() + + pred_class_freq_diff = (pred_class_freq_by_image_mask - target_class_freq_by_image_mask).sum(0) / (target_class_freq_by_image_mask.sum(0) + 1e-3) + + total_results = dict() + total_results.update({f'total_freq/{self.segm_idx2name[i]}': v + for i, v in enumerate(target_class_freq_by_image_total_marginal) + if v > 0}) + total_results.update({f'mask_freq/{self.segm_idx2name[i]}': v + for i, v in enumerate(target_class_freq_by_image_mask_marginal) + if v > 0}) + total_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v + for i, v in enumerate(pred_class_freq_diff) + if target_class_freq_by_image_total_marginal[i] > 0}) + + if groups is None: + return total_results, None + + group_results = dict() + grouping = get_groupings(groups) + for label, index in grouping.items(): + group_target_class_freq_by_image_total = target_class_freq_by_image_total[index] + group_target_class_freq_by_image_mask = target_class_freq_by_image_mask[index] + group_pred_class_freq_by_image_mask = pred_class_freq_by_image_mask[index] + + group_target_class_freq_by_image_total_marginal = group_target_class_freq_by_image_total.sum(0).astype('float32') + group_target_class_freq_by_image_total_marginal /= group_target_class_freq_by_image_total_marginal.sum() + + group_target_class_freq_by_image_mask_marginal = group_target_class_freq_by_image_mask.sum(0).astype('float32') + group_target_class_freq_by_image_mask_marginal /= group_target_class_freq_by_image_mask_marginal.sum() + + group_pred_class_freq_diff = (group_pred_class_freq_by_image_mask - group_target_class_freq_by_image_mask).sum(0) / ( + group_target_class_freq_by_image_mask.sum(0) + 1e-3) + + cur_group_results = dict() + cur_group_results.update({f'total_freq/{self.segm_idx2name[i]}': v + for i, v in enumerate(group_target_class_freq_by_image_total_marginal) + if v > 0}) + cur_group_results.update({f'mask_freq/{self.segm_idx2name[i]}': v + for i, v in enumerate(group_target_class_freq_by_image_mask_marginal) + if v > 0}) + cur_group_results.update({f'mask_freq_diff/{self.segm_idx2name[i]}': v + for i, v in enumerate(group_pred_class_freq_diff) + if group_target_class_freq_by_image_total_marginal[i] > 0}) + + group_results[label] = cur_group_results + return total_results, group_results + + +class SegmentationAwareSSIM(SegmentationAwarePairwiseScore): + def __init__(self, *args, window_size=11, **kwargs): + super().__init__(*args, **kwargs) + self.score_impl = SSIM(window_size=window_size, size_average=False).eval() + + def calc_score(self, pred_batch, target_batch, mask): + return self.score_impl(pred_batch, target_batch).detach().cpu().numpy() + + +class SegmentationAwareLPIPS(SegmentationAwarePairwiseScore): + def __init__(self, *args, model='net-lin', net='vgg', model_path=None, use_gpu=True, **kwargs): + super().__init__(*args, **kwargs) + self.score_impl = PerceptualLoss(model=model, net=net, model_path=model_path, + use_gpu=use_gpu, spatial=False).eval() + + def calc_score(self, pred_batch, target_batch, mask): + return self.score_impl(pred_batch, target_batch).flatten().detach().cpu().numpy() + + +def calculade_fid_no_img(img_i, activations_pred, activations_target, eps=1e-6): + activations_pred = activations_pred.copy() + activations_pred[img_i] = activations_target[img_i] + return calculate_frechet_distance(activations_pred, activations_target, eps=eps) + + +class SegmentationAwareFID(SegmentationAwarePairwiseScore): + def __init__(self, *args, dims=2048, eps=1e-6, n_jobs=-1, **kwargs): + super().__init__(*args, **kwargs) + if getattr(FIDScore, '_MODEL', None) is None: + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + FIDScore._MODEL = InceptionV3([block_idx]).eval() + self.model = FIDScore._MODEL + self.eps = eps + self.n_jobs = n_jobs + + def calc_score(self, pred_batch, target_batch, mask): + activations_pred = self._get_activations(pred_batch) + activations_target = self._get_activations(target_batch) + return activations_pred, activations_target + + def get_value(self, groups=None, states=None): + """ + :param groups: + :return: + total_results: dict of kind {'mean': score mean, 'std': score std} + group_results: None, if groups is None; + else dict {group_idx: {'mean': score mean among group, 'std': score std among group}} + """ + if states is not None: + (target_class_freq_by_image_total, + target_class_freq_by_image_mask, + pred_class_freq_by_image_mask, + activation_pairs) = states + else: + target_class_freq_by_image_total = self.target_class_freq_by_image_total + target_class_freq_by_image_mask = self.target_class_freq_by_image_mask + pred_class_freq_by_image_mask = self.pred_class_freq_by_image_mask + activation_pairs = self.individual_values + + target_class_freq_by_image_total = np.concatenate(target_class_freq_by_image_total, axis=0) + target_class_freq_by_image_mask = np.concatenate(target_class_freq_by_image_mask, axis=0) + pred_class_freq_by_image_mask = np.concatenate(pred_class_freq_by_image_mask, axis=0) + activations_pred, activations_target = zip(*activation_pairs) + activations_pred = np.concatenate(activations_pred, axis=0) + activations_target = np.concatenate(activations_target, axis=0) + + total_results = { + 'mean': calculate_frechet_distance(activations_pred, activations_target, eps=self.eps), + 'std': 0, + **self.distribute_fid_to_classes(target_class_freq_by_image_mask, activations_pred, activations_target) + } + + if groups is None: + return total_results, None + + group_results = dict() + grouping = get_groupings(groups) + for label, index in grouping.items(): + if len(index) > 1: + group_activations_pred = activations_pred[index] + group_activations_target = activations_target[index] + group_class_freq = target_class_freq_by_image_mask[index] + group_results[label] = { + 'mean': calculate_frechet_distance(group_activations_pred, group_activations_target, eps=self.eps), + 'std': 0, + **self.distribute_fid_to_classes(group_class_freq, + group_activations_pred, + group_activations_target) + } + else: + group_results[label] = dict(mean=float('nan'), std=0) + return total_results, group_results + + def distribute_fid_to_classes(self, class_freq, activations_pred, activations_target): + real_fid = calculate_frechet_distance(activations_pred, activations_target, eps=self.eps) + + fid_no_images = Parallel(n_jobs=self.n_jobs)( + delayed(calculade_fid_no_img)(img_i, activations_pred, activations_target, eps=self.eps) + for img_i in range(activations_pred.shape[0]) + ) + errors = real_fid - fid_no_images + return distribute_values_to_classes(class_freq, errors, self.segm_idx2name) + + def _get_activations(self, batch): + activations = self.model(batch)[0] + if activations.shape[2] != 1 or activations.shape[3] != 1: + activations = F.adaptive_avg_pool2d(activations, output_size=(1, 1)) + activations = activations.squeeze(-1).squeeze(-1).detach().cpu().numpy() + return activations diff --git a/saicinpainting/evaluation/losses/fid/__init__.py b/saicinpainting/evaluation/losses/fid/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/saicinpainting/evaluation/losses/fid/fid_score.py b/saicinpainting/evaluation/losses/fid/fid_score.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca8e602c21bb6a624d646da3f6479aea033b0ac --- /dev/null +++ b/saicinpainting/evaluation/losses/fid/fid_score.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +"""Calculates the Frechet Inception Distance (FID) to evalulate GANs + +The FID metric calculates the distance between two distributions of images. +Typically, we have summary statistics (mean & covariance matrix) of one +of these distributions, while the 2nd distribution is given by a GAN. + +When run as a stand-alone program, it compares the distribution of +images that are stored as PNG/JPEG at a specified location with a +distribution given by summary statistics (in pickle format). + +The FID is calculated by assuming that X_1 and X_2 are the activations of +the pool_3 layer of the inception net for generated samples and real world +samples respectively. + +See --help to see further details. + +Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead +of Tensorflow + +Copyright 2018 Institute of Bioinformatics, JKU Linz + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import pathlib +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser + +import numpy as np +import torch +# from scipy.misc import imread +from imageio import imread +from PIL import Image, JpegImagePlugin +from scipy import linalg +from torch.nn.functional import adaptive_avg_pool2d +from torchvision.transforms import CenterCrop, Compose, Resize, ToTensor + +try: + from tqdm import tqdm +except ImportError: + # If not tqdm is not available, provide a mock version of it + def tqdm(x): return x + +try: + from .inception import InceptionV3 +except ModuleNotFoundError: + from inception import InceptionV3 + +parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) +parser.add_argument('path', type=str, nargs=2, + help=('Path to the generated images or ' + 'to .npz statistic files')) +parser.add_argument('--batch-size', type=int, default=50, + help='Batch size to use') +parser.add_argument('--dims', type=int, default=2048, + choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), + help=('Dimensionality of Inception features to use. ' + 'By default, uses pool3 features')) +parser.add_argument('-c', '--gpu', default='', type=str, + help='GPU to use (leave blank for CPU only)') +parser.add_argument('--resize', default=256) + +transform = Compose([Resize(256), CenterCrop(256), ToTensor()]) + + +def get_activations(files, model, batch_size=50, dims=2048, + cuda=False, verbose=False, keep_size=False): + """Calculates the activations of the pool_3 layer for all images. + + Params: + -- files : List of image files paths + -- model : Instance of inception model + -- batch_size : Batch size of images for the model to process at once. + Make sure that the number of samples is a multiple of + the batch size, otherwise some samples are ignored. This + behavior is retained to match the original FID score + implementation. + -- dims : Dimensionality of features returned by Inception + -- cuda : If set to True, use GPU + -- verbose : If set to True and parameter out_step is given, the number + of calculated batches is reported. + Returns: + -- A numpy array of dimension (num images, dims) that contains the + activations of the given tensor when feeding inception with the + query tensor. + """ + model.eval() + + if len(files) % batch_size != 0: + print(('Warning: number of images is not a multiple of the ' + 'batch size. Some samples are going to be ignored.')) + if batch_size > len(files): + print(('Warning: batch size is bigger than the data size. ' + 'Setting batch size to data size')) + batch_size = len(files) + + n_batches = len(files) // batch_size + n_used_imgs = n_batches * batch_size + + pred_arr = np.empty((n_used_imgs, dims)) + + for i in tqdm(range(n_batches)): + if verbose: + print('\rPropagating batch %d/%d' % (i + 1, n_batches), + end='', flush=True) + start = i * batch_size + end = start + batch_size + + # # Official code goes below + # images = np.array([imread(str(f)).astype(np.float32) + # for f in files[start:end]]) + + # # Reshape to (n_images, 3, height, width) + # images = images.transpose((0, 3, 1, 2)) + # images /= 255 + # batch = torch.from_numpy(images).type(torch.FloatTensor) + # # + + t = transform if not keep_size else ToTensor() + + if isinstance(files[0], pathlib.PosixPath): + images = [t(Image.open(str(f))) for f in files[start:end]] + + elif isinstance(files[0], Image.Image): + images = [t(f) for f in files[start:end]] + + else: + raise ValueError(f"Unknown data type for image: {type(files[0])}") + + batch = torch.stack(images) + + if cuda: + batch = batch.cuda() + + pred = model(batch)[0] + + # If model output is not scalar, apply global spatial average pooling. + # This happens if you choose a dimensionality not equal 2048. + if pred.shape[2] != 1 or pred.shape[3] != 1: + pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) + + pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1) + + if verbose: + print(' done') + + return pred_arr + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + # if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-2): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + +def calculate_activation_statistics(files, model, batch_size=50, + dims=2048, cuda=False, verbose=False, keep_size=False): + """Calculation of the statistics used by the FID. + Params: + -- files : List of image files paths + -- model : Instance of inception model + -- batch_size : The images numpy array is split into batches with + batch size batch_size. A reasonable batch size + depends on the hardware. + -- dims : Dimensionality of features returned by Inception + -- cuda : If set to True, use GPU + -- verbose : If set to True and parameter out_step is given, the + number of calculated batches is reported. + Returns: + -- mu : The mean over samples of the activations of the pool_3 layer of + the inception model. + -- sigma : The covariance matrix of the activations of the pool_3 layer of + the inception model. + """ + act = get_activations(files, model, batch_size, dims, cuda, verbose, keep_size=keep_size) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +def _compute_statistics_of_path(path, model, batch_size, dims, cuda): + if path.endswith('.npz'): + f = np.load(path) + m, s = f['mu'][:], f['sigma'][:] + f.close() + else: + path = pathlib.Path(path) + files = list(path.glob('*.jpg')) + list(path.glob('*.png')) + m, s = calculate_activation_statistics(files, model, batch_size, + dims, cuda) + + return m, s + + +def _compute_statistics_of_images(images, model, batch_size, dims, cuda, keep_size=False): + if isinstance(images, list): # exact paths to files are provided + m, s = calculate_activation_statistics(images, model, batch_size, + dims, cuda, keep_size=keep_size) + + return m, s + + else: + raise ValueError + + +def calculate_fid_given_paths(paths, batch_size, cuda, dims): + """Calculates the FID of two paths""" + for p in paths: + if not os.path.exists(p): + raise RuntimeError('Invalid path: %s' % p) + + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + + model = InceptionV3([block_idx]) + if cuda: + model.cuda() + + m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, + dims, cuda) + m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, + dims, cuda) + fid_value = calculate_frechet_distance(m1, s1, m2, s2) + + return fid_value + + +def calculate_fid_given_images(images, batch_size, cuda, dims, use_globals=False, keep_size=False): + if use_globals: + global FID_MODEL # for multiprocessing + + for imgs in images: + if isinstance(imgs, list) and isinstance(imgs[0], (Image.Image, JpegImagePlugin.JpegImageFile)): + pass + else: + raise RuntimeError('Invalid images') + + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + + if 'FID_MODEL' not in globals() or not use_globals: + model = InceptionV3([block_idx]) + if cuda: + model.cuda() + + if use_globals: + FID_MODEL = model + + else: + model = FID_MODEL + + m1, s1 = _compute_statistics_of_images(images[0], model, batch_size, + dims, cuda, keep_size=False) + m2, s2 = _compute_statistics_of_images(images[1], model, batch_size, + dims, cuda, keep_size=False) + fid_value = calculate_frechet_distance(m1, s1, m2, s2) + return fid_value + + +if __name__ == '__main__': + args = parser.parse_args() + os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + + fid_value = calculate_fid_given_paths(args.path, + args.batch_size, + args.gpu != '', + args.dims) + print('FID: ', fid_value) diff --git a/saicinpainting/evaluation/losses/fid/inception.py b/saicinpainting/evaluation/losses/fid/inception.py new file mode 100644 index 0000000000000000000000000000000000000000..e9bd0863b457aaa40c770eaa4acbb142b18fc18b --- /dev/null +++ b/saicinpainting/evaluation/losses/fid/inception.py @@ -0,0 +1,323 @@ +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models + +try: + from torchvision.models.utils import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + +# Inception weights ported to Pytorch from +# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz +FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' + + +LOGGER = logging.getLogger(__name__) + + +class InceptionV3(nn.Module): + """Pretrained InceptionV3 network returning feature maps""" + + # Index of default block of inception to return, + # corresponds to output of final average pooling + DEFAULT_BLOCK_INDEX = 3 + + # Maps feature dimensionality to their output blocks indices + BLOCK_INDEX_BY_DIM = { + 64: 0, # First max pooling features + 192: 1, # Second max pooling featurs + 768: 2, # Pre-aux classifier features + 2048: 3 # Final average pooling features + } + + def __init__(self, + output_blocks=[DEFAULT_BLOCK_INDEX], + resize_input=True, + normalize_input=True, + requires_grad=False, + use_fid_inception=True): + """Build pretrained InceptionV3 + + Parameters + ---------- + output_blocks : list of int + Indices of blocks to return features of. Possible values are: + - 0: corresponds to output of first max pooling + - 1: corresponds to output of second max pooling + - 2: corresponds to output which is fed to aux classifier + - 3: corresponds to output of final average pooling + resize_input : bool + If true, bilinearly resizes input to width and height 299 before + feeding input to model. As the network without fully connected + layers is fully convolutional, it should be able to handle inputs + of arbitrary size, so resizing might not be strictly needed + normalize_input : bool + If true, scales the input from range (0, 1) to the range the + pretrained Inception network expects, namely (-1, 1) + requires_grad : bool + If true, parameters of the model require gradients. Possibly useful + for finetuning the network + use_fid_inception : bool + If true, uses the pretrained Inception model used in Tensorflow's + FID implementation. If false, uses the pretrained Inception model + available in torchvision. The FID Inception model has different + weights and a slightly different structure from torchvision's + Inception model. If you want to compute FID scores, you are + strongly advised to set this parameter to true to get comparable + results. + """ + super(InceptionV3, self).__init__() + + self.resize_input = resize_input + self.normalize_input = normalize_input + self.output_blocks = sorted(output_blocks) + self.last_needed_block = max(output_blocks) + + assert self.last_needed_block <= 3, \ + 'Last possible output block index is 3' + + self.blocks = nn.ModuleList() + + if use_fid_inception: + inception = fid_inception_v3() + else: + inception = models.inception_v3(pretrained=True) + + # Block 0: input to maxpool1 + block0 = [ + inception.Conv2d_1a_3x3, + inception.Conv2d_2a_3x3, + inception.Conv2d_2b_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block0)) + + # Block 1: maxpool1 to maxpool2 + if self.last_needed_block >= 1: + block1 = [ + inception.Conv2d_3b_1x1, + inception.Conv2d_4a_3x3, + nn.MaxPool2d(kernel_size=3, stride=2) + ] + self.blocks.append(nn.Sequential(*block1)) + + # Block 2: maxpool2 to aux classifier + if self.last_needed_block >= 2: + block2 = [ + inception.Mixed_5b, + inception.Mixed_5c, + inception.Mixed_5d, + inception.Mixed_6a, + inception.Mixed_6b, + inception.Mixed_6c, + inception.Mixed_6d, + inception.Mixed_6e, + ] + self.blocks.append(nn.Sequential(*block2)) + + # Block 3: aux classifier to final avgpool + if self.last_needed_block >= 3: + block3 = [ + inception.Mixed_7a, + inception.Mixed_7b, + inception.Mixed_7c, + nn.AdaptiveAvgPool2d(output_size=(1, 1)) + ] + self.blocks.append(nn.Sequential(*block3)) + + for param in self.parameters(): + param.requires_grad = requires_grad + + def forward(self, inp): + """Get Inception feature maps + + Parameters + ---------- + inp : torch.autograd.Variable + Input tensor of shape Bx3xHxW. Values are expected to be in + range (0, 1) + + Returns + ------- + List of torch.autograd.Variable, corresponding to the selected output + block, sorted ascending by index + """ + outp = [] + x = inp + + if self.resize_input: + x = F.interpolate(x, + size=(299, 299), + mode='bilinear', + align_corners=False) + + if self.normalize_input: + x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) + + for idx, block in enumerate(self.blocks): + x = block(x) + if idx in self.output_blocks: + outp.append(x) + + if idx == self.last_needed_block: + break + + return outp + + +def fid_inception_v3(): + """Build pretrained Inception model for FID computation + + The Inception model for FID computation uses a different set of weights + and has a slightly different structure than torchvision's Inception. + + This method first constructs torchvision's Inception and then patches the + necessary parts that are different in the FID Inception model. + """ + LOGGER.info('fid_inception_v3 called') + inception = models.inception_v3(num_classes=1008, + aux_logits=False, + pretrained=False) + LOGGER.info('models.inception_v3 done') + inception.Mixed_5b = FIDInceptionA(192, pool_features=32) + inception.Mixed_5c = FIDInceptionA(256, pool_features=64) + inception.Mixed_5d = FIDInceptionA(288, pool_features=64) + inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) + inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) + inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) + inception.Mixed_7b = FIDInceptionE_1(1280) + inception.Mixed_7c = FIDInceptionE_2(2048) + + LOGGER.info('fid_inception_v3 patching done') + + state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) + LOGGER.info('fid_inception_v3 weights downloaded') + + inception.load_state_dict(state_dict) + LOGGER.info('fid_inception_v3 weights loaded into model') + + return inception + + +class FIDInceptionA(models.inception.InceptionA): + """InceptionA block patched for FID computation""" + def __init__(self, in_channels, pool_features): + super(FIDInceptionA, self).__init__(in_channels, pool_features) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionC(models.inception.InceptionC): + """InceptionC block patched for FID computation""" + def __init__(self, in_channels, channels_7x7): + super(FIDInceptionC, self).__init__(in_channels, channels_7x7) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_1(models.inception.InceptionE): + """First InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_1, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: Tensorflow's average pool does not use the padded zero's in + # its average calculation + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, + count_include_pad=False) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) + + +class FIDInceptionE_2(models.inception.InceptionE): + """Second InceptionE block patched for FID computation""" + def __init__(self, in_channels): + super(FIDInceptionE_2, self).__init__(in_channels) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + # Patch: The FID Inception model uses max pooling instead of average + # pooling. This is likely an error in this specific Inception + # implementation, as other Inception models use average pooling here + # (which matches the description in the paper). + branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return torch.cat(outputs, 1) diff --git a/saicinpainting/evaluation/losses/lpips.py b/saicinpainting/evaluation/losses/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..b5f19b747f2457902695213f7efcde4fdc306c1f --- /dev/null +++ b/saicinpainting/evaluation/losses/lpips.py @@ -0,0 +1,891 @@ +############################################################ +# The contents below have been combined using files in the # +# following repository: # +# https://github.com/richzhang/PerceptualSimilarity # +############################################################ + +############################################################ +# __init__.py # +############################################################ + +import numpy as np +from skimage.metrics import structural_similarity +import torch + +from saicinpainting.utils import get_shape + + +class PerceptualLoss(torch.nn.Module): + def __init__(self, model='net-lin', net='alex', colorspace='rgb', model_path=None, spatial=False, use_gpu=True): + # VGG using our perceptually-learned weights (LPIPS metric) + # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss + super(PerceptualLoss, self).__init__() + self.use_gpu = use_gpu + self.spatial = spatial + self.model = DistModel() + self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, + model_path=model_path, spatial=self.spatial) + + def forward(self, pred, target, normalize=True): + """ + Pred and target are Variables. + If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] + If normalize is False, assumes the images are already between [-1,+1] + Inputs pred and target are Nx3xHxW + Output pytorch Variable N long + """ + + if normalize: + target = 2 * target - 1 + pred = 2 * pred - 1 + + return self.model(target, pred) + + +def normalize_tensor(in_feat, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) + return in_feat / (norm_factor + eps) + + +def l2(p0, p1, range=255.): + return .5 * np.mean((p0 / range - p1 / range) ** 2) + + +def psnr(p0, p1, peak=255.): + return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2)) + + +def dssim(p0, p1, range=255.): + return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. + + +def rgb2lab(in_img, mean_cent=False): + from skimage import color + img_lab = color.rgb2lab(in_img) + if (mean_cent): + img_lab[:, :, 0] = img_lab[:, :, 0] - 50 + return img_lab + + +def tensor2np(tensor_obj): + # change dimension of a tensor object into a numpy array + return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) + + +def np2tensor(np_obj): + # change dimenion of np array into tensor array + return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + + +def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): + # image tensor to lab tensor + from skimage import color + + img = tensor2im(image_tensor) + img_lab = color.rgb2lab(img) + if (mc_only): + img_lab[:, :, 0] = img_lab[:, :, 0] - 50 + if (to_norm and not mc_only): + img_lab[:, :, 0] = img_lab[:, :, 0] - 50 + img_lab = img_lab / 100. + + return np2tensor(img_lab) + + +def tensorlab2tensor(lab_tensor, return_inbnd=False): + from skimage import color + import warnings + warnings.filterwarnings("ignore") + + lab = tensor2np(lab_tensor) * 100. + lab[:, :, 0] = lab[:, :, 0] + 50 + + rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1) + if (return_inbnd): + # convert back to lab, see if we match + lab_back = color.rgb2lab(rgb_back.astype('uint8')) + mask = 1. * np.isclose(lab_back, lab, atol=2.) + mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) + return (im2tensor(rgb_back), mask) + else: + return im2tensor(rgb_back) + + +def rgb2lab(input): + from skimage import color + return color.rgb2lab(input / 255.) + + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + + +def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): + return torch.Tensor((image / factor - cent) + [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + + +def tensor2vec(vector_tensor): + return vector_tensor.data.cpu().numpy()[:, :, 0, 0] + + +def voc_ap(rec, prec, use_07_metric=False): + """ ap = voc_ap(rec, prec, [use_07_metric]) + Compute VOC AP given precision and recall. + If use_07_metric is true, uses the + VOC 07 11 point method (default:False). + """ + if use_07_metric: + # 11 point metric + ap = 0. + for t in np.arange(0., 1.1, 0.1): + if np.sum(rec >= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11. + else: + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.], rec, [1.])) + mpre = np.concatenate(([0.], prec, [0.])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap + + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.): + # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + + +def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): + # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): + return torch.Tensor((image / factor - cent) + [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) + + +############################################################ +# base_model.py # +############################################################ + + +class BaseModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def name(self): + return 'BaseModel' + + def initialize(self, use_gpu=True): + self.use_gpu = use_gpu + + def forward(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, path, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(path, save_filename) + torch.save(network.state_dict(), save_path) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + print('Loading network from %s' % save_path) + network.load_state_dict(torch.load(save_path, map_location='cpu')) + + def update_learning_rate(): + pass + + def get_image_paths(self): + return self.image_paths + + def save_done(self, flag=False): + np.save(os.path.join(self.save_dir, 'done_flag'), flag) + np.savetxt(os.path.join(self.save_dir, 'done_flag'), [flag, ], fmt='%i') + + +############################################################ +# dist_model.py # +############################################################ + +import os +from collections import OrderedDict +from scipy.ndimage import zoom +from tqdm import tqdm + + +class DistModel(BaseModel): + def name(self): + return self.model_name + + def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, + model_path=None, + use_gpu=True, printNet=False, spatial=False, + is_train=False, lr=.0001, beta1=0.5, version='0.1'): + ''' + INPUTS + model - ['net-lin'] for linearly calibrated network + ['net'] for off-the-shelf network + ['L2'] for L2 distance in Lab colorspace + ['SSIM'] for ssim in RGB colorspace + net - ['squeeze','alex','vgg'] + model_path - if None, will look in weights/[NET_NAME].pth + colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM + use_gpu - bool - whether or not to use a GPU + printNet - bool - whether or not to print network architecture out + spatial - bool - whether to output an array containing varying distances across spatial dimensions + spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). + spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. + spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). + is_train - bool - [True] for training mode + lr - float - initial learning rate + beta1 - float - initial momentum term for adam + version - 0.1 for latest, 0.0 was original (with a bug) + ''' + BaseModel.initialize(self, use_gpu=use_gpu) + + self.model = model + self.net = net + self.is_train = is_train + self.spatial = spatial + self.model_name = '%s [%s]' % (model, net) + + if (self.model == 'net-lin'): # pretrained net + linear layer + self.net = PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, + use_dropout=True, spatial=spatial, version=version, lpips=True) + kw = dict(map_location='cpu') + if (model_path is None): + import inspect + model_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), '..', '..', '..', 'models', 'lpips_models', f'{net}.pth')) + + if (not is_train): + self.net.load_state_dict(torch.load(model_path, **kw), strict=False) + + elif (self.model == 'net'): # pretrained network + self.net = PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) + elif (self.model in ['L2', 'l2']): + self.net = L2(use_gpu=use_gpu, colorspace=colorspace) # not really a network, only for testing + self.model_name = 'L2' + elif (self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']): + self.net = DSSIM(use_gpu=use_gpu, colorspace=colorspace) + self.model_name = 'SSIM' + else: + raise ValueError("Model [%s] not recognized." % self.model) + + self.trainable_parameters = list(self.net.parameters()) + + if self.is_train: # training mode + # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) + self.rankLoss = BCERankingLoss() + self.trainable_parameters += list(self.rankLoss.net.parameters()) + self.lr = lr + self.old_lr = lr + self.optimizer_net = torch.optim.Adam(self.trainable_parameters, lr=lr, betas=(beta1, 0.999)) + else: # test mode + self.net.eval() + + # if (use_gpu): + # self.net.to(gpu_ids[0]) + # self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) + # if (self.is_train): + # self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 + + if (printNet): + print('---------- Networks initialized -------------') + print_network(self.net) + print('-----------------------------------------------') + + def forward(self, in0, in1, retPerLayer=False): + ''' Function computes the distance between image patches in0 and in1 + INPUTS + in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] + OUTPUT + computed distances between in0 and in1 + ''' + + return self.net(in0, in1, retPerLayer=retPerLayer) + + # ***** TRAINING FUNCTIONS ***** + def optimize_parameters(self): + self.forward_train() + self.optimizer_net.zero_grad() + self.backward_train() + self.optimizer_net.step() + self.clamp_weights() + + def clamp_weights(self): + for module in self.net.modules(): + if (hasattr(module, 'weight') and module.kernel_size == (1, 1)): + module.weight.data = torch.clamp(module.weight.data, min=0) + + def set_input(self, data): + self.input_ref = data['ref'] + self.input_p0 = data['p0'] + self.input_p1 = data['p1'] + self.input_judge = data['judge'] + + # if (self.use_gpu): + # self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) + # self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) + # self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) + # self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) + + # self.var_ref = Variable(self.input_ref, requires_grad=True) + # self.var_p0 = Variable(self.input_p0, requires_grad=True) + # self.var_p1 = Variable(self.input_p1, requires_grad=True) + + def forward_train(self): # run forward pass + # print(self.net.module.scaling_layer.shift) + # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) + + assert False, "We shoud've not get here when using LPIPS as a metric" + + self.d0 = self(self.var_ref, self.var_p0) + self.d1 = self(self.var_ref, self.var_p1) + self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge) + + self.var_judge = Variable(1. * self.input_judge).view(self.d0.size()) + + self.loss_total = self.rankLoss(self.d0, self.d1, self.var_judge * 2. - 1.) + + return self.loss_total + + def backward_train(self): + torch.mean(self.loss_total).backward() + + def compute_accuracy(self, d0, d1, judge): + ''' d0, d1 are Variables, judge is a Tensor ''' + d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten() + judge_per = judge.cpu().numpy().flatten() + return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per) + + def get_current_errors(self): + retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()), + ('acc_r', self.acc_r)]) + + for key in retDict.keys(): + retDict[key] = np.mean(retDict[key]) + + return retDict + + def get_current_visuals(self): + zoom_factor = 256 / self.var_ref.data.size()[2] + + ref_img = tensor2im(self.var_ref.data) + p0_img = tensor2im(self.var_p0.data) + p1_img = tensor2im(self.var_p1.data) + + ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0) + p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0) + p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0) + + return OrderedDict([('ref', ref_img_vis), + ('p0', p0_img_vis), + ('p1', p1_img_vis)]) + + def save(self, path, label): + if (self.use_gpu): + self.save_network(self.net.module, path, '', label) + else: + self.save_network(self.net, path, '', label) + self.save_network(self.rankLoss.net, path, 'rank', label) + + def update_learning_rate(self, nepoch_decay): + lrd = self.lr / nepoch_decay + lr = self.old_lr - lrd + + for param_group in self.optimizer_net.param_groups: + param_group['lr'] = lr + + print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr)) + self.old_lr = lr + + +def score_2afc_dataset(data_loader, func, name=''): + ''' Function computes Two Alternative Forced Choice (2AFC) score using + distance function 'func' in dataset 'data_loader' + INPUTS + data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside + func - callable distance function - calling d=func(in0,in1) should take 2 + pytorch tensors with shape Nx3xXxY, and return numpy array of length N + OUTPUTS + [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators + [1] - dictionary with following elements + d0s,d1s - N arrays containing distances between reference patch to perturbed patches + gts - N array in [0,1], preferred patch selected by human evaluators + (closer to "0" for left patch p0, "1" for right patch p1, + "0.6" means 60pct people preferred right patch, 40pct preferred left) + scores - N array in [0,1], corresponding to what percentage function agreed with humans + CONSTS + N - number of test triplets in data_loader + ''' + + d0s = [] + d1s = [] + gts = [] + + for data in tqdm(data_loader.load_data(), desc=name): + d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist() + d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist() + gts += data['judge'].cpu().numpy().flatten().tolist() + + d0s = np.array(d0s) + d1s = np.array(d1s) + gts = np.array(gts) + scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5 + + return (np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores)) + + +def score_jnd_dataset(data_loader, func, name=''): + ''' Function computes JND score using distance function 'func' in dataset 'data_loader' + INPUTS + data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside + func - callable distance function - calling d=func(in0,in1) should take 2 + pytorch tensors with shape Nx3xXxY, and return pytorch array of length N + OUTPUTS + [0] - JND score in [0,1], mAP score (area under precision-recall curve) + [1] - dictionary with following elements + ds - N array containing distances between two patches shown to human evaluator + sames - N array containing fraction of people who thought the two patches were identical + CONSTS + N - number of test triplets in data_loader + ''' + + ds = [] + gts = [] + + for data in tqdm(data_loader.load_data(), desc=name): + ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist() + gts += data['same'].cpu().numpy().flatten().tolist() + + sames = np.array(gts) + ds = np.array(ds) + + sorted_inds = np.argsort(ds) + ds_sorted = ds[sorted_inds] + sames_sorted = sames[sorted_inds] + + TPs = np.cumsum(sames_sorted) + FPs = np.cumsum(1 - sames_sorted) + FNs = np.sum(sames_sorted) - TPs + + precs = TPs / (TPs + FPs) + recs = TPs / (TPs + FNs) + score = voc_ap(recs, precs) + + return (score, dict(ds=ds, sames=sames)) + + +############################################################ +# networks_basic.py # +############################################################ + +import torch.nn as nn +from torch.autograd import Variable +import numpy as np + + +def spatial_average(in_tens, keepdim=True): + return in_tens.mean([2, 3], keepdim=keepdim) + + +def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W + in_H = in_tens.shape[2] + scale_factor = 1. * out_H / in_H + + return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens) + + +# Learned perceptual metric +class PNetLin(nn.Module): + def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, + version='0.1', lpips=True): + super(PNetLin, self).__init__() + + self.pnet_type = pnet_type + self.pnet_tune = pnet_tune + self.pnet_rand = pnet_rand + self.spatial = spatial + self.lpips = lpips + self.version = version + self.scaling_layer = ScalingLayer() + + if (self.pnet_type in ['vgg', 'vgg16']): + net_type = vgg16 + self.chns = [64, 128, 256, 512, 512] + elif (self.pnet_type == 'alex'): + net_type = alexnet + self.chns = [64, 192, 384, 256, 256] + elif (self.pnet_type == 'squeeze'): + net_type = squeezenet + self.chns = [64, 128, 256, 384, 384, 512, 512] + self.L = len(self.chns) + + self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) + + if (lpips): + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + if (self.pnet_type == 'squeeze'): # 7 layers for squeezenet + self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) + self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) + self.lins += [self.lin5, self.lin6] + + def forward(self, in0, in1, retPerLayer=False): + # v0.0 - original release had a bug, where input was not scaled + in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else ( + in0, in1) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + + for kk in range(self.L): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + if (self.lpips): + if (self.spatial): + res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] + else: + res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] + else: + if (self.spatial): + res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] + else: + res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] + + val = res[0] + for l in range(1, self.L): + val += res[l] + + if (retPerLayer): + return (val, res) + else: + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + ''' A single linear layer which does a 1x1 conv ''' + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class Dist2LogitLayer(nn.Module): + ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' + + def __init__(self, chn_mid=32, use_sigmoid=True): + super(Dist2LogitLayer, self).__init__() + + layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ] + layers += [nn.LeakyReLU(0.2, True), ] + layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ] + layers += [nn.LeakyReLU(0.2, True), ] + layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ] + if (use_sigmoid): + layers += [nn.Sigmoid(), ] + self.model = nn.Sequential(*layers) + + def forward(self, d0, d1, eps=0.1): + return self.model(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1)) + + +class BCERankingLoss(nn.Module): + def __init__(self, chn_mid=32): + super(BCERankingLoss, self).__init__() + self.net = Dist2LogitLayer(chn_mid=chn_mid) + # self.parameters = list(self.net.parameters()) + self.loss = torch.nn.BCELoss() + + def forward(self, d0, d1, judge): + per = (judge + 1.) / 2. + self.logit = self.net(d0, d1) + return self.loss(self.logit, per) + + +# L2, DSSIM metrics +class FakeNet(nn.Module): + def __init__(self, use_gpu=True, colorspace='Lab'): + super(FakeNet, self).__init__() + self.use_gpu = use_gpu + self.colorspace = colorspace + + +class L2(FakeNet): + + def forward(self, in0, in1, retPerLayer=None): + assert (in0.size()[0] == 1) # currently only supports batchSize 1 + + if (self.colorspace == 'RGB'): + (N, C, X, Y) = in0.size() + value = torch.mean(torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y), + dim=3).view(N) + return value + elif (self.colorspace == 'Lab'): + value = l2(tensor2np(tensor2tensorlab(in0.data, to_norm=False)), + tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float') + ret_var = Variable(torch.Tensor((value,))) + # if (self.use_gpu): + # ret_var = ret_var.cuda() + return ret_var + + +class DSSIM(FakeNet): + + def forward(self, in0, in1, retPerLayer=None): + assert (in0.size()[0] == 1) # currently only supports batchSize 1 + + if (self.colorspace == 'RGB'): + value = dssim(1. * tensor2im(in0.data), 1. * tensor2im(in1.data), range=255.).astype('float') + elif (self.colorspace == 'Lab'): + value = dssim(tensor2np(tensor2tensorlab(in0.data, to_norm=False)), + tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float') + ret_var = Variable(torch.Tensor((value,))) + # if (self.use_gpu): + # ret_var = ret_var.cuda() + return ret_var + + +def print_network(net): + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print('Network', net) + print('Total number of parameters: %d' % num_params) + + +############################################################ +# pretrained_networks.py # +############################################################ + +from collections import namedtuple +import torch +from torchvision import models as tv + + +class squeezenet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(squeezenet, self).__init__() + pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + self.slice7 = torch.nn.Sequential() + self.N_slices = 7 + for x in range(2): + self.slice1.add_module(str(x), pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), pretrained_features[x]) + for x in range(10, 11): + self.slice5.add_module(str(x), pretrained_features[x]) + for x in range(11, 12): + self.slice6.add_module(str(x), pretrained_features[x]) + for x in range(12, 13): + self.slice7.add_module(str(x), pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + h = self.slice6(h) + h_relu6 = h + h = self.slice7(h) + h_relu7 = h + vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7']) + out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) + + return out + + +class alexnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(alexnet, self).__init__() + alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(2): + self.slice1.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(10, 12): + self.slice5.add_module(str(x), alexnet_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) + out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + return out + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + return out + + +class resnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True, num=18): + super(resnet, self).__init__() + if (num == 18): + self.net = tv.resnet18(pretrained=pretrained) + elif (num == 34): + self.net = tv.resnet34(pretrained=pretrained) + elif (num == 50): + self.net = tv.resnet50(pretrained=pretrained) + elif (num == 101): + self.net = tv.resnet101(pretrained=pretrained) + elif (num == 152): + self.net = tv.resnet152(pretrained=pretrained) + self.N_slices = 5 + + self.conv1 = self.net.conv1 + self.bn1 = self.net.bn1 + self.relu = self.net.relu + self.maxpool = self.net.maxpool + self.layer1 = self.net.layer1 + self.layer2 = self.net.layer2 + self.layer3 = self.net.layer3 + self.layer4 = self.net.layer4 + + def forward(self, X): + h = self.conv1(X) + h = self.bn1(h) + h = self.relu(h) + h_relu1 = h + h = self.maxpool(h) + h = self.layer1(h) + h_conv2 = h + h = self.layer2(h) + h_conv3 = h + h = self.layer3(h) + h_conv4 = h + h = self.layer4(h) + h_conv5 = h + + outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5']) + out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) + + return out diff --git a/saicinpainting/evaluation/losses/ssim.py b/saicinpainting/evaluation/losses/ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..ee43a0095408eca98e253dea194db788446f9c0a --- /dev/null +++ b/saicinpainting/evaluation/losses/ssim.py @@ -0,0 +1,74 @@ +import numpy as np +import torch +import torch.nn.functional as F + + +class SSIM(torch.nn.Module): + """SSIM. Modified from: + https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py + """ + + def __init__(self, window_size=11, size_average=True): + super().__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.register_buffer('window', self._create_window(window_size, self.channel)) + + def forward(self, img1, img2): + assert len(img1.shape) == 4 + + channel = img1.size()[1] + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = self._create_window(self.window_size, channel) + + # window = window.to(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return self._ssim(img1, img2, window, self.window_size, channel, self.size_average) + + def _gaussian(self, window_size, sigma): + gauss = torch.Tensor([ + np.exp(-(x - (window_size // 2)) ** 2 / float(2 * sigma ** 2)) for x in range(window_size) + ]) + return gauss / gauss.sum() + + def _create_window(self, window_size, channel): + _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + return _2D_window.expand(channel, 1, window_size, window_size).contiguous() + + def _ssim(self, img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=(window_size // 2), groups=channel) + mu2 = F.conv2d(img2, window, padding=(window_size // 2), groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d( + img1 * img1, window, padding=(window_size // 2), groups=channel) - mu1_sq + sigma2_sq = F.conv2d( + img2 * img2, window, padding=(window_size // 2), groups=channel) - mu2_sq + sigma12 = F.conv2d( + img1 * img2, window, padding=(window_size // 2), groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ + ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + + return ssim_map.mean(1).mean(1).mean(1) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + return diff --git a/saicinpainting/evaluation/masks/README.md b/saicinpainting/evaluation/masks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cf176bc10fae3b03f139727147c220f2a735c806 --- /dev/null +++ b/saicinpainting/evaluation/masks/README.md @@ -0,0 +1,27 @@ +# Current algorithm + +## Choice of mask objects + +For identification of the objects which are suitable for mask obtaining, panoptic segmentation model +from [detectron2](https://github.com/facebookresearch/detectron2) trained on COCO. Categories of the detected instances +belong either to "stuff" or "things" types. We consider that instances of objects should have category belong +to "things". Besides, we set upper bound on area which is taken by the object — we consider that too big +area indicates either of the instance being a background or a main object which should not be removed. + +## Choice of position for mask + +We consider that input image has size 2^n x 2^m. We downsample it using +[COUNTLESS](https://github.com/william-silversmith/countless) algorithm so the width is equal to +64 = 2^8 = 2^{downsample_levels}. + +### Augmentation + +There are several parameters for augmentation: +- Scaling factor. We limit scaling to the case when a mask after scaling with pivot point in its center fits inside the + image completely. +- + +### Shift + + +## Select diff --git a/saicinpainting/evaluation/masks/__init__.py b/saicinpainting/evaluation/masks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/saicinpainting/evaluation/masks/countless/.gitignore b/saicinpainting/evaluation/masks/countless/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..872aa273a4e3ec99d362cefa1c67550f21f3c366 --- /dev/null +++ b/saicinpainting/evaluation/masks/countless/.gitignore @@ -0,0 +1 @@ +results \ No newline at end of file diff --git a/saicinpainting/evaluation/masks/countless/README.md b/saicinpainting/evaluation/masks/countless/README.md new file mode 100644 index 0000000000000000000000000000000000000000..67335464d794776140fd0308f408608f2231309b --- /dev/null +++ b/saicinpainting/evaluation/masks/countless/README.md @@ -0,0 +1,25 @@ +[![Build Status](https://travis-ci.org/william-silversmith/countless.svg?branch=master)](https://travis-ci.org/william-silversmith/countless) + +Python COUNTLESS Downsampling +============================= + +To install: + +`pip install -r requirements.txt` + +To test: + +`python test.py` + +To benchmark countless2d: + +`python python/countless2d.py python/images/gray_segmentation.png` + +To benchmark countless3d: + +`python python/countless3d.py` + +Adjust N and the list of algorithms inside each script to modify the run parameters. + + +Python3 is slightly faster than Python2. \ No newline at end of file diff --git a/saicinpainting/evaluation/masks/countless/__init__.py b/saicinpainting/evaluation/masks/countless/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/saicinpainting/evaluation/masks/countless/countless2d.py b/saicinpainting/evaluation/masks/countless/countless2d.py new file mode 100644 index 0000000000000000000000000000000000000000..dc27b73affa20ab1a8a199542469a10aaf1f555a --- /dev/null +++ b/saicinpainting/evaluation/masks/countless/countless2d.py @@ -0,0 +1,529 @@ +from __future__ import print_function, division + +""" +COUNTLESS performance test in Python. + +python countless2d.py ./images/NAMEOFIMAGE +""" + +import six +from six.moves import range +from collections import defaultdict +from functools import reduce +import operator +import io +import os +from PIL import Image +import math +import numpy as np +import random +import sys +import time +from tqdm import tqdm +from scipy import ndimage + +def simplest_countless(data): + """ + Vectorized implementation of downsampling a 2D + image by 2 on each side using the COUNTLESS algorithm. + + data is a 2D numpy array with even dimensions. + """ + sections = [] + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + a, b, c, d = sections + + ab = a * (a == b) # PICK(A,B) + ac = a * (a == c) # PICK(A,C) + bc = b * (b == c) # PICK(B,C) + + a = ab | ac | bc # Bitwise OR, safe b/c non-matches are zeroed + + return a + (a == 0) * d # AB || AC || BC || D + +def quick_countless(data): + """ + Vectorized implementation of downsampling a 2D + image by 2 on each side using the COUNTLESS algorithm. + + data is a 2D numpy array with even dimensions. + """ + sections = [] + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + a, b, c, d = sections + + ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization + bc = b * (b == c) # PICK(B,C) + + a = ab_ac | bc # (PICK(A,B) || PICK(A,C)) or PICK(B,C) + return a + (a == 0) * d # AB || AC || BC || D + +def quickest_countless(data): + """ + Vectorized implementation of downsampling a 2D + image by 2 on each side using the COUNTLESS algorithm. + + data is a 2D numpy array with even dimensions. + """ + sections = [] + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + a, b, c, d = sections + + ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization + ab_ac |= b * (b == c) # PICK(B,C) + return ab_ac + (ab_ac == 0) * d # AB || AC || BC || D + +def quick_countless_xor(data): + """ + Vectorized implementation of downsampling a 2D + image by 2 on each side using the COUNTLESS algorithm. + + data is a 2D numpy array with even dimensions. + """ + sections = [] + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + a, b, c, d = sections + + ab = a ^ (a ^ b) # a or b + ab += (ab != a) * ((ab ^ (ab ^ c)) - b) # b or c + ab += (ab == c) * ((ab ^ (ab ^ d)) - c) # c or d + return ab + +def stippled_countless(data): + """ + Vectorized implementation of downsampling a 2D + image by 2 on each side using the COUNTLESS algorithm + that treats zero as "background" and inflates lone + pixels. + + data is a 2D numpy array with even dimensions. + """ + sections = [] + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + a, b, c, d = sections + + ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization + ab_ac |= b * (b == c) # PICK(B,C) + + nonzero = a + (a == 0) * (b + (b == 0) * c) + return ab_ac + (ab_ac == 0) * (d + (d == 0) * nonzero) # AB || AC || BC || D + +def zero_corrected_countless(data): + """ + Vectorized implementation of downsampling a 2D + image by 2 on each side using the COUNTLESS algorithm. + + data is a 2D numpy array with even dimensions. + """ + # allows us to prevent losing 1/2 a bit of information + # at the top end by using a bigger type. Without this 255 is handled incorrectly. + data, upgraded = upgrade_type(data) + + # offset from zero, raw countless doesn't handle 0 correctly + # we'll remove the extra 1 at the end. + data += 1 + + sections = [] + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + a, b, c, d = sections + + ab = a * (a == b) # PICK(A,B) + ac = a * (a == c) # PICK(A,C) + bc = b * (b == c) # PICK(B,C) + + a = ab | ac | bc # Bitwise OR, safe b/c non-matches are zeroed + + result = a + (a == 0) * d - 1 # a or d - 1 + + if upgraded: + return downgrade_type(result) + + # only need to reset data if we weren't upgraded + # b/c no copy was made in that case + data -= 1 + + return result + +def countless_extreme(data): + nonzeros = np.count_nonzero(data) + # print("nonzeros", nonzeros) + + N = reduce(operator.mul, data.shape) + + if nonzeros == N: + print("quick") + return quick_countless(data) + elif np.count_nonzero(data + 1) == N: + print("quick") + # print("upper", nonzeros) + return quick_countless(data) + else: + return countless(data) + + +def countless(data): + """ + Vectorized implementation of downsampling a 2D + image by 2 on each side using the COUNTLESS algorithm. + + data is a 2D numpy array with even dimensions. + """ + # allows us to prevent losing 1/2 a bit of information + # at the top end by using a bigger type. Without this 255 is handled incorrectly. + data, upgraded = upgrade_type(data) + + # offset from zero, raw countless doesn't handle 0 correctly + # we'll remove the extra 1 at the end. + data += 1 + + sections = [] + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + a, b, c, d = sections + + ab_ac = a * ((a == b) | (a == c)) # PICK(A,B) || PICK(A,C) w/ optimization + ab_ac |= b * (b == c) # PICK(B,C) + result = ab_ac + (ab_ac == 0) * d - 1 # (matches or d) - 1 + + if upgraded: + return downgrade_type(result) + + # only need to reset data if we weren't upgraded + # b/c no copy was made in that case + data -= 1 + + return result + +def upgrade_type(arr): + dtype = arr.dtype + + if dtype == np.uint8: + return arr.astype(np.uint16), True + elif dtype == np.uint16: + return arr.astype(np.uint32), True + elif dtype == np.uint32: + return arr.astype(np.uint64), True + + return arr, False + +def downgrade_type(arr): + dtype = arr.dtype + + if dtype == np.uint64: + return arr.astype(np.uint32) + elif dtype == np.uint32: + return arr.astype(np.uint16) + elif dtype == np.uint16: + return arr.astype(np.uint8) + + return arr + +def odd_to_even(image): + """ + To facilitate 2x2 downsampling segmentation, change an odd sized image into an even sized one. + Works by mirroring the starting 1 pixel edge of the image on odd shaped sides. + + e.g. turn a 3x3x5 image into a 4x4x5 (the x and y are what are getting downsampled) + + For example: [ 3, 2, 4 ] => [ 3, 3, 2, 4 ] which is now easy to downsample. + + """ + shape = np.array(image.shape) + + offset = (shape % 2)[:2] # x,y offset + + # detect if we're dealing with an even + # image. if so it's fine, just return. + if not np.any(offset): + return image + + oddshape = image.shape[:2] + offset + oddshape = np.append(oddshape, shape[2:]) + oddshape = oddshape.astype(int) + + newimg = np.empty(shape=oddshape, dtype=image.dtype) + + ox,oy = offset + sx,sy = oddshape + + newimg[0,0] = image[0,0] # corner + newimg[ox:sx,0] = image[:,0] # x axis line + newimg[0,oy:sy] = image[0,:] # y axis line + + return newimg + +def counting(array): + factor = (2, 2, 1) + shape = array.shape + + while len(shape) < 4: + array = np.expand_dims(array, axis=-1) + shape = array.shape + + output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(shape, factor)) + output = np.zeros(output_shape, dtype=array.dtype) + + for chan in range(0, shape[3]): + for z in range(0, shape[2]): + for x in range(0, shape[0], 2): + for y in range(0, shape[1], 2): + block = array[ x:x+2, y:y+2, z, chan ] # 2x2 block + + hashtable = defaultdict(int) + for subx, suby in np.ndindex(block.shape[0], block.shape[1]): + hashtable[block[subx, suby]] += 1 + + best = (0, 0) + for segid, val in six.iteritems(hashtable): + if best[1] < val: + best = (segid, val) + + output[ x // 2, y // 2, chan ] = best[0] + + return output + +def ndzoom(array): + if len(array.shape) == 3: + ratio = ( 1 / 2.0, 1 / 2.0, 1.0 ) + else: + ratio = ( 1 / 2.0, 1 / 2.0) + return ndimage.interpolation.zoom(array, ratio, order=1) + +def countless_if(array): + factor = (2, 2, 1) + shape = array.shape + + if len(shape) < 3: + array = array[ :,:, np.newaxis ] + shape = array.shape + + output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(shape, factor)) + output = np.zeros(output_shape, dtype=array.dtype) + + for chan in range(0, shape[2]): + for x in range(0, shape[0], 2): + for y in range(0, shape[1], 2): + block = array[ x:x+2, y:y+2, chan ] # 2x2 block + + if block[0,0] == block[1,0]: + pick = block[0,0] + elif block[0,0] == block[0,1]: + pick = block[0,0] + elif block[1,0] == block[0,1]: + pick = block[1,0] + else: + pick = block[1,1] + + output[ x // 2, y // 2, chan ] = pick + + return np.squeeze(output) + +def downsample_with_averaging(array): + """ + Downsample x by factor using averaging. + + @return: The downsampled array, of the same type as x. + """ + + if len(array.shape) == 3: + factor = (2,2,1) + else: + factor = (2,2) + + if np.array_equal(factor[:3], np.array([1,1,1])): + return array + + output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(array.shape, factor)) + temp = np.zeros(output_shape, float) + counts = np.zeros(output_shape, np.int) + for offset in np.ndindex(factor): + part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + indexing_expr = tuple(np.s_[:s] for s in part.shape) + temp[indexing_expr] += part + counts[indexing_expr] += 1 + return np.cast[array.dtype](temp / counts) + +def downsample_with_max_pooling(array): + + factor = (2,2) + + if np.all(np.array(factor, int) == 1): + return array + + sections = [] + + for offset in np.ndindex(factor): + part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + output = sections[0].copy() + + for section in sections[1:]: + np.maximum(output, section, output) + + return output + +def striding(array): + """Downsample x by factor using striding. + + @return: The downsampled array, of the same type as x. + """ + factor = (2,2) + if np.all(np.array(factor, int) == 1): + return array + return array[tuple(np.s_[::f] for f in factor)] + +def benchmark(): + filename = sys.argv[1] + img = Image.open(filename) + data = np.array(img.getdata(), dtype=np.uint8) + + if len(data.shape) == 1: + n_channels = 1 + reshape = (img.height, img.width) + else: + n_channels = min(data.shape[1], 3) + data = data[:, :n_channels] + reshape = (img.height, img.width, n_channels) + + data = data.reshape(reshape).astype(np.uint8) + + methods = [ + simplest_countless, + quick_countless, + quick_countless_xor, + quickest_countless, + stippled_countless, + zero_corrected_countless, + countless, + downsample_with_averaging, + downsample_with_max_pooling, + ndzoom, + striding, + # countless_if, + # counting, + ] + + formats = { + 1: 'L', + 3: 'RGB', + 4: 'RGBA' + } + + if not os.path.exists('./results'): + os.mkdir('./results') + + N = 500 + img_size = float(img.width * img.height) / 1024.0 / 1024.0 + print("N = %d, %dx%d (%.2f MPx) %d chan, %s" % (N, img.width, img.height, img_size, n_channels, filename)) + print("Algorithm\tMPx/sec\tMB/sec\tSec") + for fn in methods: + print(fn.__name__, end='') + sys.stdout.flush() + + start = time.time() + # tqdm is here to show you what's going on the first time you run it. + # Feel free to remove it to get slightly more accurate timing results. + for _ in tqdm(range(N), desc=fn.__name__, disable=True): + result = fn(data) + end = time.time() + print("\r", end='') + + total_time = (end - start) + mpx = N * img_size / total_time + mbytes = N * img_size * n_channels / total_time + # Output in tab separated format to enable copy-paste into excel/numbers + print("%s\t%.3f\t%.3f\t%.2f" % (fn.__name__, mpx, mbytes, total_time)) + outimg = Image.fromarray(np.squeeze(result), formats[n_channels]) + outimg.save('./results/{}.png'.format(fn.__name__, "PNG")) + +if __name__ == '__main__': + benchmark() + + +# Example results: +# N = 5, 1024x1024 (1.00 MPx) 1 chan, images/gray_segmentation.png +# Function MPx/sec MB/sec Sec +# simplest_countless 752.855 752.855 0.01 +# quick_countless 920.328 920.328 0.01 +# zero_corrected_countless 534.143 534.143 0.01 +# countless 644.247 644.247 0.01 +# downsample_with_averaging 372.575 372.575 0.01 +# downsample_with_max_pooling 974.060 974.060 0.01 +# ndzoom 137.517 137.517 0.04 +# striding 38550.588 38550.588 0.00 +# countless_if 4.377 4.377 1.14 +# counting 0.117 0.117 42.85 + +# Run without non-numpy implementations: +# N = 2000, 1024x1024 (1.00 MPx) 1 chan, images/gray_segmentation.png +# Algorithm MPx/sec MB/sec Sec +# simplest_countless 800.522 800.522 2.50 +# quick_countless 945.420 945.420 2.12 +# quickest_countless 947.256 947.256 2.11 +# stippled_countless 544.049 544.049 3.68 +# zero_corrected_countless 575.310 575.310 3.48 +# countless 646.684 646.684 3.09 +# downsample_with_averaging 385.132 385.132 5.19 +# downsample_with_max_poolin 988.361 988.361 2.02 +# ndzoom 163.104 163.104 12.26 +# striding 81589.340 81589.340 0.02 + + + + diff --git a/saicinpainting/evaluation/masks/countless/countless3d.py b/saicinpainting/evaluation/masks/countless/countless3d.py new file mode 100644 index 0000000000000000000000000000000000000000..810a71e4b1fa344dd2d731186516dbfa96c9cd03 --- /dev/null +++ b/saicinpainting/evaluation/masks/countless/countless3d.py @@ -0,0 +1,356 @@ +from six.moves import range +from PIL import Image +import numpy as np +import io +import time +import math +import random +import sys +from collections import defaultdict +from copy import deepcopy +from itertools import combinations +from functools import reduce +from tqdm import tqdm + +from memory_profiler import profile + +def countless5(a,b,c,d,e): + """First stage of generalizing from countless2d. + + You have five slots: A, B, C, D, E + + You can decide if something is the winner by first checking for + matches of three, then matches of two, then picking just one if + the other two tries fail. In countless2d, you just check for matches + of two and then pick one of them otherwise. + + Unfortunately, you need to check ABC, ABD, ABE, BCD, BDE, & CDE. + Then you need to check AB, AC, AD, BC, BD + We skip checking E because if none of these match, we pick E. We can + skip checking AE, BE, CE, DE since if any of those match, E is our boy + so it's redundant. + + So countless grows cominatorially in complexity. + """ + sections = [ a,b,c,d,e ] + + p2 = lambda q,r: q * (q == r) # q if p == q else 0 + p3 = lambda q,r,s: q * ( (q == r) & (r == s) ) # q if q == r == s else 0 + + lor = lambda x,y: x + (x == 0) * y + + results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) ) + results3 = reduce(lor, results3) + + results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) ) + results2 = reduce(lor, results2) + + return reduce(lor, (results3, results2, e)) + +def countless8(a,b,c,d,e,f,g,h): + """Extend countless5 to countless8. Same deal, except we also + need to check for matches of length 4.""" + sections = [ a, b, c, d, e, f, g, h ] + + p2 = lambda q,r: q * (q == r) + p3 = lambda q,r,s: q * ( (q == r) & (r == s) ) + p4 = lambda p,q,r,s: p * ( (p == q) & (q == r) & (r == s) ) + + lor = lambda x,y: x + (x == 0) * y + + results4 = ( p4(x,y,z,w) for x,y,z,w in combinations(sections, 4) ) + results4 = reduce(lor, results4) + + results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) ) + results3 = reduce(lor, results3) + + # We can always use our shortcut of omitting the last element + # for N choose 2 + results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) ) + results2 = reduce(lor, results2) + + return reduce(lor, [ results4, results3, results2, h ]) + +def dynamic_countless3d(data): + """countless8 + dynamic programming. ~2x faster""" + sections = [] + + # shift zeros up one so they don't interfere with bitwise operators + # we'll shift down at the end + data += 1 + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + pick = lambda a,b: a * (a == b) + lor = lambda x,y: x + (x == 0) * y + + subproblems2 = {} + + results2 = None + for x,y in combinations(range(7), 2): + res = pick(sections[x], sections[y]) + subproblems2[(x,y)] = res + if results2 is not None: + results2 += (results2 == 0) * res + else: + results2 = res + + subproblems3 = {} + + results3 = None + for x,y,z in combinations(range(8), 3): + res = pick(subproblems2[(x,y)], sections[z]) + + if z != 7: + subproblems3[(x,y,z)] = res + + if results3 is not None: + results3 += (results3 == 0) * res + else: + results3 = res + + results3 = reduce(lor, (results3, results2, sections[-1])) + + # free memory + results2 = None + subproblems2 = None + res = None + + results4 = ( pick(subproblems3[(x,y,z)], sections[w]) for x,y,z,w in combinations(range(8), 4) ) + results4 = reduce(lor, results4) + subproblems3 = None # free memory + + final_result = lor(results4, results3) - 1 + data -= 1 + return final_result + +def countless3d(data): + """Now write countless8 in such a way that it could be used + to process an image.""" + sections = [] + + # shift zeros up one so they don't interfere with bitwise operators + # we'll shift down at the end + data += 1 + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + factor = (2,2,2) + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + p2 = lambda q,r: q * (q == r) + p3 = lambda q,r,s: q * ( (q == r) & (r == s) ) + p4 = lambda p,q,r,s: p * ( (p == q) & (q == r) & (r == s) ) + + lor = lambda x,y: x + (x == 0) * y + + results4 = ( p4(x,y,z,w) for x,y,z,w in combinations(sections, 4) ) + results4 = reduce(lor, results4) + + results3 = ( p3(x,y,z) for x,y,z in combinations(sections, 3) ) + results3 = reduce(lor, results3) + + results2 = ( p2(x,y) for x,y in combinations(sections[:-1], 2) ) + results2 = reduce(lor, results2) + + final_result = reduce(lor, (results4, results3, results2, sections[-1])) - 1 + data -= 1 + return final_result + +def countless_generalized(data, factor): + assert len(data.shape) == len(factor) + + sections = [] + + mode_of = reduce(lambda x,y: x * y, factor) + majority = int(math.ceil(float(mode_of) / 2)) + + data += 1 + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + def pick(elements): + eq = ( elements[i] == elements[i+1] for i in range(len(elements) - 1) ) + anded = reduce(lambda p,q: p & q, eq) + return elements[0] * anded + + def logical_or(x,y): + return x + (x == 0) * y + + result = ( pick(combo) for combo in combinations(sections, majority) ) + result = reduce(logical_or, result) + for i in range(majority - 1, 3-1, -1): # 3-1 b/c of exclusive bounds + partial_result = ( pick(combo) for combo in combinations(sections, i) ) + partial_result = reduce(logical_or, partial_result) + result = logical_or(result, partial_result) + + partial_result = ( pick(combo) for combo in combinations(sections[:-1], 2) ) + partial_result = reduce(logical_or, partial_result) + result = logical_or(result, partial_result) + + result = logical_or(result, sections[-1]) - 1 + data -= 1 + return result + +def dynamic_countless_generalized(data, factor): + assert len(data.shape) == len(factor) + + sections = [] + + mode_of = reduce(lambda x,y: x * y, factor) + majority = int(math.ceil(float(mode_of) / 2)) + + data += 1 # offset from zero + + # This loop splits the 2D array apart into four arrays that are + # all the result of striding by 2 and offset by (0,0), (0,1), (1,0), + # and (1,1) representing the A, B, C, and D positions from Figure 1. + for offset in np.ndindex(factor): + part = data[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + pick = lambda a,b: a * (a == b) + lor = lambda x,y: x + (x == 0) * y # logical or + + subproblems = [ {}, {} ] + results2 = None + for x,y in combinations(range(len(sections) - 1), 2): + res = pick(sections[x], sections[y]) + subproblems[0][(x,y)] = res + if results2 is not None: + results2 = lor(results2, res) + else: + results2 = res + + results = [ results2 ] + for r in range(3, majority+1): + r_results = None + for combo in combinations(range(len(sections)), r): + res = pick(subproblems[0][combo[:-1]], sections[combo[-1]]) + + if combo[-1] != len(sections) - 1: + subproblems[1][combo] = res + + if r_results is not None: + r_results = lor(r_results, res) + else: + r_results = res + results.append(r_results) + subproblems[0] = subproblems[1] + subproblems[1] = {} + + results.reverse() + final_result = lor(reduce(lor, results), sections[-1]) - 1 + data -= 1 + return final_result + +def downsample_with_averaging(array): + """ + Downsample x by factor using averaging. + + @return: The downsampled array, of the same type as x. + """ + factor = (2,2,2) + + if np.array_equal(factor[:3], np.array([1,1,1])): + return array + + output_shape = tuple(int(math.ceil(s / f)) for s, f in zip(array.shape, factor)) + temp = np.zeros(output_shape, float) + counts = np.zeros(output_shape, np.int) + for offset in np.ndindex(factor): + part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + indexing_expr = tuple(np.s_[:s] for s in part.shape) + temp[indexing_expr] += part + counts[indexing_expr] += 1 + return np.cast[array.dtype](temp / counts) + +def downsample_with_max_pooling(array): + + factor = (2,2,2) + + sections = [] + + for offset in np.ndindex(factor): + part = array[tuple(np.s_[o::f] for o, f in zip(offset, factor))] + sections.append(part) + + output = sections[0].copy() + + for section in sections[1:]: + np.maximum(output, section, output) + + return output + +def striding(array): + """Downsample x by factor using striding. + + @return: The downsampled array, of the same type as x. + """ + factor = (2,2,2) + if np.all(np.array(factor, int) == 1): + return array + return array[tuple(np.s_[::f] for f in factor)] + +def benchmark(): + def countless3d_generalized(img): + return countless_generalized(img, (2,8,1)) + def countless3d_dynamic_generalized(img): + return dynamic_countless_generalized(img, (8,8,1)) + + methods = [ + # countless3d, + # dynamic_countless3d, + countless3d_generalized, + # countless3d_dynamic_generalized, + # striding, + # downsample_with_averaging, + # downsample_with_max_pooling + ] + + data = np.zeros(shape=(16**2, 16**2, 16**2), dtype=np.uint8) + 1 + + N = 5 + + print('Algorithm\tMPx\tMB/sec\tSec\tN=%d' % N) + + for fn in methods: + start = time.time() + for _ in range(N): + result = fn(data) + end = time.time() + + total_time = (end - start) + mpx = N * float(data.shape[0] * data.shape[1] * data.shape[2]) / total_time / 1024.0 / 1024.0 + mbytes = mpx * np.dtype(data.dtype).itemsize + # Output in tab separated format to enable copy-paste into excel/numbers + print("%s\t%.3f\t%.3f\t%.2f" % (fn.__name__, mpx, mbytes, total_time)) + +if __name__ == '__main__': + benchmark() + +# Algorithm MPx MB/sec Sec N=5 +# countless3d 10.564 10.564 60.58 +# dynamic_countless3d 22.717 22.717 28.17 +# countless3d_generalized 9.702 9.702 65.96 +# countless3d_dynamic_generalized 22.720 22.720 28.17 +# striding 253360.506 253360.506 0.00 +# downsample_with_averaging 224.098 224.098 2.86 +# downsample_with_max_pooling 690.474 690.474 0.93 + + + diff --git a/saicinpainting/evaluation/masks/countless/requirements.txt b/saicinpainting/evaluation/masks/countless/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..cbf8c87bf9b4c9fe54cb39d722253c0ab59e63ad --- /dev/null +++ b/saicinpainting/evaluation/masks/countless/requirements.txt @@ -0,0 +1,7 @@ +Pillow>=6.2.0 +numpy>=1.16 +scipy +tqdm +memory_profiler +six +pytest \ No newline at end of file diff --git a/saicinpainting/evaluation/masks/mask.py b/saicinpainting/evaluation/masks/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..3e34d0675a781fba983cb542f18390255aaf2609 --- /dev/null +++ b/saicinpainting/evaluation/masks/mask.py @@ -0,0 +1,429 @@ +import enum +from copy import deepcopy + +import numpy as np +from skimage import img_as_ubyte +from skimage.transform import rescale, resize +try: + from detectron2 import model_zoo + from detectron2.config import get_cfg + from detectron2.engine import DefaultPredictor + DETECTRON_INSTALLED = True +except: + print("Detectron v2 is not installed") + DETECTRON_INSTALLED = False + +from .countless.countless2d import zero_corrected_countless + + +class ObjectMask(): + def __init__(self, mask): + self.height, self.width = mask.shape + (self.up, self.down), (self.left, self.right) = self._get_limits(mask) + self.mask = mask[self.up:self.down, self.left:self.right].copy() + + @staticmethod + def _get_limits(mask): + def indicator_limits(indicator): + lower = indicator.argmax() + upper = len(indicator) - indicator[::-1].argmax() + return lower, upper + + vertical_indicator = mask.any(axis=1) + vertical_limits = indicator_limits(vertical_indicator) + + horizontal_indicator = mask.any(axis=0) + horizontal_limits = indicator_limits(horizontal_indicator) + + return vertical_limits, horizontal_limits + + def _clean(self): + self.up, self.down, self.left, self.right = 0, 0, 0, 0 + self.mask = np.empty((0, 0)) + + def horizontal_flip(self, inplace=False): + if not inplace: + flipped = deepcopy(self) + return flipped.horizontal_flip(inplace=True) + + self.mask = self.mask[:, ::-1] + return self + + def vertical_flip(self, inplace=False): + if not inplace: + flipped = deepcopy(self) + return flipped.vertical_flip(inplace=True) + + self.mask = self.mask[::-1, :] + return self + + def image_center(self): + y_center = self.up + (self.down - self.up) / 2 + x_center = self.left + (self.right - self.left) / 2 + return y_center, x_center + + def rescale(self, scaling_factor, inplace=False): + if not inplace: + scaled = deepcopy(self) + return scaled.rescale(scaling_factor, inplace=True) + + scaled_mask = rescale(self.mask.astype(float), scaling_factor, order=0) > 0.5 + (up, down), (left, right) = self._get_limits(scaled_mask) + self.mask = scaled_mask[up:down, left:right] + + y_center, x_center = self.image_center() + mask_height, mask_width = self.mask.shape + self.up = int(round(y_center - mask_height / 2)) + self.down = self.up + mask_height + self.left = int(round(x_center - mask_width / 2)) + self.right = self.left + mask_width + return self + + def crop_to_canvas(self, vertical=True, horizontal=True, inplace=False): + if not inplace: + cropped = deepcopy(self) + cropped.crop_to_canvas(vertical=vertical, horizontal=horizontal, inplace=True) + return cropped + + if vertical: + if self.up >= self.height or self.down <= 0: + self._clean() + else: + cut_up, cut_down = max(-self.up, 0), max(self.down - self.height, 0) + if cut_up != 0: + self.mask = self.mask[cut_up:] + self.up = 0 + if cut_down != 0: + self.mask = self.mask[:-cut_down] + self.down = self.height + + if horizontal: + if self.left >= self.width or self.right <= 0: + self._clean() + else: + cut_left, cut_right = max(-self.left, 0), max(self.right - self.width, 0) + if cut_left != 0: + self.mask = self.mask[:, cut_left:] + self.left = 0 + if cut_right != 0: + self.mask = self.mask[:, :-cut_right] + self.right = self.width + + return self + + def restore_full_mask(self, allow_crop=False): + cropped = self.crop_to_canvas(inplace=allow_crop) + mask = np.zeros((cropped.height, cropped.width), dtype=bool) + mask[cropped.up:cropped.down, cropped.left:cropped.right] = cropped.mask + return mask + + def shift(self, vertical=0, horizontal=0, inplace=False): + if not inplace: + shifted = deepcopy(self) + return shifted.shift(vertical=vertical, horizontal=horizontal, inplace=True) + + self.up += vertical + self.down += vertical + self.left += horizontal + self.right += horizontal + return self + + def area(self): + return self.mask.sum() + + +class RigidnessMode(enum.Enum): + soft = 0 + rigid = 1 + + +class SegmentationMask: + def __init__(self, confidence_threshold=0.5, rigidness_mode=RigidnessMode.rigid, + max_object_area=0.3, min_mask_area=0.02, downsample_levels=6, num_variants_per_mask=4, + max_mask_intersection=0.5, max_foreground_coverage=0.5, max_foreground_intersection=0.5, + max_hidden_area=0.2, max_scale_change=0.25, horizontal_flip=True, + max_vertical_shift=0.1, position_shuffle=True): + """ + :param confidence_threshold: float; threshold for confidence of the panoptic segmentator to allow for + the instance. + :param rigidness_mode: RigidnessMode object + when soft, checks intersection only with the object from which the mask_object was produced + when rigid, checks intersection with any foreground class object + :param max_object_area: float; allowed upper bound for to be considered as mask_object. + :param min_mask_area: float; lower bound for mask to be considered valid + :param downsample_levels: int; defines width of the resized segmentation to obtain shifted masks; + :param num_variants_per_mask: int; maximal number of the masks for the same object; + :param max_mask_intersection: float; maximum allowed area fraction of intersection for 2 masks + produced by horizontal shift of the same mask_object; higher value -> more diversity + :param max_foreground_coverage: float; maximum allowed area fraction of intersection for foreground object to be + covered by mask; lower value -> less the objects are covered + :param max_foreground_intersection: float; maximum allowed area of intersection for the mask with foreground + object; lower value -> mask is more on the background than on the objects + :param max_hidden_area: upper bound on part of the object hidden by shifting object outside the screen area; + :param max_scale_change: allowed scale change for the mask_object; + :param horizontal_flip: if horizontal flips are allowed; + :param max_vertical_shift: amount of vertical movement allowed; + :param position_shuffle: shuffle + """ + + assert DETECTRON_INSTALLED, 'Cannot use SegmentationMask without detectron2' + self.cfg = get_cfg() + self.cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")) + self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml") + self.cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = confidence_threshold + self.predictor = DefaultPredictor(self.cfg) + + self.rigidness_mode = RigidnessMode(rigidness_mode) + self.max_object_area = max_object_area + self.min_mask_area = min_mask_area + self.downsample_levels = downsample_levels + self.num_variants_per_mask = num_variants_per_mask + self.max_mask_intersection = max_mask_intersection + self.max_foreground_coverage = max_foreground_coverage + self.max_foreground_intersection = max_foreground_intersection + self.max_hidden_area = max_hidden_area + self.position_shuffle = position_shuffle + + self.max_scale_change = max_scale_change + self.horizontal_flip = horizontal_flip + self.max_vertical_shift = max_vertical_shift + + def get_segmentation(self, img): + im = img_as_ubyte(img) + panoptic_seg, segment_info = self.predictor(im)["panoptic_seg"] + return panoptic_seg, segment_info + + @staticmethod + def _is_power_of_two(n): + return (n != 0) and (n & (n-1) == 0) + + def identify_candidates(self, panoptic_seg, segments_info): + potential_mask_ids = [] + for segment in segments_info: + if not segment["isthing"]: + continue + mask = (panoptic_seg == segment["id"]).int().detach().cpu().numpy() + area = mask.sum().item() / np.prod(panoptic_seg.shape) + if area >= self.max_object_area: + continue + potential_mask_ids.append(segment["id"]) + return potential_mask_ids + + def downsample_mask(self, mask): + height, width = mask.shape + if not (self._is_power_of_two(height) and self._is_power_of_two(width)): + raise ValueError("Image sides are not power of 2.") + + num_iterations = width.bit_length() - 1 - self.downsample_levels + if num_iterations < 0: + raise ValueError(f"Width is lower than 2^{self.downsample_levels}.") + + if height.bit_length() - 1 < num_iterations: + raise ValueError("Height is too low to perform downsampling") + + downsampled = mask + for _ in range(num_iterations): + downsampled = zero_corrected_countless(downsampled) + + return downsampled + + def _augmentation_params(self): + scaling_factor = np.random.uniform(1 - self.max_scale_change, 1 + self.max_scale_change) + if self.horizontal_flip: + horizontal_flip = bool(np.random.choice(2)) + else: + horizontal_flip = False + vertical_shift = np.random.uniform(-self.max_vertical_shift, self.max_vertical_shift) + + return { + "scaling_factor": scaling_factor, + "horizontal_flip": horizontal_flip, + "vertical_shift": vertical_shift + } + + def _get_intersection(self, mask_array, mask_object): + intersection = mask_array[ + mask_object.up:mask_object.down, mask_object.left:mask_object.right + ] & mask_object.mask + return intersection + + def _check_masks_intersection(self, aug_mask, total_mask_area, prev_masks): + for existing_mask in prev_masks: + intersection_area = self._get_intersection(existing_mask, aug_mask).sum() + intersection_existing = intersection_area / existing_mask.sum() + intersection_current = 1 - (aug_mask.area() - intersection_area) / total_mask_area + if (intersection_existing > self.max_mask_intersection) or \ + (intersection_current > self.max_mask_intersection): + return False + return True + + def _check_foreground_intersection(self, aug_mask, foreground): + for existing_mask in foreground: + intersection_area = self._get_intersection(existing_mask, aug_mask).sum() + intersection_existing = intersection_area / existing_mask.sum() + if intersection_existing > self.max_foreground_coverage: + return False + intersection_mask = intersection_area / aug_mask.area() + if intersection_mask > self.max_foreground_intersection: + return False + return True + + def _move_mask(self, mask, foreground): + # Obtaining properties of the original mask_object: + orig_mask = ObjectMask(mask) + + chosen_masks = [] + chosen_parameters = [] + # to fix the case when resizing gives mask_object consisting only of False + scaling_factor_lower_bound = 0. + + for var_idx in range(self.num_variants_per_mask): + # Obtaining augmentation parameters and applying them to the downscaled mask_object + augmentation_params = self._augmentation_params() + augmentation_params["scaling_factor"] = min([ + augmentation_params["scaling_factor"], + 2 * min(orig_mask.up, orig_mask.height - orig_mask.down) / orig_mask.height + 1., + 2 * min(orig_mask.left, orig_mask.width - orig_mask.right) / orig_mask.width + 1. + ]) + augmentation_params["scaling_factor"] = max([ + augmentation_params["scaling_factor"], scaling_factor_lower_bound + ]) + + aug_mask = deepcopy(orig_mask) + aug_mask.rescale(augmentation_params["scaling_factor"], inplace=True) + if augmentation_params["horizontal_flip"]: + aug_mask.horizontal_flip(inplace=True) + total_aug_area = aug_mask.area() + if total_aug_area == 0: + scaling_factor_lower_bound = 1. + continue + + # Fix if the element vertical shift is too strong and shown area is too small: + vertical_area = aug_mask.mask.sum(axis=1) / total_aug_area # share of area taken by rows + # number of rows which are allowed to be hidden from upper and lower parts of image respectively + max_hidden_up = np.searchsorted(vertical_area.cumsum(), self.max_hidden_area) + max_hidden_down = np.searchsorted(vertical_area[::-1].cumsum(), self.max_hidden_area) + # correcting vertical shift, so not too much area will be hidden + augmentation_params["vertical_shift"] = np.clip( + augmentation_params["vertical_shift"], + -(aug_mask.up + max_hidden_up) / aug_mask.height, + (aug_mask.height - aug_mask.down + max_hidden_down) / aug_mask.height + ) + # Applying vertical shift: + vertical_shift = int(round(aug_mask.height * augmentation_params["vertical_shift"])) + aug_mask.shift(vertical=vertical_shift, inplace=True) + aug_mask.crop_to_canvas(vertical=True, horizontal=False, inplace=True) + + # Choosing horizontal shift: + max_hidden_area = self.max_hidden_area - (1 - aug_mask.area() / total_aug_area) + horizontal_area = aug_mask.mask.sum(axis=0) / total_aug_area + max_hidden_left = np.searchsorted(horizontal_area.cumsum(), max_hidden_area) + max_hidden_right = np.searchsorted(horizontal_area[::-1].cumsum(), max_hidden_area) + allowed_shifts = np.arange(-max_hidden_left, aug_mask.width - + (aug_mask.right - aug_mask.left) + max_hidden_right + 1) + allowed_shifts = - (aug_mask.left - allowed_shifts) + + if self.position_shuffle: + np.random.shuffle(allowed_shifts) + + mask_is_found = False + for horizontal_shift in allowed_shifts: + aug_mask_left = deepcopy(aug_mask) + aug_mask_left.shift(horizontal=horizontal_shift, inplace=True) + aug_mask_left.crop_to_canvas(inplace=True) + + prev_masks = [mask] + chosen_masks + is_mask_suitable = self._check_masks_intersection(aug_mask_left, total_aug_area, prev_masks) & \ + self._check_foreground_intersection(aug_mask_left, foreground) + if is_mask_suitable: + aug_draw = aug_mask_left.restore_full_mask() + chosen_masks.append(aug_draw) + augmentation_params["horizontal_shift"] = horizontal_shift / aug_mask_left.width + chosen_parameters.append(augmentation_params) + mask_is_found = True + break + + if not mask_is_found: + break + + return chosen_parameters + + def _prepare_mask(self, mask): + height, width = mask.shape + target_width = width if self._is_power_of_two(width) else (1 << width.bit_length()) + target_height = height if self._is_power_of_two(height) else (1 << height.bit_length()) + + return resize(mask.astype('float32'), (target_height, target_width), order=0, mode='edge').round().astype('int32') + + def get_masks(self, im, return_panoptic=False): + panoptic_seg, segments_info = self.get_segmentation(im) + potential_mask_ids = self.identify_candidates(panoptic_seg, segments_info) + + panoptic_seg_scaled = self._prepare_mask(panoptic_seg.detach().cpu().numpy()) + downsampled = self.downsample_mask(panoptic_seg_scaled) + scene_objects = [] + for segment in segments_info: + if not segment["isthing"]: + continue + mask = downsampled == segment["id"] + if not np.any(mask): + continue + scene_objects.append(mask) + + mask_set = [] + for mask_id in potential_mask_ids: + mask = downsampled == mask_id + if not np.any(mask): + continue + + if self.rigidness_mode is RigidnessMode.soft: + foreground = [mask] + elif self.rigidness_mode is RigidnessMode.rigid: + foreground = scene_objects + else: + raise ValueError(f'Unexpected rigidness_mode: {rigidness_mode}') + + masks_params = self._move_mask(mask, foreground) + + full_mask = ObjectMask((panoptic_seg == mask_id).detach().cpu().numpy()) + + for params in masks_params: + aug_mask = deepcopy(full_mask) + aug_mask.rescale(params["scaling_factor"], inplace=True) + if params["horizontal_flip"]: + aug_mask.horizontal_flip(inplace=True) + + vertical_shift = int(round(aug_mask.height * params["vertical_shift"])) + horizontal_shift = int(round(aug_mask.width * params["horizontal_shift"])) + aug_mask.shift(vertical=vertical_shift, horizontal=horizontal_shift, inplace=True) + aug_mask = aug_mask.restore_full_mask().astype('uint8') + if aug_mask.mean() <= self.min_mask_area: + continue + mask_set.append(aug_mask) + + if return_panoptic: + return mask_set, panoptic_seg.detach().cpu().numpy() + else: + return mask_set + + +def propose_random_square_crop(mask, min_overlap=0.5): + height, width = mask.shape + mask_ys, mask_xs = np.where(mask > 0.5) # mask==0 is known fragment and mask==1 is missing + + if height < width: + crop_size = height + obj_left, obj_right = mask_xs.min(), mask_xs.max() + obj_width = obj_right - obj_left + left_border = max(0, min(width - crop_size - 1, obj_left + obj_width * min_overlap - crop_size)) + right_border = max(left_border + 1, min(width - crop_size, obj_left + obj_width * min_overlap)) + start_x = np.random.randint(left_border, right_border) + return start_x, 0, start_x + crop_size, height + else: + crop_size = width + obj_top, obj_bottom = mask_ys.min(), mask_ys.max() + obj_height = obj_bottom - obj_top + top_border = max(0, min(height - crop_size - 1, obj_top + obj_height * min_overlap - crop_size)) + bottom_border = max(top_border + 1, min(height - crop_size, obj_top + obj_height * min_overlap)) + start_y = np.random.randint(top_border, bottom_border) + return 0, start_y, width, start_y + crop_size diff --git a/saicinpainting/evaluation/utils.py b/saicinpainting/evaluation/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7c15c9242ed8a9bc59fbb3b450cca394720bb8 --- /dev/null +++ b/saicinpainting/evaluation/utils.py @@ -0,0 +1,28 @@ +from enum import Enum + +import yaml +from easydict import EasyDict as edict +import torch.nn as nn +import torch + + +def load_yaml(path): + with open(path, 'r') as f: + return edict(yaml.safe_load(f)) + + +def move_to_device(obj, device): + if isinstance(obj, nn.Module): + return obj.to(device) + if torch.is_tensor(obj): + return obj.to(device) + if isinstance(obj, (tuple, list)): + return [move_to_device(el, device) for el in obj] + if isinstance(obj, dict): + return {name: move_to_device(val, device) for name, val in obj.items()} + raise ValueError(f'Unexpected type {type(obj)}') + + +class SmallMode(Enum): + DROP = "drop" + UPSCALE = "upscale" diff --git a/saicinpainting/evaluation/vis.py b/saicinpainting/evaluation/vis.py new file mode 100644 index 0000000000000000000000000000000000000000..c2910b4ef8c61efee72dabd0531a9b669ec8bf98 --- /dev/null +++ b/saicinpainting/evaluation/vis.py @@ -0,0 +1,37 @@ +import numpy as np +from skimage import io +from skimage.segmentation import mark_boundaries + + +def save_item_for_vis(item, out_file): + mask = item['mask'] > 0.5 + if mask.ndim == 3: + mask = mask[0] + img = mark_boundaries(np.transpose(item['image'], (1, 2, 0)), + mask, + color=(1., 0., 0.), + outline_color=(1., 1., 1.), + mode='thick') + + if 'inpainted' in item: + inp_img = mark_boundaries(np.transpose(item['inpainted'], (1, 2, 0)), + mask, + color=(1., 0., 0.), + mode='outer') + img = np.concatenate((img, inp_img), axis=1) + + img = np.clip(img * 255, 0, 255).astype('uint8') + io.imsave(out_file, img) + + +def save_mask_for_sidebyside(item, out_file): + mask = item['mask']# > 0.5 + if mask.ndim == 3: + mask = mask[0] + mask = np.clip(mask * 255, 0, 255).astype('uint8') + io.imsave(out_file, mask) + +def save_img_for_sidebyside(item, out_file): + img = np.transpose(item['image'], (1, 2, 0)) + img = np.clip(img * 255, 0, 255).astype('uint8') + io.imsave(out_file, img) \ No newline at end of file diff --git a/saicinpainting/training/__init__.py b/saicinpainting/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/saicinpainting/training/data/__init__.py b/saicinpainting/training/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/saicinpainting/training/data/aug.py b/saicinpainting/training/data/aug.py new file mode 100644 index 0000000000000000000000000000000000000000..b1246250924e79511b58cd3d7ab79de8012f8949 --- /dev/null +++ b/saicinpainting/training/data/aug.py @@ -0,0 +1,84 @@ +from albumentations import DualIAATransform, to_tuple +import imgaug.augmenters as iaa + +class IAAAffine2(DualIAATransform): + """Place a regular grid of points on the input and randomly move the neighbourhood of these point around + via affine transformations. + + Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} + + Args: + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask + """ + + def __init__( + self, + scale=(0.7, 1.3), + translate_percent=None, + translate_px=None, + rotate=0.0, + shear=(-0.1, 0.1), + order=1, + cval=0, + mode="reflect", + always_apply=False, + p=0.5, + ): + super(IAAAffine2, self).__init__(always_apply, p) + self.scale = dict(x=scale, y=scale) + self.translate_percent = to_tuple(translate_percent, 0) + self.translate_px = to_tuple(translate_px, 0) + self.rotate = to_tuple(rotate) + self.shear = dict(x=shear, y=shear) + self.order = order + self.cval = cval + self.mode = mode + + @property + def processor(self): + return iaa.Affine( + self.scale, + self.translate_percent, + self.translate_px, + self.rotate, + self.shear, + self.order, + self.cval, + self.mode, + ) + + def get_transform_init_args_names(self): + return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode") + + +class IAAPerspective2(DualIAATransform): + """Perform a random four point perspective transform of the input. + + Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} + + Args: + scale ((float, float): standard deviation of the normal distributions. These are used to sample + the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask + """ + + def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5, + order=1, cval=0, mode="replicate"): + super(IAAPerspective2, self).__init__(always_apply, p) + self.scale = to_tuple(scale, 1.0) + self.keep_size = keep_size + self.cval = cval + self.mode = mode + + @property + def processor(self): + return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size, mode=self.mode, cval=self.cval) + + def get_transform_init_args_names(self): + return ("scale", "keep_size") diff --git a/saicinpainting/training/data/datasets.py b/saicinpainting/training/data/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f503dafffb970d8dbaca33934da417036d1e55 --- /dev/null +++ b/saicinpainting/training/data/datasets.py @@ -0,0 +1,304 @@ +import glob +import logging +import os +import random + +import albumentations as A +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import webdataset +from omegaconf import open_dict, OmegaConf +from skimage.feature import canny +from skimage.transform import rescale, resize +from torch.utils.data import Dataset, IterableDataset, DataLoader, DistributedSampler, ConcatDataset + +from saicinpainting.evaluation.data import InpaintingDataset as InpaintingEvaluationDataset, \ + OurInpaintingDataset as OurInpaintingEvaluationDataset, ceil_modulo, InpaintingEvalOnlineDataset +from saicinpainting.training.data.aug import IAAAffine2, IAAPerspective2 +from saicinpainting.training.data.masks import get_mask_generator + +LOGGER = logging.getLogger(__name__) + + +class InpaintingTrainDataset(Dataset): + def __init__(self, indir, mask_generator, transform): + self.in_files = list(glob.glob(os.path.join(indir, '**', '*.jpg'), recursive=True)) + self.mask_generator = mask_generator + self.transform = transform + self.iter_i = 0 + + def __len__(self): + return len(self.in_files) + + def __getitem__(self, item): + path = self.in_files[item] + img = cv2.imread(path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = self.transform(image=img)['image'] + img = np.transpose(img, (2, 0, 1)) + # TODO: maybe generate mask before augmentations? slower, but better for segmentation-based masks + mask = self.mask_generator(img, iter_i=self.iter_i) + self.iter_i += 1 + return dict(image=img, + mask=mask) + + +class InpaintingTrainWebDataset(IterableDataset): + def __init__(self, indir, mask_generator, transform, shuffle_buffer=200): + self.impl = webdataset.Dataset(indir).shuffle(shuffle_buffer).decode('rgb').to_tuple('jpg') + self.mask_generator = mask_generator + self.transform = transform + + def __iter__(self): + for iter_i, (img,) in enumerate(self.impl): + img = np.clip(img * 255, 0, 255).astype('uint8') + img = self.transform(image=img)['image'] + img = np.transpose(img, (2, 0, 1)) + mask = self.mask_generator(img, iter_i=iter_i) + yield dict(image=img, + mask=mask) + + +class ImgSegmentationDataset(Dataset): + def __init__(self, indir, mask_generator, transform, out_size, segm_indir, semantic_seg_n_classes): + self.indir = indir + self.segm_indir = segm_indir + self.mask_generator = mask_generator + self.transform = transform + self.out_size = out_size + self.semantic_seg_n_classes = semantic_seg_n_classes + self.in_files = list(glob.glob(os.path.join(indir, '**', '*.jpg'), recursive=True)) + + def __len__(self): + return len(self.in_files) + + def __getitem__(self, item): + path = self.in_files[item] + img = cv2.imread(path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, (self.out_size, self.out_size)) + img = self.transform(image=img)['image'] + img = np.transpose(img, (2, 0, 1)) + mask = self.mask_generator(img) + segm, segm_classes= self.load_semantic_segm(path) + result = dict(image=img, + mask=mask, + segm=segm, + segm_classes=segm_classes) + return result + + def load_semantic_segm(self, img_path): + segm_path = img_path.replace(self.indir, self.segm_indir).replace(".jpg", ".png") + mask = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE) + mask = cv2.resize(mask, (self.out_size, self.out_size)) + tensor = torch.from_numpy(np.clip(mask.astype(int)-1, 0, None)) + ohe = F.one_hot(tensor.long(), num_classes=self.semantic_seg_n_classes) # w x h x n_classes + return ohe.permute(2, 0, 1).float(), tensor.unsqueeze(0) + + +def get_transforms(transform_variant, out_size): + if transform_variant == 'default': + transform = A.Compose([ + A.RandomScale(scale_limit=0.2), # +/- 20% + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.RandomCrop(height=out_size, width=out_size), + A.HorizontalFlip(), + A.CLAHE(), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + elif transform_variant == 'distortions': + transform = A.Compose([ + IAAPerspective2(scale=(0.0, 0.06)), + IAAAffine2(scale=(0.7, 1.3), + rotate=(-40, 40), + shear=(-0.1, 0.1)), + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.OpticalDistortion(), + A.RandomCrop(height=out_size, width=out_size), + A.HorizontalFlip(), + A.CLAHE(), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + elif transform_variant == 'distortions_scale05_1': + transform = A.Compose([ + IAAPerspective2(scale=(0.0, 0.06)), + IAAAffine2(scale=(0.5, 1.0), + rotate=(-40, 40), + shear=(-0.1, 0.1), + p=1), + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.OpticalDistortion(), + A.RandomCrop(height=out_size, width=out_size), + A.HorizontalFlip(), + A.CLAHE(), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + elif transform_variant == 'distortions_scale03_12': + transform = A.Compose([ + IAAPerspective2(scale=(0.0, 0.06)), + IAAAffine2(scale=(0.3, 1.2), + rotate=(-40, 40), + shear=(-0.1, 0.1), + p=1), + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.OpticalDistortion(), + A.RandomCrop(height=out_size, width=out_size), + A.HorizontalFlip(), + A.CLAHE(), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + elif transform_variant == 'distortions_scale03_07': + transform = A.Compose([ + IAAPerspective2(scale=(0.0, 0.06)), + IAAAffine2(scale=(0.3, 0.7), # scale 512 to 256 in average + rotate=(-40, 40), + shear=(-0.1, 0.1), + p=1), + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.OpticalDistortion(), + A.RandomCrop(height=out_size, width=out_size), + A.HorizontalFlip(), + A.CLAHE(), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + elif transform_variant == 'distortions_light': + transform = A.Compose([ + IAAPerspective2(scale=(0.0, 0.02)), + IAAAffine2(scale=(0.8, 1.8), + rotate=(-20, 20), + shear=(-0.03, 0.03)), + A.PadIfNeeded(min_height=out_size, min_width=out_size), + A.RandomCrop(height=out_size, width=out_size), + A.HorizontalFlip(), + A.CLAHE(), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + elif transform_variant == 'non_space_transform': + transform = A.Compose([ + A.CLAHE(), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2), + A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=30, val_shift_limit=5), + A.ToFloat() + ]) + elif transform_variant == 'no_augs': + transform = A.Compose([ + A.ToFloat() + ]) + else: + raise ValueError(f'Unexpected transform_variant {transform_variant}') + return transform + + +def make_default_train_dataloader(indir, kind='default', out_size=512, mask_gen_kwargs=None, transform_variant='default', + mask_generator_kind="mixed", dataloader_kwargs=None, ddp_kwargs=None, **kwargs): + LOGGER.info(f'Make train dataloader {kind} from {indir}. Using mask generator={mask_generator_kind}') + + mask_generator = get_mask_generator(kind=mask_generator_kind, kwargs=mask_gen_kwargs) + transform = get_transforms(transform_variant, out_size) + + if kind == 'default': + dataset = InpaintingTrainDataset(indir=indir, + mask_generator=mask_generator, + transform=transform, + **kwargs) + elif kind == 'default_web': + dataset = InpaintingTrainWebDataset(indir=indir, + mask_generator=mask_generator, + transform=transform, + **kwargs) + elif kind == 'img_with_segm': + dataset = ImgSegmentationDataset(indir=indir, + mask_generator=mask_generator, + transform=transform, + out_size=out_size, + **kwargs) + else: + raise ValueError(f'Unknown train dataset kind {kind}') + + if dataloader_kwargs is None: + dataloader_kwargs = {} + + is_dataset_only_iterable = kind in ('default_web',) + + if ddp_kwargs is not None and not is_dataset_only_iterable: + dataloader_kwargs['shuffle'] = False + dataloader_kwargs['sampler'] = DistributedSampler(dataset, **ddp_kwargs) + + if is_dataset_only_iterable and 'shuffle' in dataloader_kwargs: + with open_dict(dataloader_kwargs): + del dataloader_kwargs['shuffle'] + + dataloader = DataLoader(dataset, **dataloader_kwargs) + return dataloader + + +def make_default_val_dataset(indir, kind='default', out_size=512, transform_variant='default', **kwargs): + if OmegaConf.is_list(indir) or isinstance(indir, (tuple, list)): + return ConcatDataset([ + make_default_val_dataset(idir, kind=kind, out_size=out_size, transform_variant=transform_variant, **kwargs) for idir in indir + ]) + + LOGGER.info(f'Make val dataloader {kind} from {indir}') + mask_generator = get_mask_generator(kind=kwargs.get("mask_generator_kind"), kwargs=kwargs.get("mask_gen_kwargs")) + + if transform_variant is not None: + transform = get_transforms(transform_variant, out_size) + + if kind == 'default': + dataset = InpaintingEvaluationDataset(indir, **kwargs) + elif kind == 'our_eval': + dataset = OurInpaintingEvaluationDataset(indir, **kwargs) + elif kind == 'img_with_segm': + dataset = ImgSegmentationDataset(indir=indir, + mask_generator=mask_generator, + transform=transform, + out_size=out_size, + **kwargs) + elif kind == 'online': + dataset = InpaintingEvalOnlineDataset(indir=indir, + mask_generator=mask_generator, + transform=transform, + out_size=out_size, + **kwargs) + else: + raise ValueError(f'Unknown val dataset kind {kind}') + + return dataset + + +def make_default_val_dataloader(*args, dataloader_kwargs=None, **kwargs): + dataset = make_default_val_dataset(*args, **kwargs) + + if dataloader_kwargs is None: + dataloader_kwargs = {} + dataloader = DataLoader(dataset, **dataloader_kwargs) + return dataloader + + +def make_constant_area_crop_params(img_height, img_width, min_size=128, max_size=512, area=256*256, round_to_mod=16): + min_size = min(img_height, img_width, min_size) + max_size = min(img_height, img_width, max_size) + if random.random() < 0.5: + out_height = min(max_size, ceil_modulo(random.randint(min_size, max_size), round_to_mod)) + out_width = min(max_size, ceil_modulo(area // out_height, round_to_mod)) + else: + out_width = min(max_size, ceil_modulo(random.randint(min_size, max_size), round_to_mod)) + out_height = min(max_size, ceil_modulo(area // out_width, round_to_mod)) + + start_y = random.randint(0, img_height - out_height) + start_x = random.randint(0, img_width - out_width) + return (start_y, start_x, out_height, out_width) diff --git a/saicinpainting/training/data/masks.py b/saicinpainting/training/data/masks.py new file mode 100644 index 0000000000000000000000000000000000000000..e91fc74913356481065c5f5906acd50fb05f521c --- /dev/null +++ b/saicinpainting/training/data/masks.py @@ -0,0 +1,332 @@ +import math +import random +import hashlib +import logging +from enum import Enum + +import cv2 +import numpy as np + +from saicinpainting.evaluation.masks.mask import SegmentationMask +from saicinpainting.utils import LinearRamp + +LOGGER = logging.getLogger(__name__) + + +class DrawMethod(Enum): + LINE = 'line' + CIRCLE = 'circle' + SQUARE = 'square' + + +def make_random_irregular_mask(shape, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10, + draw_method=DrawMethod.LINE): + draw_method = DrawMethod(draw_method) + + height, width = shape + mask = np.zeros((height, width), np.float32) + times = np.random.randint(min_times, max_times + 1) + for i in range(times): + start_x = np.random.randint(width) + start_y = np.random.randint(height) + for j in range(1 + np.random.randint(5)): + angle = 0.01 + np.random.randint(max_angle) + if i % 2 == 0: + angle = 2 * 3.1415926 - angle + length = 10 + np.random.randint(max_len) + brush_w = 5 + np.random.randint(max_width) + end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width) + end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height) + if draw_method == DrawMethod.LINE: + cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w) + elif draw_method == DrawMethod.CIRCLE: + cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1., thickness=-1) + elif draw_method == DrawMethod.SQUARE: + radius = brush_w // 2 + mask[start_y - radius:start_y + radius, start_x - radius:start_x + radius] = 1 + start_x, start_y = end_x, end_y + return mask[None, ...] + + +class RandomIrregularMaskGenerator: + def __init__(self, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10, ramp_kwargs=None, + draw_method=DrawMethod.LINE): + self.max_angle = max_angle + self.max_len = max_len + self.max_width = max_width + self.min_times = min_times + self.max_times = max_times + self.draw_method = draw_method + self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None + + def __call__(self, img, iter_i=None, raw_image=None): + coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1 + cur_max_len = int(max(1, self.max_len * coef)) + cur_max_width = int(max(1, self.max_width * coef)) + cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times) * coef) + return make_random_irregular_mask(img.shape[1:], max_angle=self.max_angle, max_len=cur_max_len, + max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times, + draw_method=self.draw_method) + + +def make_random_rectangle_mask(shape, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3): + height, width = shape + mask = np.zeros((height, width), np.float32) + bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2) + times = np.random.randint(min_times, max_times + 1) + for i in range(times): + box_width = np.random.randint(bbox_min_size, bbox_max_size) + box_height = np.random.randint(bbox_min_size, bbox_max_size) + start_x = np.random.randint(margin, width - margin - box_width + 1) + start_y = np.random.randint(margin, height - margin - box_height + 1) + mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1 + return mask[None, ...] + + +class RandomRectangleMaskGenerator: + def __init__(self, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3, ramp_kwargs=None): + self.margin = margin + self.bbox_min_size = bbox_min_size + self.bbox_max_size = bbox_max_size + self.min_times = min_times + self.max_times = max_times + self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None + + def __call__(self, img, iter_i=None, raw_image=None): + coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1 + cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size) * coef) + cur_max_times = int(self.min_times + (self.max_times - self.min_times) * coef) + return make_random_rectangle_mask(img.shape[1:], margin=self.margin, bbox_min_size=self.bbox_min_size, + bbox_max_size=cur_bbox_max_size, min_times=self.min_times, + max_times=cur_max_times) + + +class RandomSegmentationMaskGenerator: + def __init__(self, **kwargs): + self.impl = None # will be instantiated in first call (effectively in subprocess) + self.kwargs = kwargs + + def __call__(self, img, iter_i=None, raw_image=None): + if self.impl is None: + self.impl = SegmentationMask(**self.kwargs) + + masks = self.impl.get_masks(np.transpose(img, (1, 2, 0))) + masks = [m for m in masks if len(np.unique(m)) > 1] + return np.random.choice(masks) + + +def make_random_superres_mask(shape, min_step=2, max_step=4, min_width=1, max_width=3): + height, width = shape + mask = np.zeros((height, width), np.float32) + step_x = np.random.randint(min_step, max_step + 1) + width_x = np.random.randint(min_width, min(step_x, max_width + 1)) + offset_x = np.random.randint(0, step_x) + + step_y = np.random.randint(min_step, max_step + 1) + width_y = np.random.randint(min_width, min(step_y, max_width + 1)) + offset_y = np.random.randint(0, step_y) + + for dy in range(width_y): + mask[offset_y + dy::step_y] = 1 + for dx in range(width_x): + mask[:, offset_x + dx::step_x] = 1 + return mask[None, ...] + + +class RandomSuperresMaskGenerator: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __call__(self, img, iter_i=None): + return make_random_superres_mask(img.shape[1:], **self.kwargs) + + +class DumbAreaMaskGenerator: + min_ratio = 0.1 + max_ratio = 0.35 + default_ratio = 0.225 + + def __init__(self, is_training): + #Parameters: + # is_training(bool): If true - random rectangular mask, if false - central square mask + self.is_training = is_training + + def _random_vector(self, dimension): + if self.is_training: + lower_limit = math.sqrt(self.min_ratio) + upper_limit = math.sqrt(self.max_ratio) + mask_side = round((random.random() * (upper_limit - lower_limit) + lower_limit) * dimension) + u = random.randint(0, dimension-mask_side-1) + v = u+mask_side + else: + margin = (math.sqrt(self.default_ratio) / 2) * dimension + u = round(dimension/2 - margin) + v = round(dimension/2 + margin) + return u, v + + def __call__(self, img, iter_i=None, raw_image=None): + c, height, width = img.shape + mask = np.zeros((height, width), np.float32) + x1, x2 = self._random_vector(width) + y1, y2 = self._random_vector(height) + mask[x1:x2, y1:y2] = 1 + return mask[None, ...] + + +class OutpaintingMaskGenerator: + def __init__(self, min_padding_percent:float=0.04, max_padding_percent:int=0.25, left_padding_prob:float=0.5, top_padding_prob:float=0.5, + right_padding_prob:float=0.5, bottom_padding_prob:float=0.5, is_fixed_randomness:bool=False): + """ + is_fixed_randomness - get identical paddings for the same image if args are the same + """ + self.min_padding_percent = min_padding_percent + self.max_padding_percent = max_padding_percent + self.probs = [left_padding_prob, top_padding_prob, right_padding_prob, bottom_padding_prob] + self.is_fixed_randomness = is_fixed_randomness + + assert self.min_padding_percent <= self.max_padding_percent + assert self.max_padding_percent > 0 + assert len([x for x in [self.min_padding_percent, self.max_padding_percent] if (x>=0 and x<=1)]) == 2, f"Padding percentage should be in [0,1]" + assert sum(self.probs) > 0, f"At least one of the padding probs should be greater than 0 - {self.probs}" + assert len([x for x in self.probs if (x >= 0) and (x <= 1)]) == 4, f"At least one of padding probs is not in [0,1] - {self.probs}" + if len([x for x in self.probs if x > 0]) == 1: + LOGGER.warning(f"Only one padding prob is greater than zero - {self.probs}. That means that the outpainting masks will be always on the same side") + + def apply_padding(self, mask, coord): + mask[int(coord[0][0]*self.img_h):int(coord[1][0]*self.img_h), + int(coord[0][1]*self.img_w):int(coord[1][1]*self.img_w)] = 1 + return mask + + def get_padding(self, size): + n1 = int(self.min_padding_percent*size) + n2 = int(self.max_padding_percent*size) + return self.rnd.randint(n1, n2) / size + + @staticmethod + def _img2rs(img): + arr = np.ascontiguousarray(img.astype(np.uint8)) + str_hash = hashlib.sha1(arr).hexdigest() + res = hash(str_hash)%(2**32) + return res + + def __call__(self, img, iter_i=None, raw_image=None): + c, self.img_h, self.img_w = img.shape + mask = np.zeros((self.img_h, self.img_w), np.float32) + at_least_one_mask_applied = False + + if self.is_fixed_randomness: + assert raw_image is not None, f"Cant calculate hash on raw_image=None" + rs = self._img2rs(raw_image) + self.rnd = np.random.RandomState(rs) + else: + self.rnd = np.random + + coords = [[ + (0,0), + (1,self.get_padding(size=self.img_h)) + ], + [ + (0,0), + (self.get_padding(size=self.img_w),1) + ], + [ + (0,1-self.get_padding(size=self.img_h)), + (1,1) + ], + [ + (1-self.get_padding(size=self.img_w),0), + (1,1) + ]] + + for pp, coord in zip(self.probs, coords): + if self.rnd.random() < pp: + at_least_one_mask_applied = True + mask = self.apply_padding(mask=mask, coord=coord) + + if not at_least_one_mask_applied: + idx = self.rnd.choice(range(len(coords)), p=np.array(self.probs)/sum(self.probs)) + mask = self.apply_padding(mask=mask, coord=coords[idx]) + return mask[None, ...] + + +class MixedMaskGenerator: + def __init__(self, irregular_proba=1/3, irregular_kwargs=None, + box_proba=1/3, box_kwargs=None, + segm_proba=1/3, segm_kwargs=None, + squares_proba=0, squares_kwargs=None, + superres_proba=0, superres_kwargs=None, + outpainting_proba=0, outpainting_kwargs=None, + invert_proba=0): + self.probas = [] + self.gens = [] + + if irregular_proba > 0: + self.probas.append(irregular_proba) + if irregular_kwargs is None: + irregular_kwargs = {} + else: + irregular_kwargs = dict(irregular_kwargs) + irregular_kwargs['draw_method'] = DrawMethod.LINE + self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs)) + + if box_proba > 0: + self.probas.append(box_proba) + if box_kwargs is None: + box_kwargs = {} + self.gens.append(RandomRectangleMaskGenerator(**box_kwargs)) + + if segm_proba > 0: + self.probas.append(segm_proba) + if segm_kwargs is None: + segm_kwargs = {} + self.gens.append(RandomSegmentationMaskGenerator(**segm_kwargs)) + + if squares_proba > 0: + self.probas.append(squares_proba) + if squares_kwargs is None: + squares_kwargs = {} + else: + squares_kwargs = dict(squares_kwargs) + squares_kwargs['draw_method'] = DrawMethod.SQUARE + self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs)) + + if superres_proba > 0: + self.probas.append(superres_proba) + if superres_kwargs is None: + superres_kwargs = {} + self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs)) + + if outpainting_proba > 0: + self.probas.append(outpainting_proba) + if outpainting_kwargs is None: + outpainting_kwargs = {} + self.gens.append(OutpaintingMaskGenerator(**outpainting_kwargs)) + + self.probas = np.array(self.probas, dtype='float32') + self.probas /= self.probas.sum() + self.invert_proba = invert_proba + + def __call__(self, img, iter_i=None, raw_image=None): + kind = np.random.choice(len(self.probas), p=self.probas) + gen = self.gens[kind] + result = gen(img, iter_i=iter_i, raw_image=raw_image) + if self.invert_proba > 0 and random.random() < self.invert_proba: + result = 1 - result + return result + + +def get_mask_generator(kind, kwargs): + if kind is None: + kind = "mixed" + if kwargs is None: + kwargs = {} + + if kind == "mixed": + cl = MixedMaskGenerator + elif kind == "outpainting": + cl = OutpaintingMaskGenerator + elif kind == "dumb": + cl = DumbAreaMaskGenerator + else: + raise NotImplementedError(f"No such generator kind = {kind}") + return cl(**kwargs) diff --git a/saicinpainting/training/losses/__init__.py b/saicinpainting/training/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/saicinpainting/training/losses/adversarial.py b/saicinpainting/training/losses/adversarial.py new file mode 100644 index 0000000000000000000000000000000000000000..d6db2967ce5074d94ed3b4c51fc743ff2f7831b1 --- /dev/null +++ b/saicinpainting/training/losses/adversarial.py @@ -0,0 +1,177 @@ +from typing import Tuple, Dict, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BaseAdversarialLoss: + def pre_generator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + generator: nn.Module, discriminator: nn.Module): + """ + Prepare for generator step + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param generator: + :param discriminator: + :return: None + """ + + def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + generator: nn.Module, discriminator: nn.Module): + """ + Prepare for discriminator step + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param generator: + :param discriminator: + :return: None + """ + + def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask: Optional[torch.Tensor] = None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Calculate generator loss + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param discr_real_pred: Tensor, discriminator output for real_batch + :param discr_fake_pred: Tensor, discriminator output for fake_batch + :param mask: Tensor, actual mask, which was at input of generator when making fake_batch + :return: total generator loss along with some values that might be interesting to log + """ + raise NotImplemented() + + def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask: Optional[torch.Tensor] = None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Calculate discriminator loss and call .backward() on it + :param real_batch: Tensor, a batch of real samples + :param fake_batch: Tensor, a batch of samples produced by generator + :param discr_real_pred: Tensor, discriminator output for real_batch + :param discr_fake_pred: Tensor, discriminator output for fake_batch + :param mask: Tensor, actual mask, which was at input of generator when making fake_batch + :return: total discriminator loss along with some values that might be interesting to log + """ + raise NotImplemented() + + def interpolate_mask(self, mask, shape): + assert mask is not None + assert self.allow_scale_mask or shape == mask.shape[-2:] + if shape != mask.shape[-2:] and self.allow_scale_mask: + if self.mask_scale_mode == 'maxpool': + mask = F.adaptive_max_pool2d(mask, shape) + else: + mask = F.interpolate(mask, size=shape, mode=self.mask_scale_mode) + return mask + +def make_r1_gp(discr_real_pred, real_batch): + if torch.is_grad_enabled(): + grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0] + grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean() + else: + grad_penalty = 0 + real_batch.requires_grad = False + + return grad_penalty + +class NonSaturatingWithR1(BaseAdversarialLoss): + def __init__(self, gp_coef=5, weight=1, mask_as_fake_target=False, allow_scale_mask=False, + mask_scale_mode='nearest', extra_mask_weight_for_gen=0, + use_unmasked_for_gen=True, use_unmasked_for_discr=True): + self.gp_coef = gp_coef + self.weight = weight + # use for discr => use for gen; + # otherwise we teach only the discr to pay attention to very small difference + assert use_unmasked_for_gen or (not use_unmasked_for_discr) + # mask as target => use unmasked for discr: + # if we don't care about unmasked regions at all + # then it doesn't matter if the value of mask_as_fake_target is true or false + assert use_unmasked_for_discr or (not mask_as_fake_target) + self.use_unmasked_for_gen = use_unmasked_for_gen + self.use_unmasked_for_discr = use_unmasked_for_discr + self.mask_as_fake_target = mask_as_fake_target + self.allow_scale_mask = allow_scale_mask + self.mask_scale_mode = mask_scale_mode + self.extra_mask_weight_for_gen = extra_mask_weight_for_gen + + def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask=None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + fake_loss = F.softplus(-discr_fake_pred) + if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \ + not self.use_unmasked_for_gen: # == if masked region should be treated differently + mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) + if not self.use_unmasked_for_gen: + fake_loss = fake_loss * mask + else: + pixel_weights = 1 + mask * self.extra_mask_weight_for_gen + fake_loss = fake_loss * pixel_weights + + return fake_loss.mean() * self.weight, dict() + + def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + generator: nn.Module, discriminator: nn.Module): + real_batch.requires_grad = True + + def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor, + mask=None) \ + -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + + real_loss = F.softplus(-discr_real_pred) + grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef + fake_loss = F.softplus(discr_fake_pred) + + if not self.use_unmasked_for_discr or self.mask_as_fake_target: + # == if masked region should be treated differently + mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) + # use_unmasked_for_discr=False only makes sense for fakes; + # for reals there is no difference beetween two regions + fake_loss = fake_loss * mask + if self.mask_as_fake_target: + fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred) + + sum_discr_loss = real_loss + grad_penalty + fake_loss + metrics = dict(discr_real_out=discr_real_pred.mean(), + discr_fake_out=discr_fake_pred.mean(), + discr_real_gp=grad_penalty) + return sum_discr_loss.mean(), metrics + +class BCELoss(BaseAdversarialLoss): + def __init__(self, weight): + self.weight = weight + self.bce_loss = nn.BCEWithLogitsLoss() + + def generator_loss(self, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + real_mask_gt = torch.zeros(discr_fake_pred.shape).to(discr_fake_pred.device) + fake_loss = self.bce_loss(discr_fake_pred, real_mask_gt) * self.weight + return fake_loss, dict() + + def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor, + generator: nn.Module, discriminator: nn.Module): + real_batch.requires_grad = True + + def discriminator_loss(self, + mask: torch.Tensor, + discr_real_pred: torch.Tensor, + discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + + real_mask_gt = torch.zeros(discr_real_pred.shape).to(discr_real_pred.device) + sum_discr_loss = (self.bce_loss(discr_real_pred, real_mask_gt) + self.bce_loss(discr_fake_pred, mask)) / 2 + metrics = dict(discr_real_out=discr_real_pred.mean(), + discr_fake_out=discr_fake_pred.mean(), + discr_real_gp=0) + return sum_discr_loss, metrics + + +def make_discrim_loss(kind, **kwargs): + if kind == 'r1': + return NonSaturatingWithR1(**kwargs) + elif kind == 'bce': + return BCELoss(**kwargs) + raise ValueError(f'Unknown adversarial loss kind {kind}') diff --git a/saicinpainting/training/losses/constants.py b/saicinpainting/training/losses/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..ae3e5e151342232be8e2c2a77fe6fd5798dc2a8c --- /dev/null +++ b/saicinpainting/training/losses/constants.py @@ -0,0 +1,152 @@ +weights = {"ade20k": + [6.34517766497462, + 9.328358208955224, + 11.389521640091116, + 16.10305958132045, + 20.833333333333332, + 22.22222222222222, + 25.125628140703515, + 43.29004329004329, + 50.5050505050505, + 54.6448087431694, + 55.24861878453038, + 60.24096385542168, + 62.5, + 66.2251655629139, + 84.74576271186442, + 90.90909090909092, + 91.74311926605505, + 96.15384615384616, + 96.15384615384616, + 97.08737864077669, + 102.04081632653062, + 135.13513513513513, + 149.2537313432836, + 153.84615384615384, + 163.93442622950818, + 166.66666666666666, + 188.67924528301887, + 192.30769230769232, + 217.3913043478261, + 227.27272727272725, + 227.27272727272725, + 227.27272727272725, + 303.03030303030306, + 322.5806451612903, + 333.3333333333333, + 370.3703703703703, + 384.61538461538464, + 416.6666666666667, + 416.6666666666667, + 434.7826086956522, + 434.7826086956522, + 454.5454545454545, + 454.5454545454545, + 500.0, + 526.3157894736842, + 526.3157894736842, + 555.5555555555555, + 555.5555555555555, + 555.5555555555555, + 555.5555555555555, + 555.5555555555555, + 555.5555555555555, + 555.5555555555555, + 588.2352941176471, + 588.2352941176471, + 588.2352941176471, + 588.2352941176471, + 588.2352941176471, + 666.6666666666666, + 666.6666666666666, + 666.6666666666666, + 666.6666666666666, + 714.2857142857143, + 714.2857142857143, + 714.2857142857143, + 714.2857142857143, + 714.2857142857143, + 769.2307692307693, + 769.2307692307693, + 769.2307692307693, + 833.3333333333334, + 833.3333333333334, + 833.3333333333334, + 833.3333333333334, + 909.090909090909, + 1000.0, + 1111.111111111111, + 1111.111111111111, + 1111.111111111111, + 1111.111111111111, + 1111.111111111111, + 1250.0, + 1250.0, + 1250.0, + 1250.0, + 1250.0, + 1428.5714285714287, + 1428.5714285714287, + 1428.5714285714287, + 1428.5714285714287, + 1428.5714285714287, + 1428.5714285714287, + 1428.5714285714287, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 1666.6666666666667, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 2500.0, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 3333.3333333333335, + 5000.0, + 5000.0, + 5000.0] +} \ No newline at end of file diff --git a/saicinpainting/training/losses/distance_weighting.py b/saicinpainting/training/losses/distance_weighting.py new file mode 100644 index 0000000000000000000000000000000000000000..93052003b1e47fd663c70aedcecd144171f49204 --- /dev/null +++ b/saicinpainting/training/losses/distance_weighting.py @@ -0,0 +1,126 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from saicinpainting.training.losses.perceptual import IMAGENET_STD, IMAGENET_MEAN + + +def dummy_distance_weighter(real_img, pred_img, mask): + return mask + + +def get_gauss_kernel(kernel_size, width_factor=1): + coords = torch.stack(torch.meshgrid(torch.arange(kernel_size), + torch.arange(kernel_size)), + dim=0).float() + diff = torch.exp(-((coords - kernel_size // 2) ** 2).sum(0) / kernel_size / width_factor) + diff /= diff.sum() + return diff + + +class BlurMask(nn.Module): + def __init__(self, kernel_size=5, width_factor=1): + super().__init__() + self.filter = nn.Conv2d(1, 1, kernel_size, padding=kernel_size // 2, padding_mode='replicate', bias=False) + self.filter.weight.data.copy_(get_gauss_kernel(kernel_size, width_factor=width_factor)) + + def forward(self, real_img, pred_img, mask): + with torch.no_grad(): + result = self.filter(mask) * mask + return result + + +class EmulatedEDTMask(nn.Module): + def __init__(self, dilate_kernel_size=5, blur_kernel_size=5, width_factor=1): + super().__init__() + self.dilate_filter = nn.Conv2d(1, 1, dilate_kernel_size, padding=dilate_kernel_size// 2, padding_mode='replicate', + bias=False) + self.dilate_filter.weight.data.copy_(torch.ones(1, 1, dilate_kernel_size, dilate_kernel_size, dtype=torch.float)) + self.blur_filter = nn.Conv2d(1, 1, blur_kernel_size, padding=blur_kernel_size // 2, padding_mode='replicate', bias=False) + self.blur_filter.weight.data.copy_(get_gauss_kernel(blur_kernel_size, width_factor=width_factor)) + + def forward(self, real_img, pred_img, mask): + with torch.no_grad(): + known_mask = 1 - mask + dilated_known_mask = (self.dilate_filter(known_mask) > 1).float() + result = self.blur_filter(1 - dilated_known_mask) * mask + return result + + +class PropagatePerceptualSim(nn.Module): + def __init__(self, level=2, max_iters=10, temperature=500, erode_mask_size=3): + super().__init__() + vgg = torchvision.models.vgg19(pretrained=True).features + vgg_avg_pooling = [] + + for weights in vgg.parameters(): + weights.requires_grad = False + + cur_level_i = 0 + for module in vgg.modules(): + if module.__class__.__name__ == 'Sequential': + continue + elif module.__class__.__name__ == 'MaxPool2d': + vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) + else: + vgg_avg_pooling.append(module) + if module.__class__.__name__ == 'ReLU': + cur_level_i += 1 + if cur_level_i == level: + break + + self.features = nn.Sequential(*vgg_avg_pooling) + + self.max_iters = max_iters + self.temperature = temperature + self.do_erode = erode_mask_size > 0 + if self.do_erode: + self.erode_mask = nn.Conv2d(1, 1, erode_mask_size, padding=erode_mask_size // 2, bias=False) + self.erode_mask.weight.data.fill_(1) + + def forward(self, real_img, pred_img, mask): + with torch.no_grad(): + real_img = (real_img - IMAGENET_MEAN.to(real_img)) / IMAGENET_STD.to(real_img) + real_feats = self.features(real_img) + + vertical_sim = torch.exp(-(real_feats[:, :, 1:] - real_feats[:, :, :-1]).pow(2).sum(1, keepdim=True) + / self.temperature) + horizontal_sim = torch.exp(-(real_feats[:, :, :, 1:] - real_feats[:, :, :, :-1]).pow(2).sum(1, keepdim=True) + / self.temperature) + + mask_scaled = F.interpolate(mask, size=real_feats.shape[-2:], mode='bilinear', align_corners=False) + if self.do_erode: + mask_scaled = (self.erode_mask(mask_scaled) > 1).float() + + cur_knowness = 1 - mask_scaled + + for iter_i in range(self.max_iters): + new_top_knowness = F.pad(cur_knowness[:, :, :-1] * vertical_sim, (0, 0, 1, 0), mode='replicate') + new_bottom_knowness = F.pad(cur_knowness[:, :, 1:] * vertical_sim, (0, 0, 0, 1), mode='replicate') + + new_left_knowness = F.pad(cur_knowness[:, :, :, :-1] * horizontal_sim, (1, 0, 0, 0), mode='replicate') + new_right_knowness = F.pad(cur_knowness[:, :, :, 1:] * horizontal_sim, (0, 1, 0, 0), mode='replicate') + + new_knowness = torch.stack([new_top_knowness, new_bottom_knowness, + new_left_knowness, new_right_knowness], + dim=0).max(0).values + + cur_knowness = torch.max(cur_knowness, new_knowness) + + cur_knowness = F.interpolate(cur_knowness, size=mask.shape[-2:], mode='bilinear') + result = torch.min(mask, 1 - cur_knowness) + + return result + + +def make_mask_distance_weighter(kind='none', **kwargs): + if kind == 'none': + return dummy_distance_weighter + if kind == 'blur': + return BlurMask(**kwargs) + if kind == 'edt': + return EmulatedEDTMask(**kwargs) + if kind == 'pps': + return PropagatePerceptualSim(**kwargs) + raise ValueError(f'Unknown mask distance weighter kind {kind}') diff --git a/saicinpainting/training/losses/feature_matching.py b/saicinpainting/training/losses/feature_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..c019895c9178817837d1a6773367b178a861dc61 --- /dev/null +++ b/saicinpainting/training/losses/feature_matching.py @@ -0,0 +1,33 @@ +from typing import List + +import torch +import torch.nn.functional as F + + +def masked_l2_loss(pred, target, mask, weight_known, weight_missing): + per_pixel_l2 = F.mse_loss(pred, target, reduction='none') + pixel_weights = mask * weight_missing + (1 - mask) * weight_known + return (pixel_weights * per_pixel_l2).mean() + + +def masked_l1_loss(pred, target, mask, weight_known, weight_missing): + per_pixel_l1 = F.l1_loss(pred, target, reduction='none') + pixel_weights = mask * weight_missing + (1 - mask) * weight_known + return (pixel_weights * per_pixel_l1).mean() + + +def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None): + if mask is None: + res = torch.stack([F.mse_loss(fake_feat, target_feat) + for fake_feat, target_feat in zip(fake_features, target_features)]).mean() + else: + res = 0 + norm = 0 + for fake_feat, target_feat in zip(fake_features, target_features): + cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False) + error_weights = 1 - cur_mask + cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean() + res = res + cur_val + norm += 1 + res = res / norm + return res diff --git a/saicinpainting/training/losses/perceptual.py b/saicinpainting/training/losses/perceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..8c055c2b327ce7943682af5c5f9394b9fcbec506 --- /dev/null +++ b/saicinpainting/training/losses/perceptual.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from models.ade20k import ModelBuilder +from saicinpainting.utils import check_and_warn_input_range + + +IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] +IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] + + +class PerceptualLoss(nn.Module): + def __init__(self, normalize_inputs=True): + super(PerceptualLoss, self).__init__() + + self.normalize_inputs = normalize_inputs + self.mean_ = IMAGENET_MEAN + self.std_ = IMAGENET_STD + + vgg = torchvision.models.vgg19(pretrained=True).features + vgg_avg_pooling = [] + + for weights in vgg.parameters(): + weights.requires_grad = False + + for module in vgg.modules(): + if module.__class__.__name__ == 'Sequential': + continue + elif module.__class__.__name__ == 'MaxPool2d': + vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) + else: + vgg_avg_pooling.append(module) + + self.vgg = nn.Sequential(*vgg_avg_pooling) + + def do_normalize_inputs(self, x): + return (x - self.mean_.to(x.device)) / self.std_.to(x.device) + + def partial_losses(self, input, target, mask=None): + check_and_warn_input_range(target, 0, 1, 'PerceptualLoss target in partial_losses') + + # we expect input and target to be in [0, 1] range + losses = [] + + if self.normalize_inputs: + features_input = self.do_normalize_inputs(input) + features_target = self.do_normalize_inputs(target) + else: + features_input = input + features_target = target + + for layer in self.vgg[:30]: + + features_input = layer(features_input) + features_target = layer(features_target) + + if layer.__class__.__name__ == 'ReLU': + loss = F.mse_loss(features_input, features_target, reduction='none') + + if mask is not None: + cur_mask = F.interpolate(mask, size=features_input.shape[-2:], + mode='bilinear', align_corners=False) + loss = loss * (1 - cur_mask) + + loss = loss.mean(dim=tuple(range(1, len(loss.shape)))) + losses.append(loss) + + return losses + + def forward(self, input, target, mask=None): + losses = self.partial_losses(input, target, mask=mask) + return torch.stack(losses).sum(dim=0) + + def get_global_features(self, input): + check_and_warn_input_range(input, 0, 1, 'PerceptualLoss input in get_global_features') + + if self.normalize_inputs: + features_input = self.do_normalize_inputs(input) + else: + features_input = input + + features_input = self.vgg(features_input) + return features_input + + +class ResNetPL(nn.Module): + def __init__(self, weight=1, + weights_path=None, arch_encoder='resnet50dilated', segmentation=True): + super().__init__() + self.impl = ModelBuilder.get_encoder(weights_path=weights_path, + arch_encoder=arch_encoder, + arch_decoder='ppm_deepsup', + fc_dim=2048, + segmentation=segmentation) + self.impl.eval() + for w in self.impl.parameters(): + w.requires_grad_(False) + + self.weight = weight + + def forward(self, pred, target): + pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred) + target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target) + + pred_feats = self.impl(pred, return_feature_maps=True) + target_feats = self.impl(target, return_feature_maps=True) + + result = torch.stack([F.mse_loss(cur_pred, cur_target) + for cur_pred, cur_target + in zip(pred_feats, target_feats)]).sum() * self.weight + return result diff --git a/saicinpainting/training/losses/segmentation.py b/saicinpainting/training/losses/segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..3d4a9f94eaae84722db584277dbbf9bc41ede357 --- /dev/null +++ b/saicinpainting/training/losses/segmentation.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .constants import weights as constant_weights + + +class CrossEntropy2d(nn.Module): + def __init__(self, reduction="mean", ignore_label=255, weights=None, *args, **kwargs): + """ + weight (Tensor, optional): a manual rescaling weight given to each class. + If given, has to be a Tensor of size "nclasses" + """ + super(CrossEntropy2d, self).__init__() + self.reduction = reduction + self.ignore_label = ignore_label + self.weights = weights + if self.weights is not None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.weights = torch.FloatTensor(constant_weights[weights]).to(device) + + def forward(self, predict, target): + """ + Args: + predict:(n, c, h, w) + target:(n, 1, h, w) + """ + target = target.long() + assert not target.requires_grad + assert predict.dim() == 4, "{0}".format(predict.size()) + assert target.dim() == 4, "{0}".format(target.size()) + assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) + assert target.size(1) == 1, "{0}".format(target.size(1)) + assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2)) + assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3)) + target = target.squeeze(1) + n, c, h, w = predict.size() + target_mask = (target >= 0) * (target != self.ignore_label) + target = target[target_mask] + predict = predict.transpose(1, 2).transpose(2, 3).contiguous() + predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) + loss = F.cross_entropy(predict, target, weight=self.weights, reduction=self.reduction) + return loss diff --git a/saicinpainting/training/losses/style_loss.py b/saicinpainting/training/losses/style_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0bb42d7fbc5d17a47bec7365889868505f5fdfb5 --- /dev/null +++ b/saicinpainting/training/losses/style_loss.py @@ -0,0 +1,155 @@ +import torch +import torch.nn as nn +import torchvision.models as models + + +class PerceptualLoss(nn.Module): + r""" + Perceptual loss, VGG-based + https://arxiv.org/abs/1603.08155 + https://github.com/dxyang/StyleTransfer/blob/master/utils.py + """ + + def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): + super(PerceptualLoss, self).__init__() + self.add_module('vgg', VGG19()) + self.criterion = torch.nn.L1Loss() + self.weights = weights + + def __call__(self, x, y): + # Compute features + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + + content_loss = 0.0 + content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1']) + content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1']) + content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1']) + content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1']) + content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1']) + + + return content_loss + + +class VGG19(torch.nn.Module): + def __init__(self): + super(VGG19, self).__init__() + features = models.vgg19(pretrained=True).features + self.relu1_1 = torch.nn.Sequential() + self.relu1_2 = torch.nn.Sequential() + + self.relu2_1 = torch.nn.Sequential() + self.relu2_2 = torch.nn.Sequential() + + self.relu3_1 = torch.nn.Sequential() + self.relu3_2 = torch.nn.Sequential() + self.relu3_3 = torch.nn.Sequential() + self.relu3_4 = torch.nn.Sequential() + + self.relu4_1 = torch.nn.Sequential() + self.relu4_2 = torch.nn.Sequential() + self.relu4_3 = torch.nn.Sequential() + self.relu4_4 = torch.nn.Sequential() + + self.relu5_1 = torch.nn.Sequential() + self.relu5_2 = torch.nn.Sequential() + self.relu5_3 = torch.nn.Sequential() + self.relu5_4 = torch.nn.Sequential() + + for x in range(2): + self.relu1_1.add_module(str(x), features[x]) + + for x in range(2, 4): + self.relu1_2.add_module(str(x), features[x]) + + for x in range(4, 7): + self.relu2_1.add_module(str(x), features[x]) + + for x in range(7, 9): + self.relu2_2.add_module(str(x), features[x]) + + for x in range(9, 12): + self.relu3_1.add_module(str(x), features[x]) + + for x in range(12, 14): + self.relu3_2.add_module(str(x), features[x]) + + for x in range(14, 16): + self.relu3_2.add_module(str(x), features[x]) + + for x in range(16, 18): + self.relu3_4.add_module(str(x), features[x]) + + for x in range(18, 21): + self.relu4_1.add_module(str(x), features[x]) + + for x in range(21, 23): + self.relu4_2.add_module(str(x), features[x]) + + for x in range(23, 25): + self.relu4_3.add_module(str(x), features[x]) + + for x in range(25, 27): + self.relu4_4.add_module(str(x), features[x]) + + for x in range(27, 30): + self.relu5_1.add_module(str(x), features[x]) + + for x in range(30, 32): + self.relu5_2.add_module(str(x), features[x]) + + for x in range(32, 34): + self.relu5_3.add_module(str(x), features[x]) + + for x in range(34, 36): + self.relu5_4.add_module(str(x), features[x]) + + # don't need the gradients, just want the features + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + relu1_1 = self.relu1_1(x) + relu1_2 = self.relu1_2(relu1_1) + + relu2_1 = self.relu2_1(relu1_2) + relu2_2 = self.relu2_2(relu2_1) + + relu3_1 = self.relu3_1(relu2_2) + relu3_2 = self.relu3_2(relu3_1) + relu3_3 = self.relu3_3(relu3_2) + relu3_4 = self.relu3_4(relu3_3) + + relu4_1 = self.relu4_1(relu3_4) + relu4_2 = self.relu4_2(relu4_1) + relu4_3 = self.relu4_3(relu4_2) + relu4_4 = self.relu4_4(relu4_3) + + relu5_1 = self.relu5_1(relu4_4) + relu5_2 = self.relu5_2(relu5_1) + relu5_3 = self.relu5_3(relu5_2) + relu5_4 = self.relu5_4(relu5_3) + + out = { + 'relu1_1': relu1_1, + 'relu1_2': relu1_2, + + 'relu2_1': relu2_1, + 'relu2_2': relu2_2, + + 'relu3_1': relu3_1, + 'relu3_2': relu3_2, + 'relu3_3': relu3_3, + 'relu3_4': relu3_4, + + 'relu4_1': relu4_1, + 'relu4_2': relu4_2, + 'relu4_3': relu4_3, + 'relu4_4': relu4_4, + + 'relu5_1': relu5_1, + 'relu5_2': relu5_2, + 'relu5_3': relu5_3, + 'relu5_4': relu5_4, + } + return out diff --git a/saicinpainting/training/modules/__init__.py b/saicinpainting/training/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..82e1a9096a5bd8f3fb00e899d0239b078246cad4 --- /dev/null +++ b/saicinpainting/training/modules/__init__.py @@ -0,0 +1,31 @@ +import logging + +from saicinpainting.training.modules.ffc import FFCResNetGenerator +from saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \ + NLayerDiscriminator, MultidilatedNLayerDiscriminator + +def make_generator(config, kind, **kwargs): + logging.info(f'Make generator {kind}') + + if kind == 'pix2pixhd_multidilated': + return MultiDilatedGlobalGenerator(**kwargs) + + if kind == 'pix2pixhd_global': + return GlobalGenerator(**kwargs) + + if kind == 'ffc_resnet': + return FFCResNetGenerator(**kwargs) + + raise ValueError(f'Unknown generator kind {kind}') + + +def make_discriminator(kind, **kwargs): + logging.info(f'Make discriminator {kind}') + + if kind == 'pix2pixhd_nlayer_multidilated': + return MultidilatedNLayerDiscriminator(**kwargs) + + if kind == 'pix2pixhd_nlayer': + return NLayerDiscriminator(**kwargs) + + raise ValueError(f'Unknown discriminator kind {kind}') diff --git a/saicinpainting/training/modules/base.py b/saicinpainting/training/modules/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a50c3fc7753a0bba64a5ab8c1ed64ff97e62313f --- /dev/null +++ b/saicinpainting/training/modules/base.py @@ -0,0 +1,80 @@ +import abc +from typing import Tuple, List + +import torch +import torch.nn as nn + +from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv +from saicinpainting.training.modules.multidilated_conv import MultidilatedConv + + +class BaseDiscriminator(nn.Module): + @abc.abstractmethod + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Predict scores and get intermediate activations. Useful for feature matching loss + :return tuple (scores, list of intermediate activations) + """ + raise NotImplemented() + + +def get_conv_block_ctor(kind='default'): + if not isinstance(kind, str): + return kind + if kind == 'default': + return nn.Conv2d + if kind == 'depthwise': + return DepthWiseSeperableConv + if kind == 'multidilated': + return MultidilatedConv + raise ValueError(f'Unknown convolutional block kind {kind}') + + +def get_norm_layer(kind='bn'): + if not isinstance(kind, str): + return kind + if kind == 'bn': + return nn.BatchNorm2d + if kind == 'in': + return nn.InstanceNorm2d + raise ValueError(f'Unknown norm block kind {kind}') + + +def get_activation(kind='tanh'): + if kind == 'tanh': + return nn.Tanh() + if kind == 'sigmoid': + return nn.Sigmoid() + if kind is False: + return nn.Identity() + raise ValueError(f'Unknown activation kind {kind}') + + +class SimpleMultiStepGenerator(nn.Module): + def __init__(self, steps: List[nn.Module]): + super().__init__() + self.steps = nn.ModuleList(steps) + + def forward(self, x): + cur_in = x + outs = [] + for step in self.steps: + cur_out = step(cur_in) + outs.append(cur_out) + cur_in = torch.cat((cur_in, cur_out), dim=1) + return torch.cat(outs[::-1], dim=1) + +def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features): + if kind == 'convtranspose': + return [nn.ConvTranspose2d(min(max_features, ngf * mult), + min(max_features, int(ngf * mult / 2)), + kernel_size=3, stride=2, padding=1, output_padding=1), + norm_layer(min(max_features, int(ngf * mult / 2))), activation] + elif kind == 'bilinear': + return [nn.Upsample(scale_factor=2, mode='bilinear'), + DepthWiseSeperableConv(min(max_features, ngf * mult), + min(max_features, int(ngf * mult / 2)), + kernel_size=3, stride=1, padding=1), + norm_layer(min(max_features, int(ngf * mult / 2))), activation] + else: + raise Exception(f"Invalid deconv kind: {kind}") \ No newline at end of file diff --git a/saicinpainting/training/modules/depthwise_sep_conv.py b/saicinpainting/training/modules/depthwise_sep_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..83dd15c3df1d9f40baf0091a373fa224532c9ddd --- /dev/null +++ b/saicinpainting/training/modules/depthwise_sep_conv.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn + +class DepthWiseSeperableConv(nn.Module): + def __init__(self, in_dim, out_dim, *args, **kwargs): + super().__init__() + if 'groups' in kwargs: + # ignoring groups for Depthwise Sep Conv + del kwargs['groups'] + + self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs) + self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward(self, x): + out = self.depthwise(x) + out = self.pointwise(out) + return out \ No newline at end of file diff --git a/saicinpainting/training/modules/fake_fakes.py b/saicinpainting/training/modules/fake_fakes.py new file mode 100644 index 0000000000000000000000000000000000000000..45c4ad559cef2730b771a709197e00ae1c87683c --- /dev/null +++ b/saicinpainting/training/modules/fake_fakes.py @@ -0,0 +1,47 @@ +import torch +from kornia import SamplePadding +from kornia.augmentation import RandomAffine, CenterCrop + + +class FakeFakesGenerator: + def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2): + self.grad_aug = RandomAffine(degrees=360, + translate=0.2, + padding_mode=SamplePadding.REFLECTION, + keepdim=False, + p=1) + self.img_aug = RandomAffine(degrees=img_aug_degree, + translate=img_aug_translate, + padding_mode=SamplePadding.REFLECTION, + keepdim=True, + p=1) + self.aug_proba = aug_proba + + def __call__(self, input_images, masks): + blend_masks = self._fill_masks_with_gradient(masks) + blend_target = self._make_blend_target(input_images) + result = input_images * (1 - blend_masks) + blend_target * blend_masks + return result, blend_masks + + def _make_blend_target(self, input_images): + batch_size = input_images.shape[0] + permuted = input_images[torch.randperm(batch_size)] + augmented = self.img_aug(input_images) + is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float() + result = augmented * is_aug + permuted * (1 - is_aug) + return result + + def _fill_masks_with_gradient(self, masks): + batch_size, _, height, width = masks.shape + grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \ + .view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2) + grad = self.grad_aug(grad) + grad = CenterCrop((height, width))(grad) + grad *= masks + + grad_for_min = grad + (1 - masks) * 10 + grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None] + grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6 + grad.clamp_(min=0, max=1) + + return grad diff --git a/saicinpainting/training/modules/ffc.py b/saicinpainting/training/modules/ffc.py new file mode 100644 index 0000000000000000000000000000000000000000..0e7b84683fccb4bccac97b6371994fa6bb44dbe4 --- /dev/null +++ b/saicinpainting/training/modules/ffc.py @@ -0,0 +1,485 @@ +# Fast Fourier Convolution NeurIPS 2020 +# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py +# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from saicinpainting.training.modules.base import get_activation, BaseDiscriminator +from saicinpainting.training.modules.spatial_transform import LearnableSpatialTransformWrapper +from saicinpainting.training.modules.squeeze_excitation import SELayer +from saicinpainting.utils import get_shape + + +class FFCSE_block(nn.Module): + + def __init__(self, channels, ratio_g): + super(FFCSE_block, self).__init__() + in_cg = int(channels * ratio_g) + in_cl = channels - in_cg + r = 16 + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.conv1 = nn.Conv2d(channels, channels // r, + kernel_size=1, bias=True) + self.relu1 = nn.ReLU(inplace=True) + self.conv_a2l = None if in_cl == 0 else nn.Conv2d( + channels // r, in_cl, kernel_size=1, bias=True) + self.conv_a2g = None if in_cg == 0 else nn.Conv2d( + channels // r, in_cg, kernel_size=1, bias=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + x = x if type(x) is tuple else (x, 0) + id_l, id_g = x + + x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1) + x = self.avgpool(x) + x = self.relu1(self.conv1(x)) + + x_l = 0 if self.conv_a2l is None else id_l * \ + self.sigmoid(self.conv_a2l(x)) + x_g = 0 if self.conv_a2g is None else id_g * \ + self.sigmoid(self.conv_a2g(x)) + return x_l, x_g + + +class FourierUnit(nn.Module): + + def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear', + spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'): + # bn_layer not used + super(FourierUnit, self).__init__() + self.groups = groups + + self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), + out_channels=out_channels * 2, + kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False) + self.bn = torch.nn.BatchNorm2d(out_channels * 2) + self.relu = torch.nn.ReLU(inplace=True) + + # squeeze and excitation block + self.use_se = use_se + if use_se: + if se_kwargs is None: + se_kwargs = {} + self.se = SELayer(self.conv_layer.in_channels, **se_kwargs) + + self.spatial_scale_factor = spatial_scale_factor + self.spatial_scale_mode = spatial_scale_mode + self.spectral_pos_encoding = spectral_pos_encoding + self.ffc3d = ffc3d + self.fft_norm = fft_norm + + def forward(self, x): + batch = x.shape[0] + + if self.spatial_scale_factor is not None: + orig_size = x.shape[-2:] + x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False) + + r_size = x.size() + # (batch, c, h, w/2+1, 2) + fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) + ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) + ffted = torch.stack((ffted.real, ffted.imag), dim=-1) + ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) + ffted = ffted.view((batch, -1,) + ffted.size()[3:]) + + if self.spectral_pos_encoding: + height, width = ffted.shape[-2:] + coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted) + coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted) + ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) + + if self.use_se: + ffted = self.se(ffted) + + ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) + ffted = self.relu(self.bn(ffted)) + + ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute( + 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) + ffted = torch.complex(ffted[..., 0], ffted[..., 1]) + + ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] + output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) + + if self.spatial_scale_factor is not None: + output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False) + + return output + + +class SeparableFourierUnit(nn.Module): + + def __init__(self, in_channels, out_channels, groups=1, kernel_size=3): + # bn_layer not used + super(SeparableFourierUnit, self).__init__() + self.groups = groups + row_out_channels = out_channels // 2 + col_out_channels = out_channels - row_out_channels + self.row_conv = torch.nn.Conv2d(in_channels=in_channels * 2, + out_channels=row_out_channels * 2, + kernel_size=(kernel_size, 1), # kernel size is always like this, but the data will be transposed + stride=1, padding=(kernel_size // 2, 0), + padding_mode='reflect', + groups=self.groups, bias=False) + self.col_conv = torch.nn.Conv2d(in_channels=in_channels * 2, + out_channels=col_out_channels * 2, + kernel_size=(kernel_size, 1), # kernel size is always like this, but the data will be transposed + stride=1, padding=(kernel_size // 2, 0), + padding_mode='reflect', + groups=self.groups, bias=False) + self.row_bn = torch.nn.BatchNorm2d(row_out_channels * 2) + self.col_bn = torch.nn.BatchNorm2d(col_out_channels * 2) + self.relu = torch.nn.ReLU(inplace=True) + + def process_branch(self, x, conv, bn): + batch = x.shape[0] + + r_size = x.size() + # (batch, c, h, w/2+1, 2) + ffted = torch.fft.rfft(x, norm="ortho") + ffted = torch.stack((ffted.real, ffted.imag), dim=-1) + ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) + ffted = ffted.view((batch, -1,) + ffted.size()[3:]) + + ffted = self.relu(bn(conv(ffted))) + + ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute( + 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) + ffted = torch.complex(ffted[..., 0], ffted[..., 1]) + + output = torch.fft.irfft(ffted, s=x.shape[-1:], norm="ortho") + return output + + + def forward(self, x): + rowwise = self.process_branch(x, self.row_conv, self.row_bn) + colwise = self.process_branch(x.permute(0, 1, 3, 2), self.col_conv, self.col_bn).permute(0, 1, 3, 2) + out = torch.cat((rowwise, colwise), dim=1) + return out + + +class SpectralTransform(nn.Module): + + def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, separable_fu=False, **fu_kwargs): + # bn_layer not used + super(SpectralTransform, self).__init__() + self.enable_lfu = enable_lfu + if stride == 2: + self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) + else: + self.downsample = nn.Identity() + + self.stride = stride + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels // + 2, kernel_size=1, groups=groups, bias=False), + nn.BatchNorm2d(out_channels // 2), + nn.ReLU(inplace=True) + ) + fu_class = SeparableFourierUnit if separable_fu else FourierUnit + self.fu = fu_class( + out_channels // 2, out_channels // 2, groups, **fu_kwargs) + if self.enable_lfu: + self.lfu = fu_class( + out_channels // 2, out_channels // 2, groups) + self.conv2 = torch.nn.Conv2d( + out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False) + + def forward(self, x): + + x = self.downsample(x) + x = self.conv1(x) + output = self.fu(x) + + if self.enable_lfu: + n, c, h, w = x.shape + split_no = 2 + split_s = h // split_no + xs = torch.cat(torch.split( + x[:, :c // 4], split_s, dim=-2), dim=1).contiguous() + xs = torch.cat(torch.split(xs, split_s, dim=-1), + dim=1).contiguous() + xs = self.lfu(xs) + xs = xs.repeat(1, 1, split_no, split_no).contiguous() + else: + xs = 0 + + output = self.conv2(x + output + xs) + + return output + + +class FFC(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, + ratio_gin, ratio_gout, stride=1, padding=0, + dilation=1, groups=1, bias=False, enable_lfu=True, + padding_type='reflect', gated=False, **spectral_kwargs): + super(FFC, self).__init__() + + assert stride == 1 or stride == 2, "Stride should be 1 or 2." + self.stride = stride + + in_cg = int(in_channels * ratio_gin) + in_cl = in_channels - in_cg + out_cg = int(out_channels * ratio_gout) + out_cl = out_channels - out_cg + #groups_g = 1 if groups == 1 else int(groups * ratio_gout) + #groups_l = 1 if groups == 1 else groups - groups_g + + self.ratio_gin = ratio_gin + self.ratio_gout = ratio_gout + self.global_in_num = in_cg + + module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d + self.convl2l = module(in_cl, out_cl, kernel_size, + stride, padding, dilation, groups, bias, padding_mode=padding_type) + module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d + self.convl2g = module(in_cl, out_cg, kernel_size, + stride, padding, dilation, groups, bias, padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d + self.convg2l = module(in_cg, out_cl, kernel_size, + stride, padding, dilation, groups, bias, padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform + self.convg2g = module( + in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs) + + self.gated = gated + module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d + self.gate = module(in_channels, 2, 1) + + def forward(self, x): + x_l, x_g = x if type(x) is tuple else (x, 0) + out_xl, out_xg = 0, 0 + + if self.gated: + total_input_parts = [x_l] + if torch.is_tensor(x_g): + total_input_parts.append(x_g) + total_input = torch.cat(total_input_parts, dim=1) + + gates = torch.sigmoid(self.gate(total_input)) + g2l_gate, l2g_gate = gates.chunk(2, dim=1) + else: + g2l_gate, l2g_gate = 1, 1 + + if self.ratio_gout != 1: + out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate + if self.ratio_gout != 0: + out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g) + + return out_xl, out_xg + + +class FFC_BN_ACT(nn.Module): + + def __init__(self, in_channels, out_channels, + kernel_size, ratio_gin, ratio_gout, + stride=1, padding=0, dilation=1, groups=1, bias=False, + norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity, + padding_type='reflect', + enable_lfu=True, **kwargs): + super(FFC_BN_ACT, self).__init__() + self.ffc = FFC(in_channels, out_channels, kernel_size, + ratio_gin, ratio_gout, stride, padding, dilation, + groups, bias, enable_lfu, padding_type=padding_type, **kwargs) + lnorm = nn.Identity if ratio_gout == 1 else norm_layer + gnorm = nn.Identity if ratio_gout == 0 else norm_layer + global_channels = int(out_channels * ratio_gout) + self.bn_l = lnorm(out_channels - global_channels) + self.bn_g = gnorm(global_channels) + + lact = nn.Identity if ratio_gout == 1 else activation_layer + gact = nn.Identity if ratio_gout == 0 else activation_layer + self.act_l = lact(inplace=True) + self.act_g = gact(inplace=True) + + def forward(self, x): + x_l, x_g = self.ffc(x) + x_l = self.act_l(self.bn_l(x_l)) + x_g = self.act_g(self.bn_g(x_g)) + return x_l, x_g + + +class FFCResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1, + spatial_transform_kwargs=None, inline=False, **conv_kwargs): + super().__init__() + self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + **conv_kwargs) + self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation, + norm_layer=norm_layer, + activation_layer=activation_layer, + padding_type=padding_type, + **conv_kwargs) + if spatial_transform_kwargs is not None: + self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs) + self.conv2 = LearnableSpatialTransformWrapper(self.conv2, **spatial_transform_kwargs) + self.inline = inline + + def forward(self, x): + if self.inline: + x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:] + else: + x_l, x_g = x if type(x) is tuple else (x, 0) + + id_l, id_g = x_l, x_g + + x_l, x_g = self.conv1((x_l, x_g)) + x_l, x_g = self.conv2((x_l, x_g)) + + x_l, x_g = id_l + x_l, id_g + x_g + out = x_l, x_g + if self.inline: + out = torch.cat(out, dim=1) + return out + + +class ConcatTupleLayer(nn.Module): + def forward(self, x): + assert isinstance(x, tuple) + x_l, x_g = x + assert torch.is_tensor(x_l) or torch.is_tensor(x_g) + if not torch.is_tensor(x_g): + return x_l + return torch.cat(x, dim=1) + + +class FFCResNetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, + padding_type='reflect', activation_layer=nn.ReLU, + up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), + init_conv_kwargs={}, downsample_conv_kwargs={}, resnet_conv_kwargs={}, + spatial_transform_layers=None, spatial_transform_kwargs={}, + add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}): + assert (n_blocks >= 0) + super().__init__() + + model = [nn.ReflectionPad2d(3), + FFC_BN_ACT(input_nc, ngf, kernel_size=7, padding=0, norm_layer=norm_layer, + activation_layer=activation_layer, **init_conv_kwargs)] + + ### downsample + for i in range(n_downsampling): + mult = 2 ** i + if i == n_downsampling - 1: + cur_conv_kwargs = dict(downsample_conv_kwargs) + cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get('ratio_gin', 0) + else: + cur_conv_kwargs = downsample_conv_kwargs + model += [FFC_BN_ACT(min(max_features, ngf * mult), + min(max_features, ngf * mult * 2), + kernel_size=3, stride=2, padding=1, + norm_layer=norm_layer, + activation_layer=activation_layer, + **cur_conv_kwargs)] + + mult = 2 ** n_downsampling + feats_num_bottleneck = min(max_features, ngf * mult) + + ### resnet blocks + for i in range(n_blocks): + cur_resblock = FFCResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation_layer=activation_layer, + norm_layer=norm_layer, **resnet_conv_kwargs) + if spatial_transform_layers is not None and i in spatial_transform_layers: + cur_resblock = LearnableSpatialTransformWrapper(cur_resblock, **spatial_transform_kwargs) + model += [cur_resblock] + + model += [ConcatTupleLayer()] + + ### upsample + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + model += [nn.ConvTranspose2d(min(max_features, ngf * mult), + min(max_features, int(ngf * mult / 2)), + kernel_size=3, stride=2, padding=1, output_padding=1), + up_norm_layer(min(max_features, int(ngf * mult / 2))), + up_activation] + + if out_ffc: + model += [FFCResnetBlock(ngf, padding_type=padding_type, activation_layer=activation_layer, + norm_layer=norm_layer, inline=True, **out_ffc_kwargs)] + + model += [nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + if add_out_act: + model.append(get_activation('tanh' if add_out_act is True else add_out_act)) + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + + +class FFCNLayerDiscriminator(BaseDiscriminator): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, max_features=512, + init_conv_kwargs={}, conv_kwargs={}): + super().__init__() + self.n_layers = n_layers + + def _act_ctor(inplace=True): + return nn.LeakyReLU(negative_slope=0.2, inplace=inplace) + + kw = 3 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[FFC_BN_ACT(input_nc, ndf, kernel_size=kw, padding=padw, norm_layer=norm_layer, + activation_layer=_act_ctor, **init_conv_kwargs)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, max_features) + + cur_model = [ + FFC_BN_ACT(nf_prev, nf, + kernel_size=kw, stride=2, padding=padw, + norm_layer=norm_layer, + activation_layer=_act_ctor, + **conv_kwargs) + ] + sequence.append(cur_model) + + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [ + FFC_BN_ACT(nf_prev, nf, + kernel_size=kw, stride=1, padding=padw, + norm_layer=norm_layer, + activation_layer=lambda *args, **kwargs: nn.LeakyReLU(*args, negative_slope=0.2, **kwargs), + **conv_kwargs), + ConcatTupleLayer() + ] + sequence.append(cur_model) + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + + def get_all_activations(self, x): + res = [x] + for n in range(self.n_layers + 2): + model = getattr(self, 'model' + str(n)) + res.append(model(res[-1])) + return res[1:] + + def forward(self, x): + act = self.get_all_activations(x) + feats = [] + for out in act[:-1]: + if isinstance(out, tuple): + if torch.is_tensor(out[1]): + out = torch.cat(out, dim=1) + else: + out = out[0] + feats.append(out) + return act[-1], feats diff --git a/saicinpainting/training/modules/multidilated_conv.py b/saicinpainting/training/modules/multidilated_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..d267ee2aa5eb84b6a9291d0eaaff322c6c2802d0 --- /dev/null +++ b/saicinpainting/training/modules/multidilated_conv.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import random +from saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv + +class MultidilatedConv(nn.Module): + def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True, + shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs): + super().__init__() + convs = [] + self.equal_dim = equal_dim + assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode + if comb_mode in ('cat_out', 'cat_both'): + self.cat_out = True + if equal_dim: + assert out_dim % dilation_num == 0 + out_dims = [out_dim // dilation_num] * dilation_num + self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], []) + else: + out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] + out_dims.append(out_dim - sum(out_dims)) + index = [] + starts = [0] + out_dims[:-1] + lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)] + for i in range(out_dims[-1]): + for j in range(dilation_num): + index += list(range(starts[j], starts[j] + lengths[j])) + starts[j] += lengths[j] + self.index = index + assert(len(index) == out_dim) + self.out_dims = out_dims + else: + self.cat_out = False + self.out_dims = [out_dim] * dilation_num + + if comb_mode in ('cat_in', 'cat_both'): + if equal_dim: + assert in_dim % dilation_num == 0 + in_dims = [in_dim // dilation_num] * dilation_num + else: + in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] + in_dims.append(in_dim - sum(in_dims)) + self.in_dims = in_dims + self.cat_in = True + else: + self.cat_in = False + self.in_dims = [in_dim] * dilation_num + + conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d + dilation = min_dilation + for i in range(dilation_num): + if isinstance(padding, int): + cur_padding = padding * dilation + else: + cur_padding = padding[i] + convs.append(conv_type( + self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs + )) + if i > 0 and shared_weights: + convs[-1].weight = convs[0].weight + convs[-1].bias = convs[0].bias + dilation *= 2 + self.convs = nn.ModuleList(convs) + + self.shuffle_in_channels = shuffle_in_channels + if self.shuffle_in_channels: + # shuffle list as shuffling of tensors is nondeterministic + in_channels_permute = list(range(in_dim)) + random.shuffle(in_channels_permute) + # save as buffer so it is saved and loaded with checkpoint + self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute)) + + def forward(self, x): + if self.shuffle_in_channels: + x = x[:, self.in_channels_permute] + + outs = [] + if self.cat_in: + if self.equal_dim: + x = x.chunk(len(self.convs), dim=1) + else: + new_x = [] + start = 0 + for dim in self.in_dims: + new_x.append(x[:, start:start+dim]) + start += dim + x = new_x + for i, conv in enumerate(self.convs): + if self.cat_in: + input = x[i] + else: + input = x + outs.append(conv(input)) + if self.cat_out: + out = torch.cat(outs, dim=1)[:, self.index] + else: + out = sum(outs) + return out diff --git a/saicinpainting/training/modules/multiscale.py b/saicinpainting/training/modules/multiscale.py new file mode 100644 index 0000000000000000000000000000000000000000..65f0a54925593e9da8106bfc6d65a4098ce001d7 --- /dev/null +++ b/saicinpainting/training/modules/multiscale.py @@ -0,0 +1,244 @@ +from typing import List, Tuple, Union, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from saicinpainting.training.modules.base import get_conv_block_ctor, get_activation +from saicinpainting.training.modules.pix2pixhd import ResnetBlock + + +class ResNetHead(nn.Module): + def __init__(self, input_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, + padding_type='reflect', conv_kind='default', activation=nn.ReLU(True)): + assert (n_blocks >= 0) + super(ResNetHead, self).__init__() + + conv_layer = get_conv_block_ctor(conv_kind) + + model = [nn.ReflectionPad2d(3), + conv_layer(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), + activation] + + ### downsample + for i in range(n_downsampling): + mult = 2 ** i + model += [conv_layer(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), + norm_layer(ngf * mult * 2), + activation] + + mult = 2 ** n_downsampling + + ### resnet blocks + for i in range(n_blocks): + model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer, + conv_kind=conv_kind)] + + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + + +class ResNetTail(nn.Module): + def __init__(self, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, + padding_type='reflect', conv_kind='default', activation=nn.ReLU(True), + up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0, + add_in_proj=None): + assert (n_blocks >= 0) + super(ResNetTail, self).__init__() + + mult = 2 ** n_downsampling + + model = [] + + if add_in_proj is not None: + model.append(nn.Conv2d(add_in_proj, ngf * mult, kernel_size=1)) + + ### resnet blocks + for i in range(n_blocks): + model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer, + conv_kind=conv_kind)] + + ### upsample + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, + output_padding=1), + up_norm_layer(int(ngf * mult / 2)), + up_activation] + self.model = nn.Sequential(*model) + + out_layers = [] + for _ in range(out_extra_layers_n): + out_layers += [nn.Conv2d(ngf, ngf, kernel_size=1, padding=0), + up_norm_layer(ngf), + up_activation] + out_layers += [nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + + if add_out_act: + out_layers.append(get_activation('tanh' if add_out_act is True else add_out_act)) + + self.out_proj = nn.Sequential(*out_layers) + + def forward(self, input, return_last_act=False): + features = self.model(input) + out = self.out_proj(features) + if return_last_act: + return out, features + else: + return out + + +class MultiscaleResNet(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=2, n_blocks_head=2, n_blocks_tail=6, n_scales=3, + norm_layer=nn.BatchNorm2d, padding_type='reflect', conv_kind='default', activation=nn.ReLU(True), + up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0, + out_cumulative=False, return_only_hr=False): + super().__init__() + + self.heads = nn.ModuleList([ResNetHead(input_nc, ngf=ngf, n_downsampling=n_downsampling, + n_blocks=n_blocks_head, norm_layer=norm_layer, padding_type=padding_type, + conv_kind=conv_kind, activation=activation) + for i in range(n_scales)]) + tail_in_feats = ngf * (2 ** n_downsampling) + ngf + self.tails = nn.ModuleList([ResNetTail(output_nc, + ngf=ngf, n_downsampling=n_downsampling, + n_blocks=n_blocks_tail, norm_layer=norm_layer, padding_type=padding_type, + conv_kind=conv_kind, activation=activation, up_norm_layer=up_norm_layer, + up_activation=up_activation, add_out_act=add_out_act, + out_extra_layers_n=out_extra_layers_n, + add_in_proj=None if (i == n_scales - 1) else tail_in_feats) + for i in range(n_scales)]) + + self.out_cumulative = out_cumulative + self.return_only_hr = return_only_hr + + @property + def num_scales(self): + return len(self.heads) + + def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \ + -> Union[torch.Tensor, List[torch.Tensor]]: + """ + :param ms_inputs: List of inputs of different resolutions from HR to LR + :param smallest_scales_num: int or None, number of smallest scales to take at input + :return: Depending on return_only_hr: + True: Only the most HR output + False: List of outputs of different resolutions from HR to LR + """ + if smallest_scales_num is None: + assert len(self.heads) == len(ms_inputs), (len(self.heads), len(ms_inputs), smallest_scales_num) + smallest_scales_num = len(self.heads) + else: + assert smallest_scales_num == len(ms_inputs) <= len(self.heads), (len(self.heads), len(ms_inputs), smallest_scales_num) + + cur_heads = self.heads[-smallest_scales_num:] + ms_features = [cur_head(cur_inp) for cur_head, cur_inp in zip(cur_heads, ms_inputs)] + + all_outputs = [] + prev_tail_features = None + for i in range(len(ms_features)): + scale_i = -i - 1 + + cur_tail_input = ms_features[-i - 1] + if prev_tail_features is not None: + if prev_tail_features.shape != cur_tail_input.shape: + prev_tail_features = F.interpolate(prev_tail_features, size=cur_tail_input.shape[2:], + mode='bilinear', align_corners=False) + cur_tail_input = torch.cat((cur_tail_input, prev_tail_features), dim=1) + + cur_out, cur_tail_feats = self.tails[scale_i](cur_tail_input, return_last_act=True) + + prev_tail_features = cur_tail_feats + all_outputs.append(cur_out) + + if self.out_cumulative: + all_outputs_cum = [all_outputs[0]] + for i in range(1, len(ms_features)): + cur_out = all_outputs[i] + cur_out_cum = cur_out + F.interpolate(all_outputs_cum[-1], size=cur_out.shape[2:], + mode='bilinear', align_corners=False) + all_outputs_cum.append(cur_out_cum) + all_outputs = all_outputs_cum + + if self.return_only_hr: + return all_outputs[-1] + else: + return all_outputs[::-1] + + +class MultiscaleDiscriminatorSimple(nn.Module): + def __init__(self, ms_impl): + super().__init__() + self.ms_impl = nn.ModuleList(ms_impl) + + @property + def num_scales(self): + return len(self.ms_impl) + + def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \ + -> List[Tuple[torch.Tensor, List[torch.Tensor]]]: + """ + :param ms_inputs: List of inputs of different resolutions from HR to LR + :param smallest_scales_num: int or None, number of smallest scales to take at input + :return: List of pairs (prediction, features) for different resolutions from HR to LR + """ + if smallest_scales_num is None: + assert len(self.ms_impl) == len(ms_inputs), (len(self.ms_impl), len(ms_inputs), smallest_scales_num) + smallest_scales_num = len(self.heads) + else: + assert smallest_scales_num == len(ms_inputs) <= len(self.ms_impl), \ + (len(self.ms_impl), len(ms_inputs), smallest_scales_num) + + return [cur_discr(cur_input) for cur_discr, cur_input in zip(self.ms_impl[-smallest_scales_num:], ms_inputs)] + + +class SingleToMultiScaleInputMixin: + def forward(self, x: torch.Tensor) -> List: + orig_height, orig_width = x.shape[2:] + factors = [2 ** i for i in range(self.num_scales)] + ms_inputs = [F.interpolate(x, size=(orig_height // f, orig_width // f), mode='bilinear', align_corners=False) + for f in factors] + return super().forward(ms_inputs) + + +class GeneratorMultiToSingleOutputMixin: + def forward(self, x): + return super().forward(x)[0] + + +class DiscriminatorMultiToSingleOutputMixin: + def forward(self, x): + out_feat_tuples = super().forward(x) + return out_feat_tuples[0][0], [f for _, flist in out_feat_tuples for f in flist] + + +class DiscriminatorMultiToSingleOutputStackedMixin: + def __init__(self, *args, return_feats_only_levels=None, **kwargs): + super().__init__(*args, **kwargs) + self.return_feats_only_levels = return_feats_only_levels + + def forward(self, x): + out_feat_tuples = super().forward(x) + outs = [out for out, _ in out_feat_tuples] + scaled_outs = [outs[0]] + [F.interpolate(cur_out, size=outs[0].shape[-2:], + mode='bilinear', align_corners=False) + for cur_out in outs[1:]] + out = torch.cat(scaled_outs, dim=1) + if self.return_feats_only_levels is not None: + feat_lists = [out_feat_tuples[i][1] for i in self.return_feats_only_levels] + else: + feat_lists = [flist for _, flist in out_feat_tuples] + feats = [f for flist in feat_lists for f in flist] + return out, feats + + +class MultiscaleDiscrSingleInput(SingleToMultiScaleInputMixin, DiscriminatorMultiToSingleOutputStackedMixin, MultiscaleDiscriminatorSimple): + pass + + +class MultiscaleResNetSingle(GeneratorMultiToSingleOutputMixin, SingleToMultiScaleInputMixin, MultiscaleResNet): + pass diff --git a/saicinpainting/training/modules/pix2pixhd.py b/saicinpainting/training/modules/pix2pixhd.py new file mode 100644 index 0000000000000000000000000000000000000000..08c6afd777a88cd232592acbbf0ef25db8d43217 --- /dev/null +++ b/saicinpainting/training/modules/pix2pixhd.py @@ -0,0 +1,669 @@ +# original: https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py +import collections +from functools import partial +import functools +import logging +from collections import defaultdict + +import numpy as np +import torch.nn as nn + +from saicinpainting.training.modules.base import BaseDiscriminator, deconv_factory, get_conv_block_ctor, get_norm_layer, get_activation +from saicinpainting.training.modules.ffc import FFCResnetBlock +from saicinpainting.training.modules.multidilated_conv import MultidilatedConv + +class DotDict(defaultdict): + # https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary + """dot.notation access to dictionary attributes""" + __getattr__ = defaultdict.get + __setattr__ = defaultdict.__setitem__ + __delattr__ = defaultdict.__delitem__ + +class Identity(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + + +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default', + dilation=1, in_dim=None, groups=1, second_dilation=None): + super(ResnetBlock, self).__init__() + self.in_dim = in_dim + self.dim = dim + if second_dilation is None: + second_dilation = dilation + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout, + conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups, + second_dilation=second_dilation) + + if self.in_dim is not None: + self.input_conv = nn.Conv2d(in_dim, dim, 1) + + self.out_channnels = dim + + def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default', + dilation=1, in_dim=None, groups=1, second_dilation=1): + conv_layer = get_conv_block_ctor(conv_kind) + + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(dilation)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(dilation)] + elif padding_type == 'zero': + p = dilation + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + if in_dim is None: + in_dim = dim + + conv_block += [conv_layer(in_dim, dim, kernel_size=3, padding=p, dilation=dilation), + norm_layer(dim), + activation] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(second_dilation)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(second_dilation)] + elif padding_type == 'zero': + p = second_dilation + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [conv_layer(dim, dim, kernel_size=3, padding=p, dilation=second_dilation, groups=groups), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + x_before = x + if self.in_dim is not None: + x = self.input_conv(x) + out = x + self.conv_block(x_before) + return out + +class ResnetBlock5x5(nn.Module): + def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default', + dilation=1, in_dim=None, groups=1, second_dilation=None): + super(ResnetBlock5x5, self).__init__() + self.in_dim = in_dim + self.dim = dim + if second_dilation is None: + second_dilation = dilation + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout, + conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups, + second_dilation=second_dilation) + + if self.in_dim is not None: + self.input_conv = nn.Conv2d(in_dim, dim, 1) + + self.out_channnels = dim + + def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default', + dilation=1, in_dim=None, groups=1, second_dilation=1): + conv_layer = get_conv_block_ctor(conv_kind) + + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(dilation * 2)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(dilation * 2)] + elif padding_type == 'zero': + p = dilation * 2 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + if in_dim is None: + in_dim = dim + + conv_block += [conv_layer(in_dim, dim, kernel_size=5, padding=p, dilation=dilation), + norm_layer(dim), + activation] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(second_dilation * 2)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(second_dilation * 2)] + elif padding_type == 'zero': + p = second_dilation * 2 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [conv_layer(dim, dim, kernel_size=5, padding=p, dilation=second_dilation, groups=groups), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + x_before = x + if self.in_dim is not None: + x = self.input_conv(x) + out = x + self.conv_block(x_before) + return out + + +class MultidilatedResnetBlock(nn.Module): + def __init__(self, dim, padding_type, conv_layer, norm_layer, activation=nn.ReLU(True), use_dropout=False): + super().__init__() + self.conv_block = self.build_conv_block(dim, padding_type, conv_layer, norm_layer, activation, use_dropout) + + def build_conv_block(self, dim, padding_type, conv_layer, norm_layer, activation, use_dropout, dilation=1): + conv_block = [] + conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type), + norm_layer(dim), + activation] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + +class MultiDilatedGlobalGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, + n_blocks=3, norm_layer=nn.BatchNorm2d, + padding_type='reflect', conv_kind='default', + deconv_kind='convtranspose', activation=nn.ReLU(True), + up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True), + add_out_act=True, max_features=1024, multidilation_kwargs={}, + ffc_positions=None, ffc_kwargs={}): + assert (n_blocks >= 0) + super().__init__() + + conv_layer = get_conv_block_ctor(conv_kind) + resnet_conv_layer = functools.partial(get_conv_block_ctor('multidilated'), **multidilation_kwargs) + norm_layer = get_norm_layer(norm_layer) + if affine is not None: + norm_layer = partial(norm_layer, affine=affine) + up_norm_layer = get_norm_layer(up_norm_layer) + if affine is not None: + up_norm_layer = partial(up_norm_layer, affine=affine) + + model = [nn.ReflectionPad2d(3), + conv_layer(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), + activation] + + identity = Identity() + ### downsample + for i in range(n_downsampling): + mult = 2 ** i + + model += [conv_layer(min(max_features, ngf * mult), + min(max_features, ngf * mult * 2), + kernel_size=3, stride=2, padding=1), + norm_layer(min(max_features, ngf * mult * 2)), + activation] + + mult = 2 ** n_downsampling + feats_num_bottleneck = min(max_features, ngf * mult) + + ### resnet blocks + for i in range(n_blocks): + if ffc_positions is not None and i in ffc_positions: + model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU, + inline=True, **ffc_kwargs)] + model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type, + conv_layer=resnet_conv_layer, activation=activation, + norm_layer=norm_layer)] + + ### upsample + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features) + model += [nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + if add_out_act: + model.append(get_activation('tanh' if add_out_act is True else add_out_act)) + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + +class ConfigGlobalGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, + n_blocks=3, norm_layer=nn.BatchNorm2d, + padding_type='reflect', conv_kind='default', + deconv_kind='convtranspose', activation=nn.ReLU(True), + up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True), + add_out_act=True, max_features=1024, + manual_block_spec=[], + resnet_block_kind='multidilatedresnetblock', + resnet_conv_kind='multidilated', + resnet_dilation=1, + multidilation_kwargs={}): + assert (n_blocks >= 0) + super().__init__() + + conv_layer = get_conv_block_ctor(conv_kind) + resnet_conv_layer = functools.partial(get_conv_block_ctor(resnet_conv_kind), **multidilation_kwargs) + norm_layer = get_norm_layer(norm_layer) + if affine is not None: + norm_layer = partial(norm_layer, affine=affine) + up_norm_layer = get_norm_layer(up_norm_layer) + if affine is not None: + up_norm_layer = partial(up_norm_layer, affine=affine) + + model = [nn.ReflectionPad2d(3), + conv_layer(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), + activation] + + identity = Identity() + + ### downsample + for i in range(n_downsampling): + mult = 2 ** i + model += [conv_layer(min(max_features, ngf * mult), + min(max_features, ngf * mult * 2), + kernel_size=3, stride=2, padding=1), + norm_layer(min(max_features, ngf * mult * 2)), + activation] + + mult = 2 ** n_downsampling + feats_num_bottleneck = min(max_features, ngf * mult) + + if len(manual_block_spec) == 0: + manual_block_spec = [ + DotDict(lambda : None, { + 'n_blocks': n_blocks, + 'use_default': True}) + ] + + ### resnet blocks + for block_spec in manual_block_spec: + def make_and_add_blocks(model, block_spec): + block_spec = DotDict(lambda : None, block_spec) + if not block_spec.use_default: + resnet_conv_layer = functools.partial(get_conv_block_ctor(block_spec.resnet_conv_kind), **block_spec.multidilation_kwargs) + resnet_conv_kind = block_spec.resnet_conv_kind + resnet_block_kind = block_spec.resnet_block_kind + if block_spec.resnet_dilation is not None: + resnet_dilation = block_spec.resnet_dilation + for i in range(block_spec.n_blocks): + if resnet_block_kind == "multidilatedresnetblock": + model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type, + conv_layer=resnet_conv_layer, activation=activation, + norm_layer=norm_layer)] + if resnet_block_kind == "resnetblock": + model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer, + conv_kind=resnet_conv_kind)] + if resnet_block_kind == "resnetblock5x5": + model += [ResnetBlock5x5(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer, + conv_kind=resnet_conv_kind)] + if resnet_block_kind == "resnetblockdwdil": + model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer, + conv_kind=resnet_conv_kind, dilation=resnet_dilation, second_dilation=resnet_dilation)] + make_and_add_blocks(model, block_spec) + + ### upsample + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features) + model += [nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + if add_out_act: + model.append(get_activation('tanh' if add_out_act is True else add_out_act)) + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + + +def make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs): + blocks = [] + for i in range(dilated_blocks_n): + if dilation_block_kind == 'simple': + blocks.append(ResnetBlock(**dilated_block_kwargs, dilation=2 ** (i + 1))) + elif dilation_block_kind == 'multi': + blocks.append(MultidilatedResnetBlock(**dilated_block_kwargs)) + else: + raise ValueError(f'dilation_block_kind could not be "{dilation_block_kind}"') + return blocks + + +class GlobalGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, + padding_type='reflect', conv_kind='default', activation=nn.ReLU(True), + up_norm_layer=nn.BatchNorm2d, affine=None, + up_activation=nn.ReLU(True), dilated_blocks_n=0, dilated_blocks_n_start=0, + dilated_blocks_n_middle=0, + add_out_act=True, + max_features=1024, is_resblock_depthwise=False, + ffc_positions=None, ffc_kwargs={}, dilation=1, second_dilation=None, + dilation_block_kind='simple', multidilation_kwargs={}): + assert (n_blocks >= 0) + super().__init__() + + conv_layer = get_conv_block_ctor(conv_kind) + norm_layer = get_norm_layer(norm_layer) + if affine is not None: + norm_layer = partial(norm_layer, affine=affine) + up_norm_layer = get_norm_layer(up_norm_layer) + if affine is not None: + up_norm_layer = partial(up_norm_layer, affine=affine) + + if ffc_positions is not None: + ffc_positions = collections.Counter(ffc_positions) + + model = [nn.ReflectionPad2d(3), + conv_layer(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), + activation] + + identity = Identity() + ### downsample + for i in range(n_downsampling): + mult = 2 ** i + + model += [conv_layer(min(max_features, ngf * mult), + min(max_features, ngf * mult * 2), + kernel_size=3, stride=2, padding=1), + norm_layer(min(max_features, ngf * mult * 2)), + activation] + + mult = 2 ** n_downsampling + feats_num_bottleneck = min(max_features, ngf * mult) + + dilated_block_kwargs = dict(dim=feats_num_bottleneck, padding_type=padding_type, + activation=activation, norm_layer=norm_layer) + if dilation_block_kind == 'simple': + dilated_block_kwargs['conv_kind'] = conv_kind + elif dilation_block_kind == 'multi': + dilated_block_kwargs['conv_layer'] = functools.partial( + get_conv_block_ctor('multidilated'), **multidilation_kwargs) + + # dilated blocks at the start of the bottleneck sausage + if dilated_blocks_n_start is not None and dilated_blocks_n_start > 0: + model += make_dil_blocks(dilated_blocks_n_start, dilation_block_kind, dilated_block_kwargs) + + # resnet blocks + for i in range(n_blocks): + # dilated blocks at the middle of the bottleneck sausage + if i == n_blocks // 2 and dilated_blocks_n_middle is not None and dilated_blocks_n_middle > 0: + model += make_dil_blocks(dilated_blocks_n_middle, dilation_block_kind, dilated_block_kwargs) + + if ffc_positions is not None and i in ffc_positions: + for _ in range(ffc_positions[i]): # same position can occur more than once + model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU, + inline=True, **ffc_kwargs)] + + if is_resblock_depthwise: + resblock_groups = feats_num_bottleneck + else: + resblock_groups = 1 + + model += [ResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation=activation, + norm_layer=norm_layer, conv_kind=conv_kind, groups=resblock_groups, + dilation=dilation, second_dilation=second_dilation)] + + + # dilated blocks at the end of the bottleneck sausage + if dilated_blocks_n is not None and dilated_blocks_n > 0: + model += make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs) + + # upsample + for i in range(n_downsampling): + mult = 2 ** (n_downsampling - i) + model += [nn.ConvTranspose2d(min(max_features, ngf * mult), + min(max_features, int(ngf * mult / 2)), + kernel_size=3, stride=2, padding=1, output_padding=1), + up_norm_layer(min(max_features, int(ngf * mult / 2))), + up_activation] + model += [nn.ReflectionPad2d(3), + nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + if add_out_act: + model.append(get_activation('tanh' if add_out_act is True else add_out_act)) + self.model = nn.Sequential(*model) + + def forward(self, input): + return self.model(input) + + +class GlobalGeneratorGated(GlobalGenerator): + def __init__(self, *args, **kwargs): + real_kwargs=dict( + conv_kind='gated_bn_relu', + activation=nn.Identity(), + norm_layer=nn.Identity + ) + real_kwargs.update(kwargs) + super().__init__(*args, **real_kwargs) + + +class GlobalGeneratorFromSuperChannels(nn.Module): + def __init__(self, input_nc, output_nc, n_downsampling, n_blocks, super_channels, norm_layer="bn", padding_type='reflect', add_out_act=True): + super().__init__() + self.n_downsampling = n_downsampling + norm_layer = get_norm_layer(norm_layer) + if type(norm_layer) == functools.partial: + use_bias = (norm_layer.func == nn.InstanceNorm2d) + else: + use_bias = (norm_layer == nn.InstanceNorm2d) + + channels = self.convert_super_channels(super_channels) + self.channels = channels + + model = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, channels[0], kernel_size=7, padding=0, bias=use_bias), + norm_layer(channels[0]), + nn.ReLU(True)] + + for i in range(n_downsampling): # add downsampling layers + mult = 2 ** i + model += [nn.Conv2d(channels[0+i], channels[1+i], kernel_size=3, stride=2, padding=1, bias=use_bias), + norm_layer(channels[1+i]), + nn.ReLU(True)] + + mult = 2 ** n_downsampling + + n_blocks1 = n_blocks // 3 + n_blocks2 = n_blocks1 + n_blocks3 = n_blocks - n_blocks1 - n_blocks2 + + for i in range(n_blocks1): + c = n_downsampling + dim = channels[c] + model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer)] + + for i in range(n_blocks2): + c = n_downsampling+1 + dim = channels[c] + kwargs = {} + if i == 0: + kwargs = {"in_dim": channels[c-1]} + model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)] + + for i in range(n_blocks3): + c = n_downsampling+2 + dim = channels[c] + kwargs = {} + if i == 0: + kwargs = {"in_dim": channels[c-1]} + model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)] + + for i in range(n_downsampling): # add upsampling layers + mult = 2 ** (n_downsampling - i) + model += [nn.ConvTranspose2d(channels[n_downsampling+3+i], + channels[n_downsampling+3+i+1], + kernel_size=3, stride=2, + padding=1, output_padding=1, + bias=use_bias), + norm_layer(channels[n_downsampling+3+i+1]), + nn.ReLU(True)] + model += [nn.ReflectionPad2d(3)] + model += [nn.Conv2d(channels[2*n_downsampling+3], output_nc, kernel_size=7, padding=0)] + + if add_out_act: + model.append(get_activation('tanh' if add_out_act is True else add_out_act)) + self.model = nn.Sequential(*model) + + def convert_super_channels(self, super_channels): + n_downsampling = self.n_downsampling + result = [] + cnt = 0 + + if n_downsampling == 2: + N1 = 10 + elif n_downsampling == 3: + N1 = 13 + else: + raise NotImplementedError + + for i in range(0, N1): + if i in [1,4,7,10]: + channel = super_channels[cnt] * (2 ** cnt) + config = {'channel': channel} + result.append(channel) + logging.info(f"Downsample channels {result[-1]}") + cnt += 1 + + for i in range(3): + for counter, j in enumerate(range(N1 + i * 3, N1 + 3 + i * 3)): + if len(super_channels) == 6: + channel = super_channels[3] * 4 + else: + channel = super_channels[i + 3] * 4 + config = {'channel': channel} + if counter == 0: + result.append(channel) + logging.info(f"Bottleneck channels {result[-1]}") + cnt = 2 + + for i in range(N1+9, N1+21): + if i in [22, 25,28]: + cnt -= 1 + if len(super_channels) == 6: + channel = super_channels[5 - cnt] * (2 ** cnt) + else: + channel = super_channels[7 - cnt] * (2 ** cnt) + result.append(int(channel)) + logging.info(f"Upsample channels {result[-1]}") + return result + + def forward(self, input): + return self.model(input) + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(BaseDiscriminator): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,): + super().__init__() + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [] + cur_model += [ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + sequence.append(cur_model) + + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [] + cur_model += [ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + sequence.append(cur_model) + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + + def get_all_activations(self, x): + res = [x] + for n in range(self.n_layers + 2): + model = getattr(self, 'model' + str(n)) + res.append(model(res[-1])) + return res[1:] + + def forward(self, x): + act = self.get_all_activations(x) + return act[-1], act[:-1] + + +class MultidilatedNLayerDiscriminator(BaseDiscriminator): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, multidilation_kwargs={}): + super().__init__() + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [] + cur_model += [ + MultidilatedConv(nf_prev, nf, kernel_size=kw, stride=2, padding=[2, 3], **multidilation_kwargs), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + sequence.append(cur_model) + + nf_prev = nf + nf = min(nf * 2, 512) + + cur_model = [] + cur_model += [ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ] + sequence.append(cur_model) + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + + def get_all_activations(self, x): + res = [x] + for n in range(self.n_layers + 2): + model = getattr(self, 'model' + str(n)) + res.append(model(res[-1])) + return res[1:] + + def forward(self, x): + act = self.get_all_activations(x) + return act[-1], act[:-1] + + +class NLayerDiscriminatorAsGen(NLayerDiscriminator): + def forward(self, x): + return super().forward(x)[0] diff --git a/saicinpainting/training/modules/spatial_transform.py b/saicinpainting/training/modules/spatial_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..2de024ba08c549605a08b64d096f1f0db7b7722a --- /dev/null +++ b/saicinpainting/training/modules/spatial_transform.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from kornia.geometry.transform import rotate + + +class LearnableSpatialTransformWrapper(nn.Module): + def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True): + super().__init__() + self.impl = impl + self.angle = torch.rand(1) * angle_init_range + if train_angle: + self.angle = nn.Parameter(self.angle, requires_grad=True) + self.pad_coef = pad_coef + + def forward(self, x): + if torch.is_tensor(x): + return self.inverse_transform(self.impl(self.transform(x)), x) + elif isinstance(x, tuple): + x_trans = tuple(self.transform(elem) for elem in x) + y_trans = self.impl(x_trans) + return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x)) + else: + raise ValueError(f'Unexpected input type {type(x)}') + + def transform(self, x): + height, width = x.shape[2:] + pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) + x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect') + x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded)) + return x_padded_rotated + + def inverse_transform(self, y_padded_rotated, orig_x): + height, width = orig_x.shape[2:] + pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) + + y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated)) + y_height, y_width = y_padded.shape[2:] + y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w] + return y + + +if __name__ == '__main__': + layer = LearnableSpatialTransformWrapper(nn.Identity()) + x = torch.arange(2* 3 * 15 * 15).view(2, 3, 15, 15).float() + y = layer(x) + assert x.shape == y.shape + assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1]) + print('all ok') diff --git a/saicinpainting/training/modules/squeeze_excitation.py b/saicinpainting/training/modules/squeeze_excitation.py new file mode 100644 index 0000000000000000000000000000000000000000..d1d902bb30c071acbc0fa919a134c80fed86bd6c --- /dev/null +++ b/saicinpainting/training/modules/squeeze_excitation.py @@ -0,0 +1,20 @@ +import torch.nn as nn + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + res = x * y.expand_as(x) + return res diff --git a/saicinpainting/training/trainers/__init__.py b/saicinpainting/training/trainers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c59241f553efe4e2dd6b198e2e5656a2b1488857 --- /dev/null +++ b/saicinpainting/training/trainers/__init__.py @@ -0,0 +1,30 @@ +import logging +import torch +from saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule + + +def get_training_model_class(kind): + if kind == 'default': + return DefaultInpaintingTrainingModule + + raise ValueError(f'Unknown trainer module {kind}') + + +def make_training_model(config): + kind = config.training_model.kind + kwargs = dict(config.training_model) + kwargs.pop('kind') + kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp' + + logging.info(f'Make training model {kind}') + + cls = get_training_model_class(kind) + return cls(config, **kwargs) + + +def load_checkpoint(train_config, path, map_location='cuda', strict=True): + model: torch.nn.Module = make_training_model(train_config) + state = torch.load(path, map_location=map_location) + model.load_state_dict(state['state_dict'], strict=strict) + model.on_load_checkpoint(state) + return model diff --git a/saicinpainting/training/trainers/base.py b/saicinpainting/training/trainers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f1b1c66fc96e7edfba7b1ee193272f92b5db7438 --- /dev/null +++ b/saicinpainting/training/trainers/base.py @@ -0,0 +1,291 @@ +import copy +import logging +from typing import Dict, Tuple + +import pandas as pd +import pytorch_lightning as ptl +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DistributedSampler + +from saicinpainting.evaluation import make_evaluator +from saicinpainting.training.data.datasets import make_default_train_dataloader, make_default_val_dataloader +from saicinpainting.training.losses.adversarial import make_discrim_loss +from saicinpainting.training.losses.perceptual import PerceptualLoss, ResNetPL +from saicinpainting.training.modules import make_generator, make_discriminator +from saicinpainting.training.visualizers import make_visualizer +from saicinpainting.utils import add_prefix_to_keys, average_dicts, set_requires_grad, flatten_dict, \ + get_has_ddp_rank + +LOGGER = logging.getLogger(__name__) + + +def make_optimizer(parameters, kind='adamw', **kwargs): + if kind == 'adam': + optimizer_class = torch.optim.Adam + elif kind == 'adamw': + optimizer_class = torch.optim.AdamW + else: + raise ValueError(f'Unknown optimizer kind {kind}') + return optimizer_class(parameters, **kwargs) + + +def update_running_average(result: nn.Module, new_iterate_model: nn.Module, decay=0.999): + with torch.no_grad(): + res_params = dict(result.named_parameters()) + new_params = dict(new_iterate_model.named_parameters()) + + for k in res_params.keys(): + res_params[k].data.mul_(decay).add_(new_params[k].data, alpha=1 - decay) + + +def make_multiscale_noise(base_tensor, scales=6, scale_mode='bilinear'): + batch_size, _, height, width = base_tensor.shape + cur_height, cur_width = height, width + result = [] + align_corners = False if scale_mode in ('bilinear', 'bicubic') else None + for _ in range(scales): + cur_sample = torch.randn(batch_size, 1, cur_height, cur_width, device=base_tensor.device) + cur_sample_scaled = F.interpolate(cur_sample, size=(height, width), mode=scale_mode, align_corners=align_corners) + result.append(cur_sample_scaled) + cur_height //= 2 + cur_width //= 2 + return torch.cat(result, dim=1) + + +class BaseInpaintingTrainingModule(ptl.LightningModule): + def __init__(self, config, use_ddp, *args, predict_only=False, visualize_each_iters=100, + average_generator=False, generator_avg_beta=0.999, average_generator_start_step=30000, + average_generator_period=10, store_discr_outputs_for_vis=False, + **kwargs): + super().__init__(*args, **kwargs) + LOGGER.info('BaseInpaintingTrainingModule init called') + + self.config = config + + self.generator = make_generator(config, **self.config.generator) + self.use_ddp = use_ddp + + if not get_has_ddp_rank(): + LOGGER.info(f'Generator\n{self.generator}') + + if not predict_only: + self.save_hyperparameters(self.config) + self.discriminator = make_discriminator(**self.config.discriminator) + self.adversarial_loss = make_discrim_loss(**self.config.losses.adversarial) + self.visualizer = make_visualizer(**self.config.visualizer) + self.val_evaluator = make_evaluator(**self.config.evaluator) + self.test_evaluator = make_evaluator(**self.config.evaluator) + + if not get_has_ddp_rank(): + LOGGER.info(f'Discriminator\n{self.discriminator}') + + extra_val = self.config.data.get('extra_val', ()) + if extra_val: + self.extra_val_titles = list(extra_val) + self.extra_evaluators = nn.ModuleDict({k: make_evaluator(**self.config.evaluator) + for k in extra_val}) + else: + self.extra_evaluators = {} + + self.average_generator = average_generator + self.generator_avg_beta = generator_avg_beta + self.average_generator_start_step = average_generator_start_step + self.average_generator_period = average_generator_period + self.generator_average = None + self.last_generator_averaging_step = -1 + self.store_discr_outputs_for_vis = store_discr_outputs_for_vis + + if self.config.losses.get("l1", {"weight_known": 0})['weight_known'] > 0: + self.loss_l1 = nn.L1Loss(reduction='none') + + if self.config.losses.get("mse", {"weight": 0})['weight'] > 0: + self.loss_mse = nn.MSELoss(reduction='none') + + if self.config.losses.perceptual.weight > 0: + self.loss_pl = PerceptualLoss() + + if self.config.losses.get("resnet_pl", {"weight": 0})['weight'] > 0: + self.loss_resnet_pl = ResNetPL(**self.config.losses.resnet_pl) + else: + self.loss_resnet_pl = None + + self.visualize_each_iters = visualize_each_iters + LOGGER.info('BaseInpaintingTrainingModule init done') + + def configure_optimizers(self): + discriminator_params = list(self.discriminator.parameters()) + return [ + dict(optimizer=make_optimizer(self.generator.parameters(), **self.config.optimizers.generator)), + dict(optimizer=make_optimizer(discriminator_params, **self.config.optimizers.discriminator)), + ] + + def train_dataloader(self): + kwargs = dict(self.config.data.train) + if self.use_ddp: + kwargs['ddp_kwargs'] = dict(num_replicas=self.trainer.num_nodes * self.trainer.num_processes, + rank=self.trainer.global_rank, + shuffle=True) + dataloader = make_default_train_dataloader(**self.config.data.train) + return dataloader + + def val_dataloader(self): + res = [make_default_val_dataloader(**self.config.data.val)] + + if self.config.data.visual_test is not None: + res = res + [make_default_val_dataloader(**self.config.data.visual_test)] + else: + res = res + res + + extra_val = self.config.data.get('extra_val', ()) + if extra_val: + res += [make_default_val_dataloader(**extra_val[k]) for k in self.extra_val_titles] + + return res + + def training_step(self, batch, batch_idx, optimizer_idx=None): + self._is_training_step = True + return self._do_step(batch, batch_idx, mode='train', optimizer_idx=optimizer_idx) + + def validation_step(self, batch, batch_idx, dataloader_idx): + extra_val_key = None + if dataloader_idx == 0: + mode = 'val' + elif dataloader_idx == 1: + mode = 'test' + else: + mode = 'extra_val' + extra_val_key = self.extra_val_titles[dataloader_idx - 2] + self._is_training_step = False + return self._do_step(batch, batch_idx, mode=mode, extra_val_key=extra_val_key) + + def training_step_end(self, batch_parts_outputs): + if self.training and self.average_generator \ + and self.global_step >= self.average_generator_start_step \ + and self.global_step >= self.last_generator_averaging_step + self.average_generator_period: + if self.generator_average is None: + self.generator_average = copy.deepcopy(self.generator) + else: + update_running_average(self.generator_average, self.generator, decay=self.generator_avg_beta) + self.last_generator_averaging_step = self.global_step + + full_loss = (batch_parts_outputs['loss'].mean() + if torch.is_tensor(batch_parts_outputs['loss']) # loss is not tensor when no discriminator used + else torch.tensor(batch_parts_outputs['loss']).float().requires_grad_(True)) + log_info = {k: v.mean() for k, v in batch_parts_outputs['log_info'].items()} + self.log_dict(log_info, on_step=True, on_epoch=False) + return full_loss + + def validation_epoch_end(self, outputs): + outputs = [step_out for out_group in outputs for step_out in out_group] + averaged_logs = average_dicts(step_out['log_info'] for step_out in outputs) + self.log_dict({k: v.mean() for k, v in averaged_logs.items()}) + + pd.set_option('display.max_columns', 500) + pd.set_option('display.width', 1000) + + # standard validation + val_evaluator_states = [s['val_evaluator_state'] for s in outputs if 'val_evaluator_state' in s] + val_evaluator_res = self.val_evaluator.evaluation_end(states=val_evaluator_states) + val_evaluator_res_df = pd.DataFrame(val_evaluator_res).stack(1).unstack(0) + val_evaluator_res_df.dropna(axis=1, how='all', inplace=True) + LOGGER.info(f'Validation metrics after epoch #{self.current_epoch}, ' + f'total {self.global_step} iterations:\n{val_evaluator_res_df}') + + for k, v in flatten_dict(val_evaluator_res).items(): + self.log(f'val_{k}', v) + + # standard visual test + test_evaluator_states = [s['test_evaluator_state'] for s in outputs + if 'test_evaluator_state' in s] + test_evaluator_res = self.test_evaluator.evaluation_end(states=test_evaluator_states) + test_evaluator_res_df = pd.DataFrame(test_evaluator_res).stack(1).unstack(0) + test_evaluator_res_df.dropna(axis=1, how='all', inplace=True) + LOGGER.info(f'Test metrics after epoch #{self.current_epoch}, ' + f'total {self.global_step} iterations:\n{test_evaluator_res_df}') + + for k, v in flatten_dict(test_evaluator_res).items(): + self.log(f'test_{k}', v) + + # extra validations + if self.extra_evaluators: + for cur_eval_title, cur_evaluator in self.extra_evaluators.items(): + cur_state_key = f'extra_val_{cur_eval_title}_evaluator_state' + cur_states = [s[cur_state_key] for s in outputs if cur_state_key in s] + cur_evaluator_res = cur_evaluator.evaluation_end(states=cur_states) + cur_evaluator_res_df = pd.DataFrame(cur_evaluator_res).stack(1).unstack(0) + cur_evaluator_res_df.dropna(axis=1, how='all', inplace=True) + LOGGER.info(f'Extra val {cur_eval_title} metrics after epoch #{self.current_epoch}, ' + f'total {self.global_step} iterations:\n{cur_evaluator_res_df}') + for k, v in flatten_dict(cur_evaluator_res).items(): + self.log(f'extra_val_{cur_eval_title}_{k}', v) + + def _do_step(self, batch, batch_idx, mode='train', optimizer_idx=None, extra_val_key=None): + if optimizer_idx == 0: # step for generator + set_requires_grad(self.generator, True) + set_requires_grad(self.discriminator, False) + elif optimizer_idx == 1: # step for discriminator + set_requires_grad(self.generator, False) + set_requires_grad(self.discriminator, True) + + batch = self(batch) + + total_loss = 0 + metrics = {} + + if optimizer_idx is None or optimizer_idx == 0: # step for generator + total_loss, metrics = self.generator_loss(batch) + + elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator + if self.config.losses.adversarial.weight > 0: + total_loss, metrics = self.discriminator_loss(batch) + + if self.get_ddp_rank() in (None, 0) and (batch_idx % self.visualize_each_iters == 0 or mode == 'test'): + if self.config.losses.adversarial.weight > 0: + if self.store_discr_outputs_for_vis: + with torch.no_grad(): + self.store_discr_outputs(batch) + vis_suffix = f'_{mode}' + if mode == 'extra_val': + vis_suffix += f'_{extra_val_key}' + self.visualizer(self.current_epoch, batch_idx, batch, suffix=vis_suffix) + + metrics_prefix = f'{mode}_' + if mode == 'extra_val': + metrics_prefix += f'{extra_val_key}_' + result = dict(loss=total_loss, log_info=add_prefix_to_keys(metrics, metrics_prefix)) + if mode == 'val': + result['val_evaluator_state'] = self.val_evaluator.process_batch(batch) + elif mode == 'test': + result['test_evaluator_state'] = self.test_evaluator.process_batch(batch) + elif mode == 'extra_val': + result[f'extra_val_{extra_val_key}_evaluator_state'] = self.extra_evaluators[extra_val_key].process_batch(batch) + + return result + + def get_current_generator(self, no_average=False): + if not no_average and not self.training and self.average_generator and self.generator_average is not None: + return self.generator_average + return self.generator + + def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Pass data through generator and obtain at leas 'predicted_image' and 'inpainted' keys""" + raise NotImplementedError() + + def generator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + raise NotImplementedError() + + def discriminator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + raise NotImplementedError() + + def store_discr_outputs(self, batch): + out_size = batch['image'].shape[2:] + discr_real_out, _ = self.discriminator(batch['image']) + discr_fake_out, _ = self.discriminator(batch['predicted_image']) + batch['discr_output_real'] = F.interpolate(discr_real_out, size=out_size, mode='nearest') + batch['discr_output_fake'] = F.interpolate(discr_fake_out, size=out_size, mode='nearest') + batch['discr_output_diff'] = batch['discr_output_real'] - batch['discr_output_fake'] + + def get_ddp_rank(self): + return self.trainer.global_rank if (self.trainer.num_nodes * self.trainer.num_processes) > 1 else None diff --git a/saicinpainting/training/trainers/default.py b/saicinpainting/training/trainers/default.py new file mode 100644 index 0000000000000000000000000000000000000000..86c7f0fab42924bfc93a031e851117634c70f593 --- /dev/null +++ b/saicinpainting/training/trainers/default.py @@ -0,0 +1,175 @@ +import logging + +import torch +import torch.nn.functional as F +from omegaconf import OmegaConf + +from saicinpainting.training.data.datasets import make_constant_area_crop_params +from saicinpainting.training.losses.distance_weighting import make_mask_distance_weighter +from saicinpainting.training.losses.feature_matching import feature_matching_loss, masked_l1_loss +from saicinpainting.training.modules.fake_fakes import FakeFakesGenerator +from saicinpainting.training.trainers.base import BaseInpaintingTrainingModule, make_multiscale_noise +from saicinpainting.utils import add_prefix_to_keys, get_ramp + +LOGGER = logging.getLogger(__name__) + + +def make_constant_area_crop_batch(batch, **kwargs): + crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2], + img_width=batch['image'].shape[3], + **kwargs) + batch['image'] = batch['image'][:, :, crop_y : crop_y + crop_height, crop_x : crop_x + crop_width] + batch['mask'] = batch['mask'][:, :, crop_y: crop_y + crop_height, crop_x: crop_x + crop_width] + return batch + + +class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule): + def __init__(self, *args, concat_mask=True, rescale_scheduler_kwargs=None, image_to_discriminator='predicted_image', + add_noise_kwargs=None, noise_fill_hole=False, const_area_crop_kwargs=None, + distance_weighter_kwargs=None, distance_weighted_mask_for_discr=False, + fake_fakes_proba=0, fake_fakes_generator_kwargs=None, + **kwargs): + super().__init__(*args, **kwargs) + self.concat_mask = concat_mask + self.rescale_size_getter = get_ramp(**rescale_scheduler_kwargs) if rescale_scheduler_kwargs is not None else None + self.image_to_discriminator = image_to_discriminator + self.add_noise_kwargs = add_noise_kwargs + self.noise_fill_hole = noise_fill_hole + self.const_area_crop_kwargs = const_area_crop_kwargs + self.refine_mask_for_losses = make_mask_distance_weighter(**distance_weighter_kwargs) \ + if distance_weighter_kwargs is not None else None + self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr + + self.fake_fakes_proba = fake_fakes_proba + if self.fake_fakes_proba > 1e-3: + self.fake_fakes_gen = FakeFakesGenerator(**(fake_fakes_generator_kwargs or {})) + + def forward(self, batch): + if self.training and self.rescale_size_getter is not None: + cur_size = self.rescale_size_getter(self.global_step) + batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False) + batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest') + + if self.training and self.const_area_crop_kwargs is not None: + batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs) + + img = batch['image'] + mask = batch['mask'] + + masked_img = img * (1 - mask) + + if self.add_noise_kwargs is not None: + noise = make_multiscale_noise(masked_img, **self.add_noise_kwargs) + if self.noise_fill_hole: + masked_img = masked_img + mask * noise[:, :masked_img.shape[1]] + masked_img = torch.cat([masked_img, noise], dim=1) + + if self.concat_mask: + masked_img = torch.cat([masked_img, mask], dim=1) + + batch['predicted_image'] = self.generator(masked_img) + batch['inpainted'] = mask * batch['predicted_image'] + (1 - mask) * batch['image'] + + if self.fake_fakes_proba > 1e-3: + if self.training and torch.rand(1).item() < self.fake_fakes_proba: + batch['fake_fakes'], batch['fake_fakes_masks'] = self.fake_fakes_gen(img, mask) + batch['use_fake_fakes'] = True + else: + batch['fake_fakes'] = torch.zeros_like(img) + batch['fake_fakes_masks'] = torch.zeros_like(mask) + batch['use_fake_fakes'] = False + + batch['mask_for_losses'] = self.refine_mask_for_losses(img, batch['predicted_image'], mask) \ + if self.refine_mask_for_losses is not None and self.training \ + else mask + + return batch + + def generator_loss(self, batch): + img = batch['image'] + predicted_img = batch[self.image_to_discriminator] + original_mask = batch['mask'] + supervised_mask = batch['mask_for_losses'] + + # L1 + l1_value = masked_l1_loss(predicted_img, img, supervised_mask, + self.config.losses.l1.weight_known, + self.config.losses.l1.weight_missing) + + total_loss = l1_value + metrics = dict(gen_l1=l1_value) + + # vgg-based perceptual loss + if self.config.losses.perceptual.weight > 0: + pl_value = self.loss_pl(predicted_img, img, mask=supervised_mask).sum() * self.config.losses.perceptual.weight + total_loss = total_loss + pl_value + metrics['gen_pl'] = pl_value + + # discriminator + # adversarial_loss calls backward by itself + mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask + self.adversarial_loss.pre_generator_step(real_batch=img, fake_batch=predicted_img, + generator=self.generator, discriminator=self.discriminator) + discr_real_pred, discr_real_features = self.discriminator(img) + discr_fake_pred, discr_fake_features = self.discriminator(predicted_img) + adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss(real_batch=img, + fake_batch=predicted_img, + discr_real_pred=discr_real_pred, + discr_fake_pred=discr_fake_pred, + mask=mask_for_discr) + total_loss = total_loss + adv_gen_loss + metrics['gen_adv'] = adv_gen_loss + metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) + + # feature matching + if self.config.losses.feature_matching.weight > 0: + need_mask_in_fm = OmegaConf.to_container(self.config.losses.feature_matching).get('pass_mask', False) + mask_for_fm = supervised_mask if need_mask_in_fm else None + fm_value = feature_matching_loss(discr_fake_features, discr_real_features, + mask=mask_for_fm) * self.config.losses.feature_matching.weight + total_loss = total_loss + fm_value + metrics['gen_fm'] = fm_value + + if self.loss_resnet_pl is not None: + resnet_pl_value = self.loss_resnet_pl(predicted_img, img) + total_loss = total_loss + resnet_pl_value + metrics['gen_resnet_pl'] = resnet_pl_value + + return total_loss, metrics + + def discriminator_loss(self, batch): + total_loss = 0 + metrics = {} + + predicted_img = batch[self.image_to_discriminator].detach() + self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=predicted_img, + generator=self.generator, discriminator=self.discriminator) + discr_real_pred, discr_real_features = self.discriminator(batch['image']) + discr_fake_pred, discr_fake_features = self.discriminator(predicted_img) + adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss(real_batch=batch['image'], + fake_batch=predicted_img, + discr_real_pred=discr_real_pred, + discr_fake_pred=discr_fake_pred, + mask=batch['mask']) + total_loss = total_loss + adv_discr_loss + metrics['discr_adv'] = adv_discr_loss + metrics.update(add_prefix_to_keys(adv_metrics, 'adv_')) + + + if batch.get('use_fake_fakes', False): + fake_fakes = batch['fake_fakes'] + self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=fake_fakes, + generator=self.generator, discriminator=self.discriminator) + discr_fake_fakes_pred, _ = self.discriminator(fake_fakes) + fake_fakes_adv_discr_loss, fake_fakes_adv_metrics = self.adversarial_loss.discriminator_loss( + real_batch=batch['image'], + fake_batch=fake_fakes, + discr_real_pred=discr_real_pred, + discr_fake_pred=discr_fake_fakes_pred, + mask=batch['mask'] + ) + total_loss = total_loss + fake_fakes_adv_discr_loss + metrics['discr_adv_fake_fakes'] = fake_fakes_adv_discr_loss + metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, 'adv_')) + + return total_loss, metrics diff --git a/saicinpainting/training/visualizers/__init__.py b/saicinpainting/training/visualizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4770d1f15a6790ab9606c7b9881f798c8e2d9545 --- /dev/null +++ b/saicinpainting/training/visualizers/__init__.py @@ -0,0 +1,15 @@ +import logging + +from saicinpainting.training.visualizers.directory import DirectoryVisualizer +from saicinpainting.training.visualizers.noop import NoopVisualizer + + +def make_visualizer(kind, **kwargs): + logging.info(f'Make visualizer {kind}') + + if kind == 'directory': + return DirectoryVisualizer(**kwargs) + if kind == 'noop': + return NoopVisualizer() + + raise ValueError(f'Unknown visualizer kind {kind}') diff --git a/saicinpainting/training/visualizers/base.py b/saicinpainting/training/visualizers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..675f01682ddf5e31b6cc341735378c6f3b242e49 --- /dev/null +++ b/saicinpainting/training/visualizers/base.py @@ -0,0 +1,73 @@ +import abc +from typing import Dict, List + +import numpy as np +import torch +from skimage import color +from skimage.segmentation import mark_boundaries + +from . import colors + +COLORS, _ = colors.generate_colors(151) # 151 - max classes for semantic segmentation + + +class BaseVisualizer: + @abc.abstractmethod + def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): + """ + Take a batch, make an image from it and visualize + """ + raise NotImplementedError() + + +def visualize_mask_and_images(images_dict: Dict[str, np.ndarray], keys: List[str], + last_without_mask=True, rescale_keys=None, mask_only_first=None, + black_mask=False) -> np.ndarray: + mask = images_dict['mask'] > 0.5 + result = [] + for i, k in enumerate(keys): + img = images_dict[k] + img = np.transpose(img, (1, 2, 0)) + + if rescale_keys is not None and k in rescale_keys: + img = img - img.min() + img /= img.max() + 1e-5 + if len(img.shape) == 2: + img = np.expand_dims(img, 2) + + if img.shape[2] == 1: + img = np.repeat(img, 3, axis=2) + elif (img.shape[2] > 3): + img_classes = img.argmax(2) + img = color.label2rgb(img_classes, colors=COLORS) + + if mask_only_first: + need_mark_boundaries = i == 0 + else: + need_mark_boundaries = i < len(keys) - 1 or not last_without_mask + + if need_mark_boundaries: + if black_mask: + img = img * (1 - mask[0][..., None]) + img = mark_boundaries(img, + mask[0], + color=(1., 0., 0.), + outline_color=(1., 1., 1.), + mode='thick') + result.append(img) + return np.concatenate(result, axis=1) + + +def visualize_mask_and_images_batch(batch: Dict[str, torch.Tensor], keys: List[str], max_items=10, + last_without_mask=True, rescale_keys=None) -> np.ndarray: + batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items() + if k in keys or k == 'mask'} + + batch_size = next(iter(batch.values())).shape[0] + items_to_vis = min(batch_size, max_items) + result = [] + for i in range(items_to_vis): + cur_dct = {k: tens[i] for k, tens in batch.items()} + result.append(visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask, + rescale_keys=rescale_keys)) + return np.concatenate(result, axis=0) diff --git a/saicinpainting/training/visualizers/colors.py b/saicinpainting/training/visualizers/colors.py new file mode 100644 index 0000000000000000000000000000000000000000..9e9e39182c58cb06a1c5e97a7e6c497cc3388ebe --- /dev/null +++ b/saicinpainting/training/visualizers/colors.py @@ -0,0 +1,76 @@ +import random +import colorsys + +import numpy as np +import matplotlib +matplotlib.use('agg') +import matplotlib.pyplot as plt +from matplotlib.colors import LinearSegmentedColormap + + +def generate_colors(nlabels, type='bright', first_color_black=False, last_color_black=True, verbose=False): + # https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib + """ + Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks + :param nlabels: Number of labels (size of colormap) + :param type: 'bright' for strong colors, 'soft' for pastel colors + :param first_color_black: Option to use first color as black, True or False + :param last_color_black: Option to use last color as black, True or False + :param verbose: Prints the number of labels and shows the colormap. True or False + :return: colormap for matplotlib + """ + if type not in ('bright', 'soft'): + print ('Please choose "bright" or "soft" for type') + return + + if verbose: + print('Number of labels: ' + str(nlabels)) + + # Generate color map for bright colors, based on hsv + if type == 'bright': + randHSVcolors = [(np.random.uniform(low=0.0, high=1), + np.random.uniform(low=0.2, high=1), + np.random.uniform(low=0.9, high=1)) for i in range(nlabels)] + + # Convert HSV list to RGB + randRGBcolors = [] + for HSVcolor in randHSVcolors: + randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2])) + + if first_color_black: + randRGBcolors[0] = [0, 0, 0] + + if last_color_black: + randRGBcolors[-1] = [0, 0, 0] + + random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) + + # Generate soft pastel colors, by limiting the RGB spectrum + if type == 'soft': + low = 0.6 + high = 0.95 + randRGBcolors = [(np.random.uniform(low=low, high=high), + np.random.uniform(low=low, high=high), + np.random.uniform(low=low, high=high)) for i in range(nlabels)] + + if first_color_black: + randRGBcolors[0] = [0, 0, 0] + + if last_color_black: + randRGBcolors[-1] = [0, 0, 0] + random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) + + # Display colorbar + if verbose: + from matplotlib import colors, colorbar + from matplotlib import pyplot as plt + fig, ax = plt.subplots(1, 1, figsize=(15, 0.5)) + + bounds = np.linspace(0, nlabels, nlabels + 1) + norm = colors.BoundaryNorm(bounds, nlabels) + + cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None, + boundaries=bounds, format='%1i', orientation=u'horizontal') + + return randRGBcolors, random_colormap + diff --git a/saicinpainting/training/visualizers/directory.py b/saicinpainting/training/visualizers/directory.py new file mode 100644 index 0000000000000000000000000000000000000000..bc42e00500c7a5b70b2cef83b03e45b5bb471ff8 --- /dev/null +++ b/saicinpainting/training/visualizers/directory.py @@ -0,0 +1,36 @@ +import os + +import cv2 +import numpy as np + +from saicinpainting.training.visualizers.base import BaseVisualizer, visualize_mask_and_images_batch +from saicinpainting.utils import check_and_warn_input_range + + +class DirectoryVisualizer(BaseVisualizer): + DEFAULT_KEY_ORDER = 'image predicted_image inpainted'.split(' ') + + def __init__(self, outdir, key_order=DEFAULT_KEY_ORDER, max_items_in_batch=10, + last_without_mask=True, rescale_keys=None): + self.outdir = outdir + os.makedirs(self.outdir, exist_ok=True) + self.key_order = key_order + self.max_items_in_batch = max_items_in_batch + self.last_without_mask = last_without_mask + self.rescale_keys = rescale_keys + + def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): + check_and_warn_input_range(batch['image'], 0, 1, 'DirectoryVisualizer target image') + vis_img = visualize_mask_and_images_batch(batch, self.key_order, max_items=self.max_items_in_batch, + last_without_mask=self.last_without_mask, + rescale_keys=self.rescale_keys) + + vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8') + + curoutdir = os.path.join(self.outdir, f'epoch{epoch_i:04d}{suffix}') + os.makedirs(curoutdir, exist_ok=True) + rank_suffix = f'_r{rank}' if rank is not None else '' + out_fname = os.path.join(curoutdir, f'batch{batch_i:07d}{rank_suffix}.jpg') + + vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR) + cv2.imwrite(out_fname, vis_img) diff --git a/saicinpainting/training/visualizers/noop.py b/saicinpainting/training/visualizers/noop.py new file mode 100644 index 0000000000000000000000000000000000000000..4175089a54a8484d51e6c879c1a99c4e4d961d15 --- /dev/null +++ b/saicinpainting/training/visualizers/noop.py @@ -0,0 +1,9 @@ +from saicinpainting.training.visualizers.base import BaseVisualizer + + +class NoopVisualizer(BaseVisualizer): + def __init__(self, *args, **kwargs): + pass + + def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): + pass diff --git a/saicinpainting/utils.py b/saicinpainting/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d0914320eab96e197ae379b94ea7eeb2fe5dfd79 --- /dev/null +++ b/saicinpainting/utils.py @@ -0,0 +1,174 @@ +import bisect +import functools +import logging +import numbers +import os +import signal +import sys +import traceback +import warnings + +import torch +from pytorch_lightning import seed_everything + +LOGGER = logging.getLogger(__name__) + + +def check_and_warn_input_range(tensor, min_value, max_value, name): + actual_min = tensor.min() + actual_max = tensor.max() + if actual_min < min_value or actual_max > max_value: + warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}") + + +def sum_dict_with_prefix(target, cur_dict, prefix, default=0): + for k, v in cur_dict.items(): + target_key = prefix + k + target[target_key] = target.get(target_key, default) + v + + +def average_dicts(dict_list): + result = {} + norm = 1e-3 + for dct in dict_list: + sum_dict_with_prefix(result, dct, '') + norm += 1 + for k in list(result): + result[k] /= norm + return result + + +def add_prefix_to_keys(dct, prefix): + return {prefix + k: v for k, v in dct.items()} + + +def set_requires_grad(module, value): + for param in module.parameters(): + param.requires_grad = value + + +def flatten_dict(dct): + result = {} + for k, v in dct.items(): + if isinstance(k, tuple): + k = '_'.join(k) + if isinstance(v, dict): + for sub_k, sub_v in flatten_dict(v).items(): + result[f'{k}_{sub_k}'] = sub_v + else: + result[k] = v + return result + + +class LinearRamp: + def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): + self.start_value = start_value + self.end_value = end_value + self.start_iter = start_iter + self.end_iter = end_iter + + def __call__(self, i): + if i < self.start_iter: + return self.start_value + if i >= self.end_iter: + return self.end_value + part = (i - self.start_iter) / (self.end_iter - self.start_iter) + return self.start_value * (1 - part) + self.end_value * part + + +class LadderRamp: + def __init__(self, start_iters, values): + self.start_iters = start_iters + self.values = values + assert len(values) == len(start_iters) + 1, (len(values), len(start_iters)) + + def __call__(self, i): + segment_i = bisect.bisect_right(self.start_iters, i) + return self.values[segment_i] + + +def get_ramp(kind='ladder', **kwargs): + if kind == 'linear': + return LinearRamp(**kwargs) + if kind == 'ladder': + return LadderRamp(**kwargs) + raise ValueError(f'Unexpected ramp kind: {kind}') + + +def print_traceback_handler(sig, frame): + LOGGER.warning(f'Received signal {sig}') + bt = ''.join(traceback.format_stack()) + LOGGER.warning(f'Requested stack trace:\n{bt}') + + +def register_debug_signal_handlers(sig=signal.SIGUSR1, handler=print_traceback_handler): + LOGGER.warning(f'Setting signal {sig} handler {handler}') + signal.signal(sig, handler) + + +def handle_deterministic_config(config): + seed = dict(config).get('seed', None) + if seed is None: + return False + + seed_everything(seed) + return True + + +def get_shape(t): + if torch.is_tensor(t): + return tuple(t.shape) + elif isinstance(t, dict): + return {n: get_shape(q) for n, q in t.items()} + elif isinstance(t, (list, tuple)): + return [get_shape(q) for q in t] + elif isinstance(t, numbers.Number): + return type(t) + else: + raise ValueError('unexpected type {}'.format(type(t))) + + +def get_has_ddp_rank(): + master_port = os.environ.get('MASTER_PORT', None) + node_rank = os.environ.get('NODE_RANK', None) + local_rank = os.environ.get('LOCAL_RANK', None) + world_size = os.environ.get('WORLD_SIZE', None) + has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None + return has_rank + + +def handle_ddp_subprocess(): + def main_decorator(main_func): + @functools.wraps(main_func) + def new_main(*args, **kwargs): + # Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE + parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None) + has_parent = parent_cwd is not None + has_rank = get_has_ddp_rank() + assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}' + + if has_parent: + # we are in the worker + sys.argv.extend([ + f'hydra.run.dir={parent_cwd}', + # 'hydra/hydra_logging=disabled', + # 'hydra/job_logging=disabled' + ]) + # do nothing if this is a top-level process + # TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization + + main_func(*args, **kwargs) + return new_main + return main_decorator + + +def handle_ddp_parent_process(): + parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None) + has_parent = parent_cwd is not None + has_rank = get_has_ddp_rank() + assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}' + + if parent_cwd is None: + os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd() + + return has_parent