jayparmr commited on
Commit
19b3da3
·
1 Parent(s): 564eb4b

Upload 118 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +31 -0
  2. config.yml +157 -0
  3. deployment.ipynb +995 -0
  4. handler.py +11 -0
  5. inference.py +341 -0
  6. inference2.py +169 -0
  7. internals/__init__.py +0 -0
  8. internals/data/__init__.py +0 -0
  9. internals/data/dataAccessor.py +104 -0
  10. internals/data/result.py +19 -0
  11. internals/data/task.py +125 -0
  12. internals/pipelines/commons.py +119 -0
  13. internals/pipelines/controlnets.py +221 -0
  14. internals/pipelines/img_classifier.py +24 -0
  15. internals/pipelines/img_to_text.py +31 -0
  16. internals/pipelines/inpainter.py +41 -0
  17. internals/pipelines/object_remove.py +82 -0
  18. internals/pipelines/prompt_modifier.py +54 -0
  19. internals/pipelines/remove_background.py +16 -0
  20. internals/pipelines/safety_checker.py +163 -0
  21. internals/pipelines/twoStepPipeline.py +252 -0
  22. internals/pipelines/upscaler.py +91 -0
  23. internals/util/__init__.py +0 -0
  24. internals/util/args.py +13 -0
  25. internals/util/avatar.py +59 -0
  26. internals/util/cache.py +31 -0
  27. internals/util/commons.py +203 -0
  28. internals/util/config.py +66 -0
  29. internals/util/failure_hander.py +40 -0
  30. internals/util/image.py +18 -0
  31. internals/util/lora_style.py +154 -0
  32. internals/util/slack.py +58 -0
  33. models/ade20k/.DS_Store +0 -0
  34. models/ade20k/__init__.py +1 -0
  35. models/ade20k/base.py +627 -0
  36. models/ade20k/color150.mat +0 -0
  37. models/ade20k/mobilenet.py +154 -0
  38. models/ade20k/object150_info.csv +151 -0
  39. models/ade20k/resnet.py +181 -0
  40. models/ade20k/segm_lib/.DS_Store +0 -0
  41. models/ade20k/segm_lib/nn/.DS_Store +0 -0
  42. models/ade20k/segm_lib/nn/__init__.py +2 -0
  43. models/ade20k/segm_lib/nn/modules/__init__.py +12 -0
  44. models/ade20k/segm_lib/nn/modules/batchnorm.py +329 -0
  45. models/ade20k/segm_lib/nn/modules/comm.py +131 -0
  46. models/ade20k/segm_lib/nn/modules/replicate.py +94 -0
  47. models/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py +56 -0
  48. models/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py +111 -0
  49. models/ade20k/segm_lib/nn/modules/unittest.py +29 -0
  50. models/ade20k/segm_lib/nn/parallel/__init__.py +1 -0
README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # creco-inference
2
+ Unified inference code for SageMaker and Hugging Face endpoints
3
+
4
+ ## Deployment
5
+
6
+ - Inference code (this) should be placed in the model folder respectively,
7
+
8
+ ### SageMaker
9
+
10
+ ```
11
+ model/
12
+ code/
13
+ (repo) <-- The repo inference code as direct child (no sub-folder)
14
+ vae
15
+ unet
16
+ ...
17
+ ```
18
+
19
+ - Refer `deployment.ipynb` for creating endpoint.
20
+
21
+ ### Hugging Face
22
+
23
+ ```
24
+ model/
25
+ (repo) <-- The repo inference code as direct child (no sub-folder)
26
+ vae
27
+ unet
28
+ ...
29
+ ```
30
+
31
+ - Refer [doc](https://huggingface.co/docs/inference-endpoints/guides/create_endpoint) to create endpoint.
config.yml ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_title: b18_ffc075_batch8x15
2
+ training_model:
3
+ kind: default
4
+ visualize_each_iters: 1000
5
+ concat_mask: true
6
+ store_discr_outputs_for_vis: true
7
+ losses:
8
+ l1:
9
+ weight_missing: 0
10
+ weight_known: 10
11
+ perceptual:
12
+ weight: 0
13
+ adversarial:
14
+ kind: r1
15
+ weight: 10
16
+ gp_coef: 0.001
17
+ mask_as_fake_target: true
18
+ allow_scale_mask: true
19
+ feature_matching:
20
+ weight: 100
21
+ resnet_pl:
22
+ weight: 30
23
+ weights_path: ${env:TORCH_HOME}
24
+
25
+ optimizers:
26
+ generator:
27
+ kind: adam
28
+ lr: 0.001
29
+ discriminator:
30
+ kind: adam
31
+ lr: 0.0001
32
+ visualizer:
33
+ key_order:
34
+ - image
35
+ - predicted_image
36
+ - discr_output_fake
37
+ - discr_output_real
38
+ - inpainted
39
+ rescale_keys:
40
+ - discr_output_fake
41
+ - discr_output_real
42
+ kind: directory
43
+ outdir: /group-volume/User-Driven-Content-Generation/r.suvorov/inpainting/experiments/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/samples
44
+ location:
45
+ data_root_dir: /group-volume/User-Driven-Content-Generation/datasets/inpainting_data_root_large
46
+ out_root_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/experiments
47
+ tb_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/tb_logs
48
+ data:
49
+ batch_size: 15
50
+ val_batch_size: 2
51
+ num_workers: 3
52
+ train:
53
+ indir: ${location.data_root_dir}/train
54
+ out_size: 256
55
+ mask_gen_kwargs:
56
+ irregular_proba: 1
57
+ irregular_kwargs:
58
+ max_angle: 4
59
+ max_len: 200
60
+ max_width: 100
61
+ max_times: 5
62
+ min_times: 1
63
+ box_proba: 1
64
+ box_kwargs:
65
+ margin: 10
66
+ bbox_min_size: 30
67
+ bbox_max_size: 150
68
+ max_times: 3
69
+ min_times: 1
70
+ segm_proba: 0
71
+ segm_kwargs:
72
+ confidence_threshold: 0.5
73
+ max_object_area: 0.5
74
+ min_mask_area: 0.07
75
+ downsample_levels: 6
76
+ num_variants_per_mask: 1
77
+ rigidness_mode: 1
78
+ max_foreground_coverage: 0.3
79
+ max_foreground_intersection: 0.7
80
+ max_mask_intersection: 0.1
81
+ max_hidden_area: 0.1
82
+ max_scale_change: 0.25
83
+ horizontal_flip: true
84
+ max_vertical_shift: 0.2
85
+ position_shuffle: true
86
+ transform_variant: distortions
87
+ dataloader_kwargs:
88
+ batch_size: ${data.batch_size}
89
+ shuffle: true
90
+ num_workers: ${data.num_workers}
91
+ val:
92
+ indir: ${location.data_root_dir}/val
93
+ img_suffix: .png
94
+ dataloader_kwargs:
95
+ batch_size: ${data.val_batch_size}
96
+ shuffle: false
97
+ num_workers: ${data.num_workers}
98
+ visual_test:
99
+ indir: ${location.data_root_dir}/korean_test
100
+ img_suffix: _input.png
101
+ pad_out_to_modulo: 32
102
+ dataloader_kwargs:
103
+ batch_size: 1
104
+ shuffle: false
105
+ num_workers: ${data.num_workers}
106
+ generator:
107
+ kind: ffc_resnet
108
+ input_nc: 4
109
+ output_nc: 3
110
+ ngf: 64
111
+ n_downsampling: 3
112
+ n_blocks: 18
113
+ add_out_act: sigmoid
114
+ init_conv_kwargs:
115
+ ratio_gin: 0
116
+ ratio_gout: 0
117
+ enable_lfu: false
118
+ downsample_conv_kwargs:
119
+ ratio_gin: ${generator.init_conv_kwargs.ratio_gout}
120
+ ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin}
121
+ enable_lfu: false
122
+ resnet_conv_kwargs:
123
+ ratio_gin: 0.75
124
+ ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin}
125
+ enable_lfu: false
126
+ discriminator:
127
+ kind: pix2pixhd_nlayer
128
+ input_nc: 3
129
+ ndf: 64
130
+ n_layers: 4
131
+ evaluator:
132
+ kind: default
133
+ inpainted_key: inpainted
134
+ integral_kind: ssim_fid100_f1
135
+ trainer:
136
+ kwargs:
137
+ gpus: -1
138
+ accelerator: ddp
139
+ max_epochs: 200
140
+ gradient_clip_val: 1
141
+ log_gpu_memory: None
142
+ limit_train_batches: 25000
143
+ val_check_interval: ${trainer.kwargs.limit_train_batches}
144
+ log_every_n_steps: 1000
145
+ precision: 32
146
+ terminate_on_nan: false
147
+ check_val_every_n_epoch: 1
148
+ num_sanity_val_steps: 8
149
+ limit_val_batches: 1000
150
+ replace_sampler_ddp: false
151
+ checkpoint_kwargs:
152
+ verbose: true
153
+ save_top_k: 5
154
+ save_last: true
155
+ period: 1
156
+ monitor: val_ssim_fid100_f1_total_mean
157
+ mode: max
deployment.ipynb ADDED
@@ -0,0 +1,995 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "5af7e53b-80ff-4058-888d-fe41804f64ba",
7
+ "metadata": {
8
+ "scrolled": true,
9
+ "tags": []
10
+ },
11
+ "outputs": [
12
+ {
13
+ "name": "stdout",
14
+ "output_type": "stream",
15
+ "text": [
16
+ "Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com\n",
17
+ "Requirement already satisfied: pip in /home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages (23.1.2)\n"
18
+ ]
19
+ }
20
+ ],
21
+ "source": [
22
+ "!pip install --upgrade pip\n",
23
+ "!pip install \"sagemaker==2.116.0\" \"huggingface_hub==0.10.1\" --upgrade --quiet"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 5,
29
+ "id": "93ee3d96-400f-46b4-8eb3-0f3f3c853a7e",
30
+ "metadata": {
31
+ "tags": []
32
+ },
33
+ "outputs": [],
34
+ "source": [
35
+ "from distutils.dir_util import copy_tree\n",
36
+ "from pathlib import Path\n",
37
+ "from huggingface_hub import snapshot_download\n",
38
+ "import random\n",
39
+ "import os\n",
40
+ "import tarfile\n",
41
+ "import time\n",
42
+ "import sagemaker\n",
43
+ "from datetime import datetime\n",
44
+ "from sagemaker.s3 import S3Uploader\n",
45
+ "import boto3\n",
46
+ "from sagemaker.huggingface.model import HuggingFaceModel\n",
47
+ "from threading import Thread\n",
48
+ "import subprocess\n",
49
+ "import shutil"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 2,
55
+ "id": "2db37b03-b517-46bc-8602-4999a64399c0",
56
+ "metadata": {
57
+ "tags": []
58
+ },
59
+ "outputs": [],
60
+ "source": [
61
+ "# ------------------------------------------------\n",
62
+ "# Configuration\n",
63
+ "# ------------------------------------------------\n",
64
+ "STAGE = \"prod\"\n",
65
+ "model_configs = [\n",
66
+ " # {\n",
67
+ " # \"inference_2\": False, \n",
68
+ " # \"path\": \"icbinp\",\n",
69
+ " # \"endpoint_name\": \"gamma-10000-2023-05-16-14-55\"\n",
70
+ " # #\"endpoint_name\": f\"{STAGE}-10000-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
71
+ " # },\n",
72
+ " # {\n",
73
+ " # \"inference_2\": False, \n",
74
+ " # \"path\": \"icb_with_epi\",\n",
75
+ " # \"endpoint_name\": \"gamma-10000-2023-05-16-14-55\"\n",
76
+ " # # \"endpoint_name\": f\"{STAGE}-10000-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
77
+ " # },\n",
78
+ " {\n",
79
+ " \"inference_2\": False, \n",
80
+ " \"path\": \"model_v9\",\n",
81
+ " # \"endpoint_name\": \"gamma-10000-2023-05-16-14-55\"\n",
82
+ " \"endpoint_name\": f\"{STAGE}-10000-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
83
+ " },\n",
84
+ " {\n",
85
+ " \"inference_2\": False, \n",
86
+ " \"path\": \"model_v8\",\n",
87
+ " #\"endpoint_name\": \"gamma-10001-2023-05-08-06-14\"\n",
88
+ " \"endpoint_name\": f\"{STAGE}-10001-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
89
+ " },\n",
90
+ " # {\n",
91
+ " # \"inference_2\": False, \n",
92
+ " # \"path\": \"model_v5_anime\",\n",
93
+ " # \"endpoint_name\": \"gamma-10001-2023-05-08-06-14\"\n",
94
+ " # #\"endpoint_name\": f\"{STAGE}-10001-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
95
+ " # },\n",
96
+ " # {\n",
97
+ " # \"inference_2\": False, \n",
98
+ " # \"path\": \"model_v5.3_comic\",\n",
99
+ " # #\"endpoint_name\": \"gamma-10002-2023-05-08-07-22\"\n",
100
+ " # \"endpoint_name\": f\"{STAGE}-10002-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
101
+ " # },\n",
102
+ " {\n",
103
+ " \"inference_2\": False, \n",
104
+ " \"path\": \"model_v10\",\n",
105
+ " # \"endpoint_name\": \"gamma-10002-2023-05-08-07-22\"\n",
106
+ " \"endpoint_name\": f\"{STAGE}-10002-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
107
+ " },\n",
108
+ " {\n",
109
+ " \"inference_2\": True, \n",
110
+ " \"path\": \"model_v5.2_other\",\n",
111
+ " # \"endpoint_name\": \"gamma-other-2023-05-04-09-33\"\n",
112
+ " \"endpoint_name\": f\"{STAGE}-other-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
113
+ " }\n",
114
+ " # {\n",
115
+ " # \"inference_2\": False, \n",
116
+ " # \"path\": \"model_v6_bheem\",\n",
117
+ " # \"endpoint_name\": f\"{STAGE}-10003-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
118
+ " # },\n",
119
+ " # {\n",
120
+ " # \"inference_2\": False, \n",
121
+ " # \"path\": \"model_v12\",\n",
122
+ " # \"endpoint_name\": \"gamma-10003-2023-05-04-05-20\"\n",
123
+ " # # \"endpoint_name\": f\"{STAGE}-10003-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
124
+ " # }\n",
125
+ "]\n",
126
+ "\n",
127
+ "VpcConfig = {\n",
128
+ " \"Subnets\": [\n",
129
+ " \"subnet-0df3f71df4c7b29e5\",\n",
130
+ " \"subnet-0d753b7fc74b5ee68\"\n",
131
+ " ],\n",
132
+ " \"SecurityGroupIds\": [\n",
133
+ " \"sg-033a7948e79a501cd\"\n",
134
+ " ]\n",
135
+ "}"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": 3,
141
+ "id": "d7322ac4-aeeb-4a72-a662-5f3fa74e6454",
142
+ "metadata": {
143
+ "tags": []
144
+ },
145
+ "outputs": [],
146
+ "source": [
147
+ "def compress(tar_dir=None,output_file=\"model.tar.gz\"):\n",
148
+ " parent_dir=os.getcwd()\n",
149
+ " os.chdir(parent_dir + \"/\" + tar_dir)\n",
150
+ " with tarfile.open(os.path.join(parent_dir, output_file), \"w:gz\") as tar:\n",
151
+ " for item in os.listdir('.'):\n",
152
+ " print(\"- \" + item)\n",
153
+ " tar.add(item, arcname=item)\n",
154
+ " os.chdir(parent_dir)\n",
155
+ "\n",
156
+ " \n",
157
+ "def create_model_tar(config):\n",
158
+ " print(\"Copying inference 'code': \" + config.get(\"path\"))\n",
159
+ " \n",
160
+ " model_tar = Path(config.get(\"path\"))\n",
161
+ " if os.path.exists(model_tar.joinpath(\"code\")):\n",
162
+ " shutil.rmtree(model_tar.joinpath(\"code\"))\n",
163
+ " out_tar = config.get(\"path\") + \".tar.gz\"\n",
164
+ " model_tar.mkdir(exist_ok=True)\n",
165
+ " copy_tree(\"code/\", str(model_tar.joinpath(\"code\")))\n",
166
+ " copy_tree(\"laur_style/\", str(model_tar.joinpath(\"laur_style\")))\n",
167
+ " \n",
168
+ " if config.get(\"inference_2\"):\n",
169
+ " os.remove(model_tar.joinpath(\"code\").joinpath(\"inference.py\"))\n",
170
+ " os.rename(model_tar.joinpath(\"code\").joinpath(\"inference2.py\"), model_tar.joinpath(\"code\").joinpath(\"inference.py\"))\n",
171
+ " \n",
172
+ " print(\"Compressing: \" + config.get(\"path\"))\n",
173
+ "\n",
174
+ " if os.path.exists(out_tar):\n",
175
+ " os.remove(out_tar)\n",
176
+ "\n",
177
+ " compress(str(model_tar), out_tar)\n",
178
+ " \n",
179
+ "def upload_to_s3(config):\n",
180
+ " out_tar = config.get(\"path\") + \".tar.gz\"\n",
181
+ " print(\"Uploading model to S3: \" + out_tar)\n",
182
+ " s3_model_uri=S3Uploader.upload(local_path=out_tar, desired_s3_uri=f\"s3://comic-assets/stable-diffusion-v1-4/v2/\")\n",
183
+ " return s3_model_uri\n",
184
+ " \n",
185
+ " \n",
186
+ "def deploy_and_create_endpoint(config, s3_model_uri):\n",
187
+ " sess = sagemaker.Session()\n",
188
+ " # sagemaker session bucket -> used for uploading data, models and logs\n",
189
+ " # sagemaker will automatically create this bucket if it not exists\n",
190
+ " sagemaker_session_bucket=None\n",
191
+ " if sagemaker_session_bucket is None and sess is not None:\n",
192
+ " # set to default bucket if a bucket name is not given\n",
193
+ " sagemaker_session_bucket = sess.default_bucket()\n",
194
+ " try:\n",
195
+ " role = sagemaker.get_execution_role()\n",
196
+ " except ValueError:\n",
197
+ " iam = boto3.client('iam')\n",
198
+ " role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n",
199
+ "\n",
200
+ " sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)\n",
201
+ " \n",
202
+ " huggingface_model = HuggingFaceModel(\n",
203
+ " model_data=s3_model_uri, # path to your model and script\n",
204
+ " role=role, # iam role with permissions to create an Endpoint\n",
205
+ " transformers_version=\"4.17\", # transformers version used\n",
206
+ " pytorch_version=\"1.10\", # pytorch version used\n",
207
+ " py_version='py38',# python version used\n",
208
+ " vpc_config=VpcConfig,\n",
209
+ " )\n",
210
+ "\n",
211
+ " print(\"Creating endpoint: \" + config.get(\"endpoint_name\"))\n",
212
+ "\n",
213
+ " predictor = huggingface_model.deploy(\n",
214
+ " initial_instance_count=1,\n",
215
+ " instance_type=\"ml.g4dn.xlarge\",\n",
216
+ " endpoint_name=config.get(\"endpoint_name\")\n",
217
+ " )\n",
218
+ "\n",
219
+ " \n",
220
+ "def start_process(config):\n",
221
+ " try:\n",
222
+ " create_model_tar(config)\n",
223
+ " s3_model_uri = upload_to_s3(config)\n",
224
+ " #s3_model_uri = \"s3://comic-assets/stable-diffusion-v1-4/v2//model_v5.2_other.tar.gz\"\n",
225
+ " deploy_and_create_endpoint(config, s3_model_uri)\n",
226
+ " except Exception as e:\n",
227
+ " print(\"Failed to deploy: \" + config.get(\"path\") + \"\\n\" + str(e))"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": 4,
233
+ "id": "cdc04669-90a5-4b43-8499-ad1d2dd63a4c",
234
+ "metadata": {
235
+ "tags": []
236
+ },
237
+ "outputs": [
238
+ {
239
+ "name": "stdout",
240
+ "output_type": "stream",
241
+ "text": [
242
+ "Copying inference 'code': model_v9\n",
243
+ "Compressing: model_v9\n",
244
+ "- scheduler\n",
245
+ "- vae\n",
246
+ "- .ipynb_checkpoints\n",
247
+ "- feature_extractor\n",
248
+ "- tokenizer\n",
249
+ "- text_encoder\n",
250
+ "- model_index.json\n",
251
+ "- laur_style\n",
252
+ "- code\n",
253
+ "- unet\n",
254
+ "- args.json\n",
255
+ "Uploading model to S3: model_v9.tar.gz\n",
256
+ "Creating endpoint: gamma-10000-2023-05-16-14-55\n",
257
+ "-----------------!\n",
258
+ "\n",
259
+ "Completed in : 992.3517553806305s\n"
260
+ ]
261
+ }
262
+ ],
263
+ "source": [
264
+ "threads = []\n",
265
+ "\n",
266
+ "os.chdir(\"/home/ec2-user/SageMaker\")\n",
267
+ "\n",
268
+ "start_time = time.time()\n",
269
+ "\n",
270
+ "for config in model_configs:\n",
271
+ " thread = Thread(target=start_process, args=(config,))\n",
272
+ " thread.start()\n",
273
+ " thread.join()\n",
274
+ " threads.append(thread)\n",
275
+ "\n",
276
+ "for thread in threads:\n",
277
+ " thread.join()\n",
278
+ " \n",
279
+ "print(\"\\n\\nCompleted in : \" + str(time.time() - start_time) + \"s\")\n",
280
+ "\n",
281
+ "# For redeploying gamma endpoints or promoting gamma endpoints to prod\n",
282
+ "\n",
283
+ "# thread1 = Thread(target=deploy_and_create_endpoint, args=(model_configs[0],\"s3://comic-assets/stable-diffusion-v1-4/v2//model_v9.tar.gz\",))\n",
284
+ "# thread2 = Thread(target=deploy_and_create_endpoint, args=(model_configs[1],\"s3://comic-assets/stable-diffusion-v1-4/v2//anime_mode_with_lora.tar.gz\",))\n",
285
+ "# thread3 = Thread(target=deploy_and_create_endpoint, args=(model_configs[0],\"s3://comic-assets/stable-diffusion-v1-4/v2//model_v5.3_comic.tar.gz\",))\n",
286
+ "# thread4 = Thread(target=deploy_and_create_endpoint, args=(model_configs[3],\"s3://comic-assets/stable-diffusion-v1-4/v2//model_v5.2_other.tar.gz\",))\n",
287
+ "\n",
288
+ "# thread1.start()\n",
289
+ "# thread2.start()\n",
290
+ "# thread3.start()\n",
291
+ "# thread4.start()\n",
292
+ "\n",
293
+ "# thread1.join()\n",
294
+ "# thread2.join()\n",
295
+ "# thread3.join()\n",
296
+ "# thread4.join()\n",
297
+ "\n",
298
+ "# print(\"Done\")\n"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "id": "39f007f2-0ff8-487c-b5d7-158f0947b7fd",
305
+ "metadata": {
306
+ "collapsed": true,
307
+ "jupyter": {
308
+ "outputs_hidden": true
309
+ },
310
+ "tags": []
311
+ },
312
+ "outputs": [],
313
+ "source": [
314
+ "\n",
315
+ "# import sagemaker\n",
316
+ "# import boto3\n",
317
+ "# import time \n",
318
+ "\n",
319
+ "# start = time.time()\n",
320
+ "\n",
321
+ "# sess = sagemaker.Session()\n",
322
+ "# # sagemaker session bucket -> used for uploading data, models and logs\n",
323
+ "# # sagemaker will automatically create this bucket if it not exists\n",
324
+ "# sagemaker_session_bucket=None\n",
325
+ "# if sagemaker_session_bucket is None and sess is not None:\n",
326
+ "# # set to default bucket if a bucket name is not given\n",
327
+ "# sagemaker_session_bucket = sess.default_bucket()\n",
328
+ "\n",
329
+ "# try:\n",
330
+ "# role = sagemaker.get_execution_role()\n",
331
+ "# except ValueError:\n",
332
+ "# iam = boto3.client('iam')\n",
333
+ "# role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n",
334
+ "\n",
335
+ "# sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)\n",
336
+ "\n",
337
+ "# print(f\"sagemaker role arn: {role}\")\n",
338
+ "# print(f\"sagemaker bucket: {sess.default_bucket()}\")\n",
339
+ "# print(f\"sagemaker session region: {sess.boto_region_name}\")\n",
340
+ "# print(sagemaker.get_execution_role())\n",
341
+ "\n",
342
+ "# from sagemaker.s3 import S3Uploader\n",
343
+ "\n",
344
+ "# print(\"Uploading model to S3\")\n",
345
+ "\n",
346
+ "# # upload model.tar.gz to s3\n",
347
+ "# s3_model_uri=S3Uploader.upload(local_path=\"model.tar.gz\", desired_s3_uri=f\"s3://comic-assets/stable-diffusion-v1-4/v2/\")\n",
348
+ "\n",
349
+ "# print(f\"model uploaded to: {s3_model_uri}\")\n",
350
+ "\n",
351
+ "\n",
352
+ "# from sagemaker.huggingface.model import HuggingFaceModel\n",
353
+ "\n",
354
+ "# VpcConfig = {\n",
355
+ "# \"Subnets\": [\n",
356
+ "# \"subnet-0df3f71df4c7b29e5\",\n",
357
+ "# \"subnet-0d753b7fc74b5ee68\"\n",
358
+ "# ],\n",
359
+ "# \"SecurityGroupIds\": [\n",
360
+ "# \"sg-033a7948e79a501cd\"\n",
361
+ "# ]\n",
362
+ "# }\n",
363
+ "\n",
364
+ "# # create Hugging Face Model Class\n",
365
+ "# huggingface_model = HuggingFaceModel(\n",
366
+ "# model_data=s3_model_uri, # path to your model and script\n",
367
+ "# role=role, # iam role with permissions to create an Endpoint\n",
368
+ "# transformers_version=\"4.17\", # transformers version used\n",
369
+ "# pytorch_version=\"1.10\", # pytorch version used\n",
370
+ "# py_version='py38',# python version used\n",
371
+ "# vpc_config=VpcConfig,\n",
372
+ "# )\n",
373
+ "\n",
374
+ "# print(\"Deploying model\")\n",
375
+ "\n",
376
+ "# predictor = huggingface_model.deploy(\n",
377
+ "# initial_instance_count=1,\n",
378
+ "# instance_type=\"ml.g4dn.xlarge\",\n",
379
+ "# # endpoint_name=endpoint_name\n",
380
+ "# )\n",
381
+ "\n",
382
+ "# print(f\"Done {time.time() - start}\")"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": null,
388
+ "id": "aa95a262-d6ba-4e61-8657-6f8e5bab74a1",
389
+ "metadata": {
390
+ "tags": []
391
+ },
392
+ "outputs": [],
393
+ "source": [
394
+ "!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | sudo bash"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "id": "524ca546-2a67-4b51-9cda-a1b51a49c339",
401
+ "metadata": {
402
+ "tags": []
403
+ },
404
+ "outputs": [],
405
+ "source": [
406
+ "!sudo yum install git-lfs"
407
+ ]
408
+ },
409
+ {
410
+ "cell_type": "code",
411
+ "execution_count": null,
412
+ "id": "3c7e661f-5eee-4357-80f6-e7563941a812",
413
+ "metadata": {},
414
+ "outputs": [],
415
+ "source": []
416
+ }
417
+ ],
418
+ "metadata": {
419
+ "availableInstances": [
420
+ {
421
+ "_defaultOrder": 0,
422
+ "_isFastLaunch": true,
423
+ "category": "General purpose",
424
+ "gpuNum": 0,
425
+ "hideHardwareSpecs": false,
426
+ "memoryGiB": 4,
427
+ "name": "ml.t3.medium",
428
+ "vcpuNum": 2
429
+ },
430
+ {
431
+ "_defaultOrder": 1,
432
+ "_isFastLaunch": false,
433
+ "category": "General purpose",
434
+ "gpuNum": 0,
435
+ "hideHardwareSpecs": false,
436
+ "memoryGiB": 8,
437
+ "name": "ml.t3.large",
438
+ "vcpuNum": 2
439
+ },
440
+ {
441
+ "_defaultOrder": 2,
442
+ "_isFastLaunch": false,
443
+ "category": "General purpose",
444
+ "gpuNum": 0,
445
+ "hideHardwareSpecs": false,
446
+ "memoryGiB": 16,
447
+ "name": "ml.t3.xlarge",
448
+ "vcpuNum": 4
449
+ },
450
+ {
451
+ "_defaultOrder": 3,
452
+ "_isFastLaunch": false,
453
+ "category": "General purpose",
454
+ "gpuNum": 0,
455
+ "hideHardwareSpecs": false,
456
+ "memoryGiB": 32,
457
+ "name": "ml.t3.2xlarge",
458
+ "vcpuNum": 8
459
+ },
460
+ {
461
+ "_defaultOrder": 4,
462
+ "_isFastLaunch": true,
463
+ "category": "General purpose",
464
+ "gpuNum": 0,
465
+ "hideHardwareSpecs": false,
466
+ "memoryGiB": 8,
467
+ "name": "ml.m5.large",
468
+ "vcpuNum": 2
469
+ },
470
+ {
471
+ "_defaultOrder": 5,
472
+ "_isFastLaunch": false,
473
+ "category": "General purpose",
474
+ "gpuNum": 0,
475
+ "hideHardwareSpecs": false,
476
+ "memoryGiB": 16,
477
+ "name": "ml.m5.xlarge",
478
+ "vcpuNum": 4
479
+ },
480
+ {
481
+ "_defaultOrder": 6,
482
+ "_isFastLaunch": false,
483
+ "category": "General purpose",
484
+ "gpuNum": 0,
485
+ "hideHardwareSpecs": false,
486
+ "memoryGiB": 32,
487
+ "name": "ml.m5.2xlarge",
488
+ "vcpuNum": 8
489
+ },
490
+ {
491
+ "_defaultOrder": 7,
492
+ "_isFastLaunch": false,
493
+ "category": "General purpose",
494
+ "gpuNum": 0,
495
+ "hideHardwareSpecs": false,
496
+ "memoryGiB": 64,
497
+ "name": "ml.m5.4xlarge",
498
+ "vcpuNum": 16
499
+ },
500
+ {
501
+ "_defaultOrder": 8,
502
+ "_isFastLaunch": false,
503
+ "category": "General purpose",
504
+ "gpuNum": 0,
505
+ "hideHardwareSpecs": false,
506
+ "memoryGiB": 128,
507
+ "name": "ml.m5.8xlarge",
508
+ "vcpuNum": 32
509
+ },
510
+ {
511
+ "_defaultOrder": 9,
512
+ "_isFastLaunch": false,
513
+ "category": "General purpose",
514
+ "gpuNum": 0,
515
+ "hideHardwareSpecs": false,
516
+ "memoryGiB": 192,
517
+ "name": "ml.m5.12xlarge",
518
+ "vcpuNum": 48
519
+ },
520
+ {
521
+ "_defaultOrder": 10,
522
+ "_isFastLaunch": false,
523
+ "category": "General purpose",
524
+ "gpuNum": 0,
525
+ "hideHardwareSpecs": false,
526
+ "memoryGiB": 256,
527
+ "name": "ml.m5.16xlarge",
528
+ "vcpuNum": 64
529
+ },
530
+ {
531
+ "_defaultOrder": 11,
532
+ "_isFastLaunch": false,
533
+ "category": "General purpose",
534
+ "gpuNum": 0,
535
+ "hideHardwareSpecs": false,
536
+ "memoryGiB": 384,
537
+ "name": "ml.m5.24xlarge",
538
+ "vcpuNum": 96
539
+ },
540
+ {
541
+ "_defaultOrder": 12,
542
+ "_isFastLaunch": false,
543
+ "category": "General purpose",
544
+ "gpuNum": 0,
545
+ "hideHardwareSpecs": false,
546
+ "memoryGiB": 8,
547
+ "name": "ml.m5d.large",
548
+ "vcpuNum": 2
549
+ },
550
+ {
551
+ "_defaultOrder": 13,
552
+ "_isFastLaunch": false,
553
+ "category": "General purpose",
554
+ "gpuNum": 0,
555
+ "hideHardwareSpecs": false,
556
+ "memoryGiB": 16,
557
+ "name": "ml.m5d.xlarge",
558
+ "vcpuNum": 4
559
+ },
560
+ {
561
+ "_defaultOrder": 14,
562
+ "_isFastLaunch": false,
563
+ "category": "General purpose",
564
+ "gpuNum": 0,
565
+ "hideHardwareSpecs": false,
566
+ "memoryGiB": 32,
567
+ "name": "ml.m5d.2xlarge",
568
+ "vcpuNum": 8
569
+ },
570
+ {
571
+ "_defaultOrder": 15,
572
+ "_isFastLaunch": false,
573
+ "category": "General purpose",
574
+ "gpuNum": 0,
575
+ "hideHardwareSpecs": false,
576
+ "memoryGiB": 64,
577
+ "name": "ml.m5d.4xlarge",
578
+ "vcpuNum": 16
579
+ },
580
+ {
581
+ "_defaultOrder": 16,
582
+ "_isFastLaunch": false,
583
+ "category": "General purpose",
584
+ "gpuNum": 0,
585
+ "hideHardwareSpecs": false,
586
+ "memoryGiB": 128,
587
+ "name": "ml.m5d.8xlarge",
588
+ "vcpuNum": 32
589
+ },
590
+ {
591
+ "_defaultOrder": 17,
592
+ "_isFastLaunch": false,
593
+ "category": "General purpose",
594
+ "gpuNum": 0,
595
+ "hideHardwareSpecs": false,
596
+ "memoryGiB": 192,
597
+ "name": "ml.m5d.12xlarge",
598
+ "vcpuNum": 48
599
+ },
600
+ {
601
+ "_defaultOrder": 18,
602
+ "_isFastLaunch": false,
603
+ "category": "General purpose",
604
+ "gpuNum": 0,
605
+ "hideHardwareSpecs": false,
606
+ "memoryGiB": 256,
607
+ "name": "ml.m5d.16xlarge",
608
+ "vcpuNum": 64
609
+ },
610
+ {
611
+ "_defaultOrder": 19,
612
+ "_isFastLaunch": false,
613
+ "category": "General purpose",
614
+ "gpuNum": 0,
615
+ "hideHardwareSpecs": false,
616
+ "memoryGiB": 384,
617
+ "name": "ml.m5d.24xlarge",
618
+ "vcpuNum": 96
619
+ },
620
+ {
621
+ "_defaultOrder": 20,
622
+ "_isFastLaunch": false,
623
+ "category": "General purpose",
624
+ "gpuNum": 0,
625
+ "hideHardwareSpecs": true,
626
+ "memoryGiB": 0,
627
+ "name": "ml.geospatial.interactive",
628
+ "supportedImageNames": [
629
+ "sagemaker-geospatial-v1-0"
630
+ ],
631
+ "vcpuNum": 0
632
+ },
633
+ {
634
+ "_defaultOrder": 21,
635
+ "_isFastLaunch": true,
636
+ "category": "Compute optimized",
637
+ "gpuNum": 0,
638
+ "hideHardwareSpecs": false,
639
+ "memoryGiB": 4,
640
+ "name": "ml.c5.large",
641
+ "vcpuNum": 2
642
+ },
643
+ {
644
+ "_defaultOrder": 22,
645
+ "_isFastLaunch": false,
646
+ "category": "Compute optimized",
647
+ "gpuNum": 0,
648
+ "hideHardwareSpecs": false,
649
+ "memoryGiB": 8,
650
+ "name": "ml.c5.xlarge",
651
+ "vcpuNum": 4
652
+ },
653
+ {
654
+ "_defaultOrder": 23,
655
+ "_isFastLaunch": false,
656
+ "category": "Compute optimized",
657
+ "gpuNum": 0,
658
+ "hideHardwareSpecs": false,
659
+ "memoryGiB": 16,
660
+ "name": "ml.c5.2xlarge",
661
+ "vcpuNum": 8
662
+ },
663
+ {
664
+ "_defaultOrder": 24,
665
+ "_isFastLaunch": false,
666
+ "category": "Compute optimized",
667
+ "gpuNum": 0,
668
+ "hideHardwareSpecs": false,
669
+ "memoryGiB": 32,
670
+ "name": "ml.c5.4xlarge",
671
+ "vcpuNum": 16
672
+ },
673
+ {
674
+ "_defaultOrder": 25,
675
+ "_isFastLaunch": false,
676
+ "category": "Compute optimized",
677
+ "gpuNum": 0,
678
+ "hideHardwareSpecs": false,
679
+ "memoryGiB": 72,
680
+ "name": "ml.c5.9xlarge",
681
+ "vcpuNum": 36
682
+ },
683
+ {
684
+ "_defaultOrder": 26,
685
+ "_isFastLaunch": false,
686
+ "category": "Compute optimized",
687
+ "gpuNum": 0,
688
+ "hideHardwareSpecs": false,
689
+ "memoryGiB": 96,
690
+ "name": "ml.c5.12xlarge",
691
+ "vcpuNum": 48
692
+ },
693
+ {
694
+ "_defaultOrder": 27,
695
+ "_isFastLaunch": false,
696
+ "category": "Compute optimized",
697
+ "gpuNum": 0,
698
+ "hideHardwareSpecs": false,
699
+ "memoryGiB": 144,
700
+ "name": "ml.c5.18xlarge",
701
+ "vcpuNum": 72
702
+ },
703
+ {
704
+ "_defaultOrder": 28,
705
+ "_isFastLaunch": false,
706
+ "category": "Compute optimized",
707
+ "gpuNum": 0,
708
+ "hideHardwareSpecs": false,
709
+ "memoryGiB": 192,
710
+ "name": "ml.c5.24xlarge",
711
+ "vcpuNum": 96
712
+ },
713
+ {
714
+ "_defaultOrder": 29,
715
+ "_isFastLaunch": true,
716
+ "category": "Accelerated computing",
717
+ "gpuNum": 1,
718
+ "hideHardwareSpecs": false,
719
+ "memoryGiB": 16,
720
+ "name": "ml.g4dn.xlarge",
721
+ "vcpuNum": 4
722
+ },
723
+ {
724
+ "_defaultOrder": 30,
725
+ "_isFastLaunch": false,
726
+ "category": "Accelerated computing",
727
+ "gpuNum": 1,
728
+ "hideHardwareSpecs": false,
729
+ "memoryGiB": 32,
730
+ "name": "ml.g4dn.2xlarge",
731
+ "vcpuNum": 8
732
+ },
733
+ {
734
+ "_defaultOrder": 31,
735
+ "_isFastLaunch": false,
736
+ "category": "Accelerated computing",
737
+ "gpuNum": 1,
738
+ "hideHardwareSpecs": false,
739
+ "memoryGiB": 64,
740
+ "name": "ml.g4dn.4xlarge",
741
+ "vcpuNum": 16
742
+ },
743
+ {
744
+ "_defaultOrder": 32,
745
+ "_isFastLaunch": false,
746
+ "category": "Accelerated computing",
747
+ "gpuNum": 1,
748
+ "hideHardwareSpecs": false,
749
+ "memoryGiB": 128,
750
+ "name": "ml.g4dn.8xlarge",
751
+ "vcpuNum": 32
752
+ },
753
+ {
754
+ "_defaultOrder": 33,
755
+ "_isFastLaunch": false,
756
+ "category": "Accelerated computing",
757
+ "gpuNum": 4,
758
+ "hideHardwareSpecs": false,
759
+ "memoryGiB": 192,
760
+ "name": "ml.g4dn.12xlarge",
761
+ "vcpuNum": 48
762
+ },
763
+ {
764
+ "_defaultOrder": 34,
765
+ "_isFastLaunch": false,
766
+ "category": "Accelerated computing",
767
+ "gpuNum": 1,
768
+ "hideHardwareSpecs": false,
769
+ "memoryGiB": 256,
770
+ "name": "ml.g4dn.16xlarge",
771
+ "vcpuNum": 64
772
+ },
773
+ {
774
+ "_defaultOrder": 35,
775
+ "_isFastLaunch": false,
776
+ "category": "Accelerated computing",
777
+ "gpuNum": 1,
778
+ "hideHardwareSpecs": false,
779
+ "memoryGiB": 61,
780
+ "name": "ml.p3.2xlarge",
781
+ "vcpuNum": 8
782
+ },
783
+ {
784
+ "_defaultOrder": 36,
785
+ "_isFastLaunch": false,
786
+ "category": "Accelerated computing",
787
+ "gpuNum": 4,
788
+ "hideHardwareSpecs": false,
789
+ "memoryGiB": 244,
790
+ "name": "ml.p3.8xlarge",
791
+ "vcpuNum": 32
792
+ },
793
+ {
794
+ "_defaultOrder": 37,
795
+ "_isFastLaunch": false,
796
+ "category": "Accelerated computing",
797
+ "gpuNum": 8,
798
+ "hideHardwareSpecs": false,
799
+ "memoryGiB": 488,
800
+ "name": "ml.p3.16xlarge",
801
+ "vcpuNum": 64
802
+ },
803
+ {
804
+ "_defaultOrder": 38,
805
+ "_isFastLaunch": false,
806
+ "category": "Accelerated computing",
807
+ "gpuNum": 8,
808
+ "hideHardwareSpecs": false,
809
+ "memoryGiB": 768,
810
+ "name": "ml.p3dn.24xlarge",
811
+ "vcpuNum": 96
812
+ },
813
+ {
814
+ "_defaultOrder": 39,
815
+ "_isFastLaunch": false,
816
+ "category": "Memory Optimized",
817
+ "gpuNum": 0,
818
+ "hideHardwareSpecs": false,
819
+ "memoryGiB": 16,
820
+ "name": "ml.r5.large",
821
+ "vcpuNum": 2
822
+ },
823
+ {
824
+ "_defaultOrder": 40,
825
+ "_isFastLaunch": false,
826
+ "category": "Memory Optimized",
827
+ "gpuNum": 0,
828
+ "hideHardwareSpecs": false,
829
+ "memoryGiB": 32,
830
+ "name": "ml.r5.xlarge",
831
+ "vcpuNum": 4
832
+ },
833
+ {
834
+ "_defaultOrder": 41,
835
+ "_isFastLaunch": false,
836
+ "category": "Memory Optimized",
837
+ "gpuNum": 0,
838
+ "hideHardwareSpecs": false,
839
+ "memoryGiB": 64,
840
+ "name": "ml.r5.2xlarge",
841
+ "vcpuNum": 8
842
+ },
843
+ {
844
+ "_defaultOrder": 42,
845
+ "_isFastLaunch": false,
846
+ "category": "Memory Optimized",
847
+ "gpuNum": 0,
848
+ "hideHardwareSpecs": false,
849
+ "memoryGiB": 128,
850
+ "name": "ml.r5.4xlarge",
851
+ "vcpuNum": 16
852
+ },
853
+ {
854
+ "_defaultOrder": 43,
855
+ "_isFastLaunch": false,
856
+ "category": "Memory Optimized",
857
+ "gpuNum": 0,
858
+ "hideHardwareSpecs": false,
859
+ "memoryGiB": 256,
860
+ "name": "ml.r5.8xlarge",
861
+ "vcpuNum": 32
862
+ },
863
+ {
864
+ "_defaultOrder": 44,
865
+ "_isFastLaunch": false,
866
+ "category": "Memory Optimized",
867
+ "gpuNum": 0,
868
+ "hideHardwareSpecs": false,
869
+ "memoryGiB": 384,
870
+ "name": "ml.r5.12xlarge",
871
+ "vcpuNum": 48
872
+ },
873
+ {
874
+ "_defaultOrder": 45,
875
+ "_isFastLaunch": false,
876
+ "category": "Memory Optimized",
877
+ "gpuNum": 0,
878
+ "hideHardwareSpecs": false,
879
+ "memoryGiB": 512,
880
+ "name": "ml.r5.16xlarge",
881
+ "vcpuNum": 64
882
+ },
883
+ {
884
+ "_defaultOrder": 46,
885
+ "_isFastLaunch": false,
886
+ "category": "Memory Optimized",
887
+ "gpuNum": 0,
888
+ "hideHardwareSpecs": false,
889
+ "memoryGiB": 768,
890
+ "name": "ml.r5.24xlarge",
891
+ "vcpuNum": 96
892
+ },
893
+ {
894
+ "_defaultOrder": 47,
895
+ "_isFastLaunch": false,
896
+ "category": "Accelerated computing",
897
+ "gpuNum": 1,
898
+ "hideHardwareSpecs": false,
899
+ "memoryGiB": 16,
900
+ "name": "ml.g5.xlarge",
901
+ "vcpuNum": 4
902
+ },
903
+ {
904
+ "_defaultOrder": 48,
905
+ "_isFastLaunch": false,
906
+ "category": "Accelerated computing",
907
+ "gpuNum": 1,
908
+ "hideHardwareSpecs": false,
909
+ "memoryGiB": 32,
910
+ "name": "ml.g5.2xlarge",
911
+ "vcpuNum": 8
912
+ },
913
+ {
914
+ "_defaultOrder": 49,
915
+ "_isFastLaunch": false,
916
+ "category": "Accelerated computing",
917
+ "gpuNum": 1,
918
+ "hideHardwareSpecs": false,
919
+ "memoryGiB": 64,
920
+ "name": "ml.g5.4xlarge",
921
+ "vcpuNum": 16
922
+ },
923
+ {
924
+ "_defaultOrder": 50,
925
+ "_isFastLaunch": false,
926
+ "category": "Accelerated computing",
927
+ "gpuNum": 1,
928
+ "hideHardwareSpecs": false,
929
+ "memoryGiB": 128,
930
+ "name": "ml.g5.8xlarge",
931
+ "vcpuNum": 32
932
+ },
933
+ {
934
+ "_defaultOrder": 51,
935
+ "_isFastLaunch": false,
936
+ "category": "Accelerated computing",
937
+ "gpuNum": 1,
938
+ "hideHardwareSpecs": false,
939
+ "memoryGiB": 256,
940
+ "name": "ml.g5.16xlarge",
941
+ "vcpuNum": 64
942
+ },
943
+ {
944
+ "_defaultOrder": 52,
945
+ "_isFastLaunch": false,
946
+ "category": "Accelerated computing",
947
+ "gpuNum": 4,
948
+ "hideHardwareSpecs": false,
949
+ "memoryGiB": 192,
950
+ "name": "ml.g5.12xlarge",
951
+ "vcpuNum": 48
952
+ },
953
+ {
954
+ "_defaultOrder": 53,
955
+ "_isFastLaunch": false,
956
+ "category": "Accelerated computing",
957
+ "gpuNum": 4,
958
+ "hideHardwareSpecs": false,
959
+ "memoryGiB": 384,
960
+ "name": "ml.g5.24xlarge",
961
+ "vcpuNum": 96
962
+ },
963
+ {
964
+ "_defaultOrder": 54,
965
+ "_isFastLaunch": false,
966
+ "category": "Accelerated computing",
967
+ "gpuNum": 8,
968
+ "hideHardwareSpecs": false,
969
+ "memoryGiB": 768,
970
+ "name": "ml.g5.48xlarge",
971
+ "vcpuNum": 192
972
+ }
973
+ ],
974
+ "instance_type": "ml.t3.medium",
975
+ "kernelspec": {
976
+ "display_name": "conda_pytorch_p39",
977
+ "language": "python",
978
+ "name": "conda_pytorch_p39"
979
+ },
980
+ "language_info": {
981
+ "codemirror_mode": {
982
+ "name": "ipython",
983
+ "version": 3
984
+ },
985
+ "file_extension": ".py",
986
+ "mimetype": "text/x-python",
987
+ "name": "python",
988
+ "nbconvert_exporter": "python",
989
+ "pygments_lexer": "ipython3",
990
+ "version": "3.9.15"
991
+ }
992
+ },
993
+ "nbformat": 4,
994
+ "nbformat_minor": 5
995
+ }
handler.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ from inference import model_fn, predict_fn
4
+
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path=""):
8
+ return model_fn(path)
9
+
10
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
11
+ return predict_fn(data, None)
inference.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import torch
4
+
5
+ from internals.data.dataAccessor import update_db
6
+ from internals.data.task import Task, TaskType
7
+ from internals.pipelines.commons import Img2Img, Text2Img
8
+ from internals.pipelines.controlnets import ControlNet
9
+ from internals.pipelines.img_classifier import ImageClassifier
10
+ from internals.pipelines.img_to_text import Image2Text
11
+ from internals.pipelines.prompt_modifier import PromptModifier
12
+ from internals.pipelines.safety_checker import SafetyChecker
13
+ from internals.util.args import apply_style_args
14
+ from internals.util.avatar import Avatar
15
+ from internals.util.cache import auto_clear_cuda_and_gc
16
+ from internals.util.commons import pickPoses, upload_image, upload_images
17
+ from internals.util.config import set_configs_from_task, set_root_dir
18
+ from internals.util.failure_hander import FailureHandler
19
+ from internals.util.lora_style import LoraStyle
20
+ from internals.util.slack import Slack
21
+
22
+ torch.backends.cudnn.benchmark = True
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+
25
+ num_return_sequences = 4 # the number of results to generate
26
+ auto_mode = False
27
+
28
+ prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
29
+ img2text = Image2Text()
30
+ img_classifier = ImageClassifier()
31
+ controlnet = ControlNet()
32
+ lora_style = LoraStyle()
33
+ text2img_pipe = Text2Img()
34
+ img2img_pipe = Img2Img()
35
+ safety_checker = SafetyChecker()
36
+ slack = Slack()
37
+ avatar = Avatar()
38
+
39
+
40
+ def get_patched_prompt(task: Task):
41
+ def add_style_and_character(prompt: List[str], additional: Optional[str] = None):
42
+ for i in range(len(prompt)):
43
+ prompt[i] = avatar.add_code_names(prompt[i])
44
+ prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
45
+ if additional:
46
+ prompt[i] = additional + " " + prompt[i]
47
+
48
+ prompt = task.get_prompt()
49
+
50
+ if task.is_prompt_engineering():
51
+ prompt = prompt_modifier.modify(prompt)
52
+ else:
53
+ prompt = [prompt] * num_return_sequences
54
+
55
+ ori_prompt = [task.get_prompt()] * num_return_sequences
56
+
57
+ class_name = None
58
+ # if task.get_imageUrl():
59
+ # class_name = img_classifier.classify(
60
+ # task.get_imageUrl(), task.get_width(), task.get_height()
61
+ # )
62
+ add_style_and_character(ori_prompt, class_name)
63
+ add_style_and_character(prompt, class_name)
64
+
65
+ print({"prompts": prompt})
66
+
67
+ return (prompt, ori_prompt)
68
+
69
+
70
+ def get_patched_prompt_tile_upscale(task: Task):
71
+ if task.get_prompt():
72
+ prompt = task.get_prompt()
73
+ else:
74
+ prompt = img2text.process(task.get_imageUrl())
75
+
76
+ prompt = avatar.add_code_names(prompt)
77
+ prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
78
+
79
+ class_name = img_classifier.classify(
80
+ task.get_imageUrl(), task.get_width(), task.get_height()
81
+ )
82
+ prompt = class_name + " " + prompt
83
+
84
+ print({"prompt": prompt})
85
+
86
+ return prompt
87
+
88
+
89
+ @update_db
90
+ @auto_clear_cuda_and_gc(controlnet)
91
+ @slack.auto_send_alert
92
+ def canny(task: Task):
93
+ prompt, _ = get_patched_prompt(task)
94
+
95
+ controlnet.load_canny()
96
+
97
+ # pipe2 is used for canny and pose
98
+ lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
99
+ lora_patcher.patch()
100
+
101
+ images, has_nsfw = controlnet.process_canny(
102
+ prompt=prompt,
103
+ imageUrl=task.get_imageUrl(),
104
+ seed=task.get_seed(),
105
+ steps=task.get_steps(),
106
+ width=task.get_width(),
107
+ height=task.get_height(),
108
+ guidance_scale=task.get_cy_guidance_scale(),
109
+ negative_prompt=[
110
+ f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}"
111
+ ]
112
+ * num_return_sequences,
113
+ **lora_patcher.kwargs(),
114
+ )
115
+
116
+ generated_image_urls = upload_images(images, "_canny", task.get_taskId())
117
+
118
+ lora_patcher.cleanup()
119
+ controlnet.cleanup()
120
+
121
+ return {
122
+ "modified_prompts": prompt,
123
+ "generated_image_urls": generated_image_urls,
124
+ "has_nsfw": has_nsfw,
125
+ }
126
+
127
+
128
+ @update_db
129
+ @auto_clear_cuda_and_gc(controlnet)
130
+ @slack.auto_send_alert
131
+ def tile_upscale(task: Task):
132
+ output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId())
133
+
134
+ prompt = get_patched_prompt_tile_upscale(task)
135
+
136
+ controlnet.load_tile_upscaler()
137
+
138
+ lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
139
+ lora_patcher.patch()
140
+
141
+ images, has_nsfw = controlnet.process_tile_upscaler(
142
+ imageUrl=task.get_imageUrl(),
143
+ seed=task.get_seed(),
144
+ steps=task.get_steps(),
145
+ width=task.get_width(),
146
+ height=task.get_height(),
147
+ prompt=prompt,
148
+ resize_dimension=task.get_resize_dimension(),
149
+ negative_prompt=task.get_negative_prompt(),
150
+ guidance_scale=task.get_ti_guidance_scale(),
151
+ )
152
+
153
+ generated_image_url = upload_image(images[0], output_key)
154
+
155
+ lora_patcher.cleanup()
156
+ controlnet.cleanup()
157
+
158
+ return {
159
+ "modified_prompts": prompt,
160
+ "generated_image_url": generated_image_url,
161
+ "has_nsfw": has_nsfw,
162
+ }
163
+
164
+
165
+ @update_db
166
+ @auto_clear_cuda_and_gc(controlnet)
167
+ @slack.auto_send_alert
168
+ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
169
+ prompt, _ = get_patched_prompt(task)
170
+
171
+ controlnet.load_pose()
172
+
173
+ # pipe2 is used for canny and pose
174
+ lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
175
+ lora_patcher.patch()
176
+
177
+ if poses is None:
178
+ poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences
179
+
180
+ images, has_nsfw = controlnet.process_pose(
181
+ prompt=prompt,
182
+ image=poses,
183
+ seed=task.get_seed(),
184
+ steps=task.get_steps(),
185
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
186
+ width=task.get_width(),
187
+ height=task.get_height(),
188
+ guidance_scale=task.get_po_guidance_scale(),
189
+ **lora_patcher.kwargs(),
190
+ )
191
+
192
+ generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
193
+
194
+ lora_patcher.cleanup()
195
+ controlnet.cleanup()
196
+
197
+ return {
198
+ "modified_prompts": prompt,
199
+ "generated_image_urls": generated_image_urls,
200
+ "has_nsfw": has_nsfw,
201
+ }
202
+
203
+
204
+ @update_db
205
+ @auto_clear_cuda_and_gc(controlnet)
206
+ @slack.auto_send_alert
207
+ def text2img(task: Task):
208
+ prompt, ori_prompt = get_patched_prompt(task)
209
+
210
+ lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style())
211
+ lora_patcher.patch()
212
+
213
+ torch.manual_seed(task.get_seed())
214
+
215
+ images, has_nsfw = text2img_pipe.process(
216
+ prompt=ori_prompt,
217
+ modified_prompts=prompt,
218
+ num_inference_steps=task.get_steps(),
219
+ guidance_scale=7.5,
220
+ height=task.get_height(),
221
+ width=task.get_width(),
222
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
223
+ iteration=task.get_iteration(),
224
+ **lora_patcher.kwargs(),
225
+ )
226
+
227
+ generated_image_urls = upload_images(images, "", task.get_taskId())
228
+
229
+ lora_patcher.cleanup()
230
+
231
+ return {
232
+ "modified_prompts": prompt,
233
+ "generated_image_urls": generated_image_urls,
234
+ "has_nsfw": has_nsfw,
235
+ }
236
+
237
+
238
+ @update_db
239
+ @auto_clear_cuda_and_gc(controlnet)
240
+ @slack.auto_send_alert
241
+ def img2img(task: Task):
242
+ prompt, _ = get_patched_prompt(task)
243
+
244
+ lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style())
245
+ lora_patcher.patch()
246
+
247
+ torch.manual_seed(task.get_seed())
248
+
249
+ images, has_nsfw = img2img_pipe.process(
250
+ prompt=prompt,
251
+ imageUrl=task.get_imageUrl(),
252
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
253
+ steps=task.get_steps(),
254
+ width=task.get_width(),
255
+ height=task.get_height(),
256
+ strength=task.get_i2i_strength(),
257
+ guidance_scale=task.get_i2i_guidance_scale(),
258
+ **lora_patcher.kwargs(),
259
+ )
260
+
261
+ generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId())
262
+
263
+ lora_patcher.cleanup()
264
+
265
+ return {
266
+ "modified_prompts": prompt,
267
+ "generated_image_urls": generated_image_urls,
268
+ "has_nsfw": has_nsfw,
269
+ }
270
+
271
+
272
+ def model_fn(model_dir):
273
+ print("Logs: model loaded .... starts")
274
+
275
+ set_root_dir(__file__)
276
+
277
+ FailureHandler.register()
278
+
279
+ avatar.load_local()
280
+
281
+ prompt_modifier.load()
282
+ img2text.load()
283
+ img_classifier.load()
284
+
285
+ lora_style.load(model_dir)
286
+ safety_checker.load()
287
+
288
+ controlnet.load(model_dir)
289
+ text2img_pipe.load(model_dir)
290
+ img2img_pipe.create(text2img_pipe)
291
+
292
+ safety_checker.apply(text2img_pipe)
293
+ safety_checker.apply(img2img_pipe)
294
+ safety_checker.apply(controlnet)
295
+
296
+ print("Logs: model loaded ....")
297
+ return
298
+
299
+
300
+ @FailureHandler.clear
301
+ def predict_fn(data, pipe):
302
+ task = Task(data)
303
+ print("task is ", data)
304
+
305
+ FailureHandler.handle(task)
306
+
307
+ try:
308
+ # Set set_environment
309
+ set_configs_from_task(task)
310
+
311
+ # Apply arguments
312
+ apply_style_args(data)
313
+
314
+ # Re-fetch styles
315
+ lora_style.fetch_styles()
316
+
317
+ # Fetch avatars
318
+ avatar.fetch_from_network(task.get_model_id())
319
+
320
+ task_type = task.get_type()
321
+
322
+ if task_type == TaskType.TEXT_TO_IMAGE:
323
+ # character sheet
324
+ if "character sheet" in task.get_prompt().lower():
325
+ return pose(task, s3_outkey="", poses=pickPoses())
326
+ else:
327
+ return text2img(task)
328
+ elif task_type == TaskType.IMAGE_TO_IMAGE:
329
+ return img2img(task)
330
+ elif task_type == TaskType.CANNY:
331
+ return canny(task)
332
+ elif task_type == TaskType.POSE:
333
+ return pose(task)
334
+ elif task_type == TaskType.TILE_UPSCALE:
335
+ return tile_upscale(task)
336
+ else:
337
+ raise Exception("Invalid task type")
338
+ except Exception as e:
339
+ print(f"Error: {e}")
340
+ slack.error_alert(task, e)
341
+ return None
inference2.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+
3
+ import torch
4
+
5
+ from internals.data.dataAccessor import update_db
6
+ from internals.data.task import ModelType, Task, TaskType
7
+ from internals.pipelines.inpainter import InPainter
8
+ from internals.pipelines.object_remove import ObjectRemoval
9
+ from internals.pipelines.prompt_modifier import PromptModifier
10
+ from internals.pipelines.remove_background import RemoveBackground
11
+ from internals.pipelines.safety_checker import SafetyChecker
12
+ from internals.pipelines.upscaler import Upscaler
13
+ from internals.util.avatar import Avatar
14
+ from internals.util.cache import clear_cuda
15
+ from internals.util.commons import (construct_default_s3_url, upload_image,
16
+ upload_images)
17
+ from internals.util.config import set_configs_from_task, set_root_dir
18
+ from internals.util.failure_hander import FailureHandler
19
+ from internals.util.slack import Slack
20
+
21
+ torch.backends.cudnn.benchmark = True
22
+ torch.backends.cuda.matmul.allow_tf32 = True
23
+
24
+ num_return_sequences = 4
25
+ auto_mode = False
26
+
27
+ slack = Slack()
28
+
29
+ prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
30
+ upscaler = Upscaler()
31
+ inpainter = InPainter()
32
+ safety_checker = SafetyChecker()
33
+ object_removal = ObjectRemoval()
34
+ avatar = Avatar()
35
+
36
+
37
+ @update_db
38
+ @slack.auto_send_alert
39
+ def remove_bg(task: Task):
40
+ remove_background = RemoveBackground()
41
+ output_image = remove_background.remove(task.get_imageUrl())
42
+
43
+ output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
44
+ upload_image(output_image, output_key)
45
+
46
+ return {"generated_image_url": construct_default_s3_url(output_key)}
47
+
48
+
49
+ @update_db
50
+ @slack.auto_send_alert
51
+ def inpaint(task: Task):
52
+ prompt = avatar.add_code_names(task.get_prompt())
53
+ if task.is_prompt_engineering():
54
+ prompt = prompt_modifier.modify(prompt)
55
+ else:
56
+ prompt = [prompt] * num_return_sequences
57
+
58
+ print({"prompts": prompt})
59
+
60
+ images = inpainter.process(
61
+ prompt=prompt,
62
+ image_url=task.get_imageUrl(),
63
+ mask_image_url=task.get_maskImageUrl(),
64
+ width=task.get_width(),
65
+ height=task.get_height(),
66
+ seed=task.get_seed(),
67
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
68
+ )
69
+ generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
70
+
71
+ clear_cuda()
72
+
73
+ return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
74
+
75
+
76
+ @update_db
77
+ @slack.auto_send_alert
78
+ def remove_object(task: Task):
79
+ output_key = "crecoAI/{}_object_remove.png".format(task.get_taskId())
80
+
81
+ images = object_removal.process(
82
+ image_url=task.get_imageUrl(),
83
+ mask_image_url=task.get_maskImageUrl(),
84
+ seed=task.get_seed(),
85
+ width=task.get_width(),
86
+ height=task.get_height(),
87
+ )
88
+ generated_image_urls = upload_image(images[0], output_key)
89
+
90
+ clear_cuda()
91
+
92
+ return {"generated_image_urls": generated_image_urls}
93
+
94
+
95
+ @update_db
96
+ @slack.auto_send_alert
97
+ def upscale_image(task: Task):
98
+ output_key = "crecoAI/{}_upscale.png".format(task.get_taskId())
99
+ out_img = None
100
+ if task.get_modelType() == ModelType.ANIME:
101
+ print("Using Anime model")
102
+ out_img = upscaler.upscale_anime(
103
+ image=task.get_imageUrl(), resize_dimension=task.get_resize_dimension()
104
+ )
105
+ else:
106
+ print("Using Real model")
107
+ out_img = upscaler.upscale(
108
+ image=task.get_imageUrl(), resize_dimension=task.get_resize_dimension()
109
+ )
110
+
111
+ upload_image(BytesIO(out_img), output_key)
112
+ return {"generated_image_url": construct_default_s3_url(output_key)}
113
+
114
+
115
+ def model_fn(model_dir):
116
+ print("Logs: model loaded .... starts")
117
+
118
+ set_root_dir(__file__)
119
+
120
+ FailureHandler.register()
121
+
122
+ avatar.load_local()
123
+
124
+ prompt_modifier.load()
125
+ safety_checker.load()
126
+
127
+ object_removal.load(model_dir)
128
+ upscaler.load()
129
+ inpainter.load()
130
+
131
+ safety_checker.apply(inpainter)
132
+
133
+ print("Logs: model loaded ....")
134
+ return
135
+
136
+
137
+ @FailureHandler.clear
138
+ def predict_fn(data, pipe):
139
+ task = Task(data)
140
+ print("task is ", data)
141
+
142
+ FailureHandler.handle(task)
143
+
144
+ # Set set_environment
145
+ set_configs_from_task(task)
146
+
147
+ try:
148
+ # Set set_environment
149
+ set_configs_from_task(task)
150
+
151
+ # Fetch avatars
152
+ avatar.fetch_from_network(task.get_model_id())
153
+
154
+ task_type = task.get_type()
155
+
156
+ if task_type == TaskType.REMOVE_BG:
157
+ return remove_bg(task)
158
+ elif task_type == TaskType.INPAINT:
159
+ return inpaint(task)
160
+ elif task_type == TaskType.UPSCALE_IMAGE:
161
+ return upscale_image(task)
162
+ elif task_type == TaskType.OBJECT_REMOVAL:
163
+ return remove_object(task)
164
+ else:
165
+ raise Exception("Invalid task type")
166
+ except Exception as e:
167
+ print(f"Error: {e}")
168
+ slack.error_alert(task, e)
169
+ return None
internals/__init__.py ADDED
File without changes
internals/data/__init__.py ADDED
File without changes
internals/data/dataAccessor.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from typing import Dict, List, Optional
3
+
4
+ import requests
5
+ from pydash import includes
6
+
7
+ from internals.data.task import Task
8
+ from internals.util.config import api_endpoint, api_headers
9
+ from internals.util.slack import Slack
10
+
11
+
12
+ def updateSource(sourceId, userId, state):
13
+ print("update source is called")
14
+ url = api_endpoint() + f"/comic-crecoai/source/{sourceId}"
15
+ headers = {
16
+ "Content-Type": "application/json",
17
+ "user-id": str(userId),
18
+ **api_headers(),
19
+ }
20
+
21
+ data = {"state": state}
22
+
23
+ try:
24
+ response = requests.patch(url, headers=headers, json=data, timeout=10)
25
+ print("update source response", response)
26
+ except requests.exceptions.Timeout:
27
+ print("Request timed out while updating source")
28
+ except requests.exceptions.RequestException as e:
29
+ print(f"Error while updating source: {e}")
30
+
31
+ return
32
+
33
+
34
+ def saveGeneratedImages(sourceId, userId, has_nsfw: bool):
35
+ print("save generation called")
36
+ url = api_endpoint() + "/comic-crecoai/source/" + str(sourceId) + "/generatedImages"
37
+ headers = {
38
+ "Content-Type": "application/json",
39
+ "user-id": str(userId),
40
+ **api_headers(),
41
+ }
42
+ data = {"state": "ACTIVE", "has_nsfw": has_nsfw}
43
+
44
+ try:
45
+ requests.patch(url, headers=headers, json=data)
46
+ # print("save generation response", response)
47
+ except requests.exceptions.Timeout:
48
+ print("Request timed out while saving image")
49
+ except requests.exceptions.RequestException as e:
50
+ print("Failed to mark source as active: ", e)
51
+ return
52
+ return
53
+
54
+
55
+ def getStyles() -> Optional[Dict]:
56
+ url = api_endpoint() + "/comic-crecoai/style"
57
+ try:
58
+ response = requests.get(
59
+ url,
60
+ timeout=10,
61
+ headers={"x-api-key": "kGyEMp)oHB(zf^E5>-{o]I%go", **api_headers()},
62
+ )
63
+ return response.json()
64
+ except requests.exceptions.Timeout:
65
+ print("Request timed out while fetching styles")
66
+ except requests.exceptions.RequestException as e:
67
+ print(f"Error while fetching styles: {e}")
68
+ return None
69
+
70
+
71
+ def getCharacters(model_id: str) -> Optional[List]:
72
+ url = api_endpoint() + "/comic-crecoai/model/{}".format(model_id)
73
+ try:
74
+ response = requests.get(url, timeout=10, headers=api_headers())
75
+ response = response.json()
76
+ response = response["data"]["characters"]
77
+ return response
78
+ except requests.exceptions.Timeout:
79
+ print("Request timed out while fetching characters")
80
+ except Exception as e:
81
+ print(f"Error while fetching characters: {e}")
82
+ return None
83
+
84
+
85
+ def update_db(func):
86
+ def caller(*args, **kwargs):
87
+ if type(args[0]) is not Task:
88
+ raise Exception("First argument must be a Task object")
89
+ task = args[0]
90
+ try:
91
+ updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
92
+ rargs = func(*args, **kwargs)
93
+ has_nsfw = rargs.get("has_nsfw", False)
94
+ updateSource(task.get_sourceId(), task.get_userId(), "COMPLETED")
95
+ saveGeneratedImages(task.get_sourceId(), task.get_userId(), has_nsfw)
96
+ return rargs
97
+ except Exception as e:
98
+ print("Error processing image: {}".format(str(e)))
99
+ traceback.print_exc()
100
+ slack = Slack()
101
+ slack.error_alert(task, e)
102
+ updateSource(task.get_sourceId(), task.get_userId(), "FAILED")
103
+
104
+ return caller
internals/data/result.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from internals.util.config import get_nsfw_access
2
+
3
+
4
+ class Result:
5
+ images, nsfw = None, None
6
+
7
+ def __init__(self, images, nsfw):
8
+ self.images = images
9
+ self.nsfw = nsfw
10
+
11
+ @staticmethod
12
+ def from_result(result):
13
+ has_nsfw = result.nsfw_content_detected
14
+ if has_nsfw and isinstance(has_nsfw, list):
15
+ has_nsfw = any(has_nsfw)
16
+
17
+ has_nsfw = ~get_nsfw_access() and has_nsfw
18
+ return (result.images, bool(has_nsfw))
19
+ # return Result(result.images, result.has_nsfw_concepts)
internals/data/task.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+
6
+
7
+ class TaskType(Enum):
8
+ TEXT_TO_IMAGE = "GENERATE_AI_IMAGE"
9
+ IMAGE_TO_IMAGE = "IMAGE_TO_IMAGE"
10
+ POSE = "POSE"
11
+ CANNY = "CANNY"
12
+ REMOVE_BG = "REMOVE_BG"
13
+ INPAINT = "INPAINT"
14
+ UPSCALE_IMAGE = "UPSCALE_IMAGE"
15
+ TILE_UPSCALE = "TILE_UPSCALE"
16
+ OBJECT_REMOVAL = "OBJECT_REMOVAL"
17
+
18
+
19
+ class ModelType(Enum):
20
+ REAL = 10000
21
+ ANIME = 10001
22
+ COMIC = 10002
23
+
24
+
25
+ class Task:
26
+ def __init__(self, data):
27
+ self.__data = data
28
+ if data.get("seed", -1) == None or self.get_seed() == -1:
29
+ self.__data["seed"] = np.random.randint(0, np.iinfo(np.int64).max)
30
+ prompt = data.get("prompt", "")
31
+ if prompt is None:
32
+ self.__data["prompt"] = ""
33
+ else:
34
+ self.__data["prompt"] = data.get("prompt", "")[:200]
35
+
36
+ def get_taskId(self) -> str:
37
+ return self.__data.get("task_id")
38
+
39
+ def get_sourceId(self) -> str:
40
+ return self.__data.get("source_id")
41
+
42
+ def get_imageUrl(self) -> str:
43
+ return self.__data.get("imageUrl", None)
44
+
45
+ def get_prompt(self) -> str:
46
+ return self.__data.get("prompt", "")
47
+
48
+ def get_userId(self) -> str:
49
+ return self.__data.get("userId", "")
50
+
51
+ def get_email(self) -> str:
52
+ return self.__data.get("email", "")
53
+
54
+ def get_style(self) -> str:
55
+ return self.__data.get("style", None)
56
+
57
+ def get_iteration(self) -> float:
58
+ return float(self.__data.get("iteration", 3.0))
59
+
60
+ def get_modelType(self) -> ModelType:
61
+ id = self.get_model_id()
62
+ return ModelType(id)
63
+
64
+ def get_model_id(self) -> int:
65
+ return int(self.__data.get("modelId", 10000))
66
+
67
+ def get_width(self) -> int:
68
+ return int(self.__data.get("width", 512))
69
+
70
+ def get_height(self) -> int:
71
+ return int(self.__data.get("height", 512))
72
+
73
+ def get_seed(self) -> int:
74
+ return int(self.__data.get("seed", -1))
75
+
76
+ def get_steps(self) -> int:
77
+ return int(self.__data.get("steps", "75"))
78
+
79
+ def get_type(self) -> Union[TaskType, None]:
80
+ try:
81
+ return TaskType(self.__data.get("task_type"))
82
+ except ValueError:
83
+ return None
84
+
85
+ def get_maskImageUrl(self) -> str:
86
+ return self.__data.get("maskImageUrl")
87
+
88
+ def get_negative_prompt(self) -> str:
89
+ return self.__data.get("negative_prompt", "")
90
+
91
+ def is_prompt_engineering(self) -> bool:
92
+ return self.__data.get("auto_mode", True)
93
+
94
+ def get_queue_name(self) -> str:
95
+ return self.__data.get("queue_name", "")
96
+
97
+ def get_resize_dimension(self) -> int:
98
+ return self.__data.get("resize_dimension", 1024)
99
+
100
+ def get_ti_guidance_scale(self) -> float:
101
+ return self.__data.get("ti_guidance_scale", 7.5)
102
+
103
+ def get_i2i_guidance_scale(self) -> float:
104
+ return self.__data.get("i2i_guidance_scale", 7.5)
105
+
106
+ def get_i2i_strength(self) -> float:
107
+ return self.__data.get("i2i_strength", 0.75)
108
+
109
+ def get_cy_guidance_scale(self) -> float:
110
+ return self.__data.get("cy_guidance_scale", 9)
111
+
112
+ def get_po_guidance_scale(self) -> float:
113
+ return self.__data.get("po_guidance_scale", 7.5)
114
+
115
+ def get_nsfw_threshold(self) -> float:
116
+ return self.__data.get("nsfw_threshold", 0.03)
117
+
118
+ def can_access_nsfw(self) -> bool:
119
+ return self.__data.get("can_access_nsfw", False)
120
+
121
+ def get_access_token(self) -> str:
122
+ return self.__data.get("access_token", "")
123
+
124
+ def get_raw(self) -> dict:
125
+ return self.__data.copy()
internals/pipelines/commons.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union
2
+
3
+ import torch
4
+ from diffusers import StableDiffusionImg2ImgPipeline
5
+
6
+ from internals.data.result import Result
7
+ from internals.pipelines.twoStepPipeline import two_step_pipeline
8
+ from internals.util.commons import disable_safety_checker, download_image
9
+
10
+
11
+ class AbstractPipeline:
12
+ def load(self, model_dir: str):
13
+ pass
14
+
15
+ def create(self, pipe):
16
+ pass
17
+
18
+
19
+ class Text2Img(AbstractPipeline):
20
+ def load(self, model_dir: str):
21
+ self.pipe = two_step_pipeline.from_pretrained(
22
+ model_dir, torch_dtype=torch.float16
23
+ ).to("cuda")
24
+ self.__patch()
25
+
26
+ def create(self, pipeline: AbstractPipeline):
27
+ self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
28
+ self.__patch()
29
+
30
+ def __patch(self):
31
+ self.pipe.enable_xformers_memory_efficient_attention()
32
+
33
+ @torch.inference_mode()
34
+ def process(
35
+ self,
36
+ prompt: Union[str, List[str]] = None,
37
+ modified_prompts: Union[str, List[str]] = None,
38
+ height: Optional[int] = None,
39
+ width: Optional[int] = None,
40
+ num_inference_steps: int = 50,
41
+ guidance_scale: float = 7.5,
42
+ negative_prompt: Optional[Union[str, List[str]]] = None,
43
+ num_images_per_prompt: Optional[int] = 1,
44
+ eta: float = 0.0,
45
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
46
+ latents: Optional[torch.FloatTensor] = None,
47
+ prompt_embeds: Optional[torch.FloatTensor] = None,
48
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
49
+ output_type: Optional[str] = "pil",
50
+ return_dict: bool = True,
51
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
52
+ callback_steps: int = 1,
53
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
54
+ iteration: float = 3.0,
55
+ ):
56
+ result = self.pipe.two_step_pipeline(
57
+ prompt=prompt,
58
+ modified_prompts=modified_prompts,
59
+ height=height,
60
+ width=width,
61
+ num_inference_steps=num_inference_steps,
62
+ guidance_scale=guidance_scale,
63
+ negative_prompt=negative_prompt,
64
+ num_images_per_prompt=num_images_per_prompt,
65
+ eta=eta,
66
+ generator=generator,
67
+ latents=latents,
68
+ prompt_embeds=prompt_embeds,
69
+ negative_prompt_embeds=negative_prompt_embeds,
70
+ output_type=output_type,
71
+ return_dict=return_dict,
72
+ callback=callback,
73
+ callback_steps=callback_steps,
74
+ cross_attention_kwargs=cross_attention_kwargs,
75
+ iteration=iteration,
76
+ )
77
+ return Result.from_result(result)
78
+
79
+
80
+ class Img2Img(AbstractPipeline):
81
+ def load(self, model_dir: str):
82
+ self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
83
+ model_dir, torch_dtype=torch.float16
84
+ ).to("cuda")
85
+ self.__patch()
86
+
87
+ def create(self, pipeline: AbstractPipeline):
88
+ self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to(
89
+ "cuda"
90
+ )
91
+ self.__patch()
92
+
93
+ def __patch(self):
94
+ self.pipe.enable_xformers_memory_efficient_attention()
95
+
96
+ @torch.inference_mode()
97
+ def process(
98
+ self,
99
+ prompt: List[str],
100
+ imageUrl: str,
101
+ negative_prompt: List[str],
102
+ strength: float,
103
+ guidance_scale: float,
104
+ steps: int,
105
+ width: int,
106
+ height: int,
107
+ ):
108
+ image = download_image(imageUrl).resize((width, height))
109
+
110
+ result = self.pipe.__call__(
111
+ prompt=prompt,
112
+ image=image,
113
+ strength=strength,
114
+ negative_prompt=negative_prompt,
115
+ guidance_scale=guidance_scale,
116
+ num_images_per_prompt=1,
117
+ num_inference_steps=steps,
118
+ )
119
+ return Result.from_result(result)
internals/pipelines/controlnets.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from controlnet_aux import OpenposeDetector
7
+ from diffusers import (
8
+ ControlNetModel,
9
+ DiffusionPipeline,
10
+ StableDiffusionControlNetPipeline,
11
+ UniPCMultistepScheduler,
12
+ )
13
+ from PIL import Image
14
+ from tqdm import gui
15
+
16
+ from internals.data.result import Result
17
+ from internals.pipelines.commons import AbstractPipeline
18
+ from internals.util.cache import clear_cuda_and_gc
19
+ from internals.util.commons import download_image
20
+
21
+
22
+ class ControlNet(AbstractPipeline):
23
+ __current_task_name = ""
24
+
25
+ def load(self, model_dir: str):
26
+ # we will load canny by default
27
+ self.load_canny()
28
+
29
+ # controlnet pipeline for canny and pose
30
+ pipe = DiffusionPipeline.from_pretrained(
31
+ model_dir,
32
+ controlnet=self.controlnet,
33
+ torch_dtype=torch.float16,
34
+ custom_pipeline="stable_diffusion_controlnet_img2img",
35
+ ).to("cuda")
36
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
37
+ pipe.enable_model_cpu_offload()
38
+ pipe.enable_xformers_memory_efficient_attention()
39
+ self.pipe = pipe
40
+
41
+ # controlnet pipeline for tile upscaler
42
+ pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda")
43
+ pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config)
44
+ pipe2.enable_xformers_memory_efficient_attention()
45
+ self.pipe2 = pipe2
46
+
47
+ def load_canny(self):
48
+ if self.__current_task_name == "canny":
49
+ return
50
+ canny = ControlNetModel.from_pretrained(
51
+ "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16
52
+ ).to("cuda")
53
+ self.__current_task_name = "canny"
54
+ self.controlnet = canny
55
+ if hasattr(self, "pipe"):
56
+ self.pipe.controlnet = canny
57
+ if hasattr(self, "pipe2"):
58
+ self.pipe2.controlnet = canny
59
+ clear_cuda_and_gc()
60
+
61
+ def load_pose(self):
62
+ if self.__current_task_name == "pose":
63
+ return
64
+ pose = ControlNetModel.from_pretrained(
65
+ "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16
66
+ ).to("cuda")
67
+ self.__current_task_name = "pose"
68
+ self.controlnet = pose
69
+ if hasattr(self, "pipe"):
70
+ self.pipe.controlnet = pose
71
+ if hasattr(self, "pipe2"):
72
+ self.pipe2.controlnet = pose
73
+ clear_cuda_and_gc()
74
+
75
+ def load_tile_upscaler(self):
76
+ if self.__current_task_name == "tile_upscaler":
77
+ return
78
+ tile_upscaler = ControlNetModel.from_pretrained(
79
+ "lllyasviel/control_v11f1e_sd15_tile", torch_dtype=torch.float16
80
+ ).to("cuda")
81
+ self.__current_task_name = "tile_upscaler"
82
+ self.controlnet = tile_upscaler
83
+ if hasattr(self, "pipe"):
84
+ self.pipe.controlnet = tile_upscaler
85
+ if hasattr(self, "pipe2"):
86
+ self.pipe2.controlnet = tile_upscaler
87
+ clear_cuda_and_gc()
88
+
89
+ def cleanup(self):
90
+ self.pipe.controlnet = None
91
+ self.pipe2.controlnet = None
92
+ self.controlnet = None
93
+ self.__current_task_name = ""
94
+
95
+ clear_cuda_and_gc()
96
+
97
+ @torch.inference_mode()
98
+ def process_canny(
99
+ self,
100
+ prompt: List[str],
101
+ imageUrl: str,
102
+ seed: int,
103
+ steps: int,
104
+ negative_prompt: List[str],
105
+ guidance_scale: float,
106
+ height: int,
107
+ width: int,
108
+ ):
109
+ if self.__current_task_name != "canny":
110
+ raise Exception("ControlNet is not loaded with canny model")
111
+
112
+ torch.manual_seed(seed)
113
+
114
+ init_image = download_image(imageUrl).resize((width, height))
115
+ init_image = self.__canny_detect_edge(init_image)
116
+
117
+ result = self.pipe2.__call__(
118
+ prompt=prompt,
119
+ image=init_image,
120
+ guidance_scale=guidance_scale,
121
+ num_images_per_prompt=1,
122
+ negative_prompt=negative_prompt,
123
+ num_inference_steps=steps,
124
+ height=height,
125
+ width=width,
126
+ )
127
+ return Result.from_result(result)
128
+
129
+ @torch.inference_mode()
130
+ def process_pose(
131
+ self,
132
+ prompt: List[str],
133
+ image: List[Image.Image],
134
+ seed: int,
135
+ steps: int,
136
+ guidance_scale: float,
137
+ negative_prompt: List[str],
138
+ height: int,
139
+ width: int,
140
+ ):
141
+ if self.__current_task_name != "pose":
142
+ raise Exception("ControlNet is not loaded with pose model")
143
+
144
+ torch.manual_seed(seed)
145
+
146
+ result = self.pipe2.__call__(
147
+ prompt=prompt,
148
+ image=image,
149
+ num_images_per_prompt=1,
150
+ num_inference_steps=steps,
151
+ negative_prompt=negative_prompt,
152
+ guidance_scale=guidance_scale,
153
+ height=height,
154
+ width=width,
155
+ )
156
+ return Result.from_result(result)
157
+
158
+ @torch.inference_mode()
159
+ def process_tile_upscaler(
160
+ self,
161
+ imageUrl: str,
162
+ prompt: str,
163
+ negative_prompt: str,
164
+ steps: int,
165
+ seed: int,
166
+ height: int,
167
+ width: int,
168
+ resize_dimension: int,
169
+ guidance_scale: float,
170
+ ):
171
+ if self.__current_task_name != "tile_upscaler":
172
+ raise Exception("ControlNet is not loaded with tile_upscaler model")
173
+
174
+ torch.manual_seed(seed)
175
+
176
+ init_image = download_image(imageUrl).resize((width, height))
177
+ condition_image = self.__resize_for_condition_image(
178
+ init_image, resize_dimension
179
+ )
180
+
181
+ result = self.pipe.__call__(
182
+ image=condition_image,
183
+ prompt=prompt,
184
+ controlnet_conditioning_image=condition_image,
185
+ num_inference_steps=steps,
186
+ negative_prompt=negative_prompt,
187
+ height=condition_image.size[1],
188
+ width=condition_image.size[0],
189
+ strength=1.0,
190
+ guidance_scale=guidance_scale,
191
+ )
192
+ return Result.from_result(result)
193
+
194
+ def detect_pose(self, imageUrl: str) -> Image.Image:
195
+ detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
196
+ image = download_image(imageUrl)
197
+ image = detector.__call__(image, hand_and_face=True)
198
+ return image
199
+
200
+ def __canny_detect_edge(self, image: Image.Image) -> Image.Image:
201
+ image_array = np.array(image)
202
+
203
+ low_threshold = 100
204
+ high_threshold = 200
205
+
206
+ image_array = cv2.Canny(image_array, low_threshold, high_threshold)
207
+ image_array = image_array[:, :, None]
208
+ image_array = np.concatenate([image_array, image_array, image_array], axis=2)
209
+ canny_image = Image.fromarray(image_array)
210
+ return canny_image
211
+
212
+ def __resize_for_condition_image(self, image: Image.Image, resolution: int):
213
+ input_image = image.convert("RGB")
214
+ W, H = input_image.size
215
+ k = float(resolution) / min(W, H)
216
+ H *= k
217
+ W *= k
218
+ H = int(round(H / 64.0)) * 64
219
+ W = int(round(W / 64.0)) * 64
220
+ img = input_image.resize((W, H), resample=Image.LANCZOS)
221
+ return img
internals/pipelines/img_classifier.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from transformers import pipeline
4
+
5
+ from internals.util.commons import download_image
6
+
7
+
8
+ class ImageClassifier:
9
+ def __init__(self, candidates: List[str] = ["realistic", "anime", "comic"]):
10
+ self.__candidates = candidates
11
+
12
+ def load(self):
13
+ self.pipe = pipeline(
14
+ "zero-shot-image-classification",
15
+ model="philschmid/clip-zero-shot-image-classification",
16
+ )
17
+
18
+ def classify(self, image_url: str, width: int, height: int) -> str:
19
+ image = download_image(image_url).resize((width, height))
20
+ results = self.pipe.__call__([image], candidate_labels=self.__candidates)
21
+ results = results[0]
22
+ if len(results) > 0:
23
+ return results[0]["label"]
24
+ return ""
internals/pipelines/img_to_text.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import torch
4
+ from torchvision import transforms
5
+ from transformers import BlipForConditionalGeneration, BlipProcessor
6
+
7
+ from internals.util.commons import download_image
8
+
9
+
10
+ class Image2Text:
11
+ def load(self):
12
+ self.processor = BlipProcessor.from_pretrained(
13
+ "Salesforce/blip-image-captioning-large"
14
+ )
15
+ self.model = BlipForConditionalGeneration.from_pretrained(
16
+ "Salesforce/blip-image-captioning-large", torch_dtype=torch.float16
17
+ ).to("cuda")
18
+
19
+ def process(self, imageUrl: str) -> str:
20
+ image = download_image(imageUrl).resize((512, 512))
21
+ inputs = self.processor.__call__(image, return_tensors="pt").to(
22
+ "cuda", torch.float16
23
+ )
24
+ output_ids = self.model.generate(
25
+ **inputs, do_sample=False, top_p=0.9, max_length=128
26
+ )
27
+ output_text = self.processor.batch_decode(output_ids)
28
+ print(output_text)
29
+ output_text = output_text[0]
30
+ output_text = re.sub("</.>|\\n|\[SEP\]", "", output_text)
31
+ return output_text
internals/pipelines/inpainter.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import torch
4
+ from diffusers import StableDiffusionInpaintPipeline
5
+
6
+ from internals.pipelines.commons import AbstractPipeline
7
+ from internals.util.commons import disable_safety_checker, download_image
8
+
9
+
10
+ class InPainter(AbstractPipeline):
11
+ def load(self):
12
+ self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
13
+ "jayparmr/icbinp_v8_inpaint_v2",
14
+ torch_dtype=torch.float16,
15
+ ).to("cuda")
16
+ disable_safety_checker(self.pipe)
17
+
18
+ @torch.inference_mode()
19
+ def process(
20
+ self,
21
+ image_url: str,
22
+ mask_image_url: str,
23
+ width: int,
24
+ height: int,
25
+ seed: int,
26
+ prompt: Union[str, List[str]],
27
+ negative_prompt: Union[str, List[str]],
28
+ ):
29
+ torch.manual_seed(seed)
30
+
31
+ input_img = download_image(image_url).resize((width, height))
32
+ mask_img = download_image(mask_image_url).resize((width, height))
33
+
34
+ return self.pipe.__call__(
35
+ prompt=prompt,
36
+ image=input_img,
37
+ mask_image=mask_img,
38
+ height=height,
39
+ width=width,
40
+ negative_prompt=negative_prompt,
41
+ ).images
internals/pipelines/object_remove.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+ from omegaconf import OmegaConf
10
+ from PIL import Image
11
+ from torch.utils.data._utils.collate import default_collate
12
+
13
+ from internals.util.commons import download_file, download_image
14
+ from internals.util.config import get_root_dir
15
+ from saicinpainting.evaluation.utils import move_to_device
16
+ from saicinpainting.training.data.datasets import make_default_val_dataset
17
+ from saicinpainting.training.trainers import load_checkpoint
18
+
19
+
20
+ class ObjectRemoval:
21
+ def load(self, model_dir):
22
+ print("Downloading LAMA model...")
23
+
24
+ self.lama_path = Path.home() / ".cache" / "lama"
25
+
26
+ out_file = self.lama_path / "models" / "best.ckpt"
27
+ os.makedirs(os.path.dirname(out_file), exist_ok=True)
28
+ download_file(
29
+ "https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt", out_file
30
+ )
31
+ config = OmegaConf.load(get_root_dir() + "/config.yml")
32
+ config.training_model.predict_only = True
33
+ self.model = load_checkpoint(
34
+ config, str(out_file), strict=False, map_location="cuda"
35
+ )
36
+ self.model.freeze()
37
+ self.model.to("cuda")
38
+
39
+ @torch.no_grad()
40
+ def process(
41
+ self,
42
+ image_url: str,
43
+ mask_image_url: str,
44
+ seed: int,
45
+ width: int,
46
+ height: int,
47
+ ) -> List:
48
+ torch.manual_seed(seed)
49
+
50
+ img_folder = self.lama_path / "images"
51
+ indir = img_folder / "input"
52
+
53
+ img_folder.mkdir(parents=True, exist_ok=True)
54
+ indir.mkdir(parents=True, exist_ok=True)
55
+
56
+ download_image(image_url).resize((width, height)).save(indir / "data.png")
57
+ download_image(mask_image_url).resize((width, height)).save(
58
+ indir / "data_mask.png"
59
+ )
60
+
61
+ dataset = make_default_val_dataset(
62
+ img_folder / "input", img_suffix=".png", pad_out_to_modulo=8
63
+ )
64
+
65
+ out_images = []
66
+ for img_i in tqdm.trange(len(dataset)):
67
+ batch = move_to_device(default_collate([dataset[img_i]]), "cuda")
68
+ batch["mask"] = (batch["mask"] > 0) * 1
69
+ batch = self.model(batch)
70
+ out_path = str(img_folder / "out.png")
71
+
72
+ cur_res = batch["inpainted"][0].permute(1, 2, 0).detach().cpu().numpy()
73
+
74
+ cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
75
+ cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
76
+ cv2.imwrite(out_path, cur_res)
77
+
78
+ image = Image.open(out_path).convert("RGB")
79
+ out_images.append(image)
80
+ os.remove(out_path)
81
+
82
+ return out_images
internals/pipelines/prompt_modifier.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
4
+
5
+
6
+ class PromptModifier:
7
+ def __init__(self, num_of_sequences: Optional[int] = 4):
8
+ self.__blacklist = {"alphonse mucha": "", "adolphe bouguereau": ""}
9
+ self.__num_of_sequences = num_of_sequences
10
+
11
+ def load(self):
12
+ self.prompter_model = AutoModelForCausalLM.from_pretrained(
13
+ "Gustavosta/MagicPrompt-Stable-Diffusion"
14
+ )
15
+ self.prompter_tokenizer = AutoTokenizer.from_pretrained(
16
+ "Gustavosta/MagicPrompt-Stable-Diffusion"
17
+ )
18
+ self.prompter_tokenizer.pad_token = self.prompter_tokenizer.eos_token
19
+ self.prompter_tokenizer.padding_side = "left"
20
+
21
+ def modify(self, text: str) -> List[str]:
22
+ eos_id = self.prompter_tokenizer.eos_token_id
23
+ # restricted_words_list = ["octane", "cyber"]
24
+ # restricted_words_token_ids = prompter_tokenizer(
25
+ # restricted_words_list, add_special_tokens=False
26
+ # ).input_ids
27
+
28
+ generation_config = GenerationConfig(
29
+ do_sample=False,
30
+ max_new_tokens=75,
31
+ num_beams=4,
32
+ num_return_sequences=self.__num_of_sequences,
33
+ eos_token_id=eos_id,
34
+ pad_token_id=eos_id,
35
+ length_penalty=-1.0,
36
+ )
37
+
38
+ input_ids = self.prompter_tokenizer(text.strip(), return_tensors="pt").input_ids
39
+ outputs = self.prompter_model.generate(
40
+ input_ids, generation_config=generation_config
41
+ )
42
+ output_texts = self.prompter_tokenizer.batch_decode(
43
+ outputs, skip_special_tokens=True
44
+ )
45
+ output_texts = self.__patch_blacklist_words(output_texts)
46
+ return output_texts
47
+
48
+ def __patch_blacklist_words(self, texts: List[str]):
49
+ def replace_all(text, dic):
50
+ for i, j in dic.items():
51
+ text = text.replace(i, j)
52
+ return text
53
+
54
+ return [replace_all(text, self.__blacklist) for text in texts]
internals/pipelines/remove_background.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import Union
3
+
4
+ from PIL import Image
5
+ from rembg import remove
6
+
7
+ from internals.util.commons import read_url
8
+
9
+
10
+ class RemoveBackground:
11
+ def remove(self, image: Union[str, Image.Image]) -> Image.Image:
12
+ if type(image) is str:
13
+ image = Image.open(io.BytesIO(read_url(image)))
14
+
15
+ output = remove(image)
16
+ return output
internals/pipelines/safety_checker.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from re import L
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
8
+
9
+ from internals.pipelines.commons import AbstractPipeline
10
+ from internals.util.config import get_nsfw_access, get_nsfw_threshold
11
+
12
+
13
+ def cosine_distance(image_embeds, text_embeds):
14
+ normalized_image_embeds = nn.functional.normalize(image_embeds)
15
+ normalized_text_embeds = nn.functional.normalize(text_embeds)
16
+ return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
17
+
18
+
19
+ class SafetyChecker:
20
+ def load(self):
21
+ self.model = StableDiffusionSafetyCheckerV2.from_pretrained(
22
+ "CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16
23
+ ).to("cuda")
24
+
25
+ def apply(self, pipeline: AbstractPipeline):
26
+ if hasattr(pipeline, "pipe"):
27
+ pipeline.pipe.safety_checker = self.model
28
+ if hasattr(pipeline, "pipe2"):
29
+ pipeline.pipe2.safety_checker = self.model
30
+
31
+
32
+ class StableDiffusionSafetyCheckerV2(PreTrainedModel):
33
+ config_class = CLIPConfig
34
+
35
+ _no_split_modules = ["CLIPEncoderLayer"]
36
+
37
+ def __init__(self, config: CLIPConfig):
38
+ super().__init__(config)
39
+
40
+ self.vision_model = CLIPVisionModel(config.vision_config)
41
+ self.visual_projection = nn.Linear(
42
+ config.vision_config.hidden_size, config.projection_dim, bias=False
43
+ )
44
+
45
+ self.concept_embeds = nn.Parameter(
46
+ torch.ones(17, config.projection_dim), requires_grad=False
47
+ )
48
+ self.special_care_embeds = nn.Parameter(
49
+ torch.ones(3, config.projection_dim), requires_grad=False
50
+ )
51
+
52
+ self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
53
+ self.special_care_embeds_weights = nn.Parameter(
54
+ torch.ones(3), requires_grad=False
55
+ )
56
+
57
+ @torch.no_grad()
58
+ def forward(self, clip_input, images):
59
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
60
+ image_embeds = self.visual_projection(pooled_output)
61
+
62
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
63
+ special_cos_dist = (
64
+ cosine_distance(image_embeds, self.special_care_embeds)
65
+ .cpu()
66
+ .float()
67
+ .numpy()
68
+ )
69
+ cos_dist = (
70
+ cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
71
+ )
72
+
73
+ result = []
74
+ batch_size = image_embeds.shape[0]
75
+ for i in range(batch_size):
76
+ result_img = {
77
+ "special_scores": {},
78
+ "special_care": [],
79
+ "concept_scores": {},
80
+ "bad_concepts": [],
81
+ }
82
+
83
+ # increase this value to create a stronger `nfsw` filter
84
+ # at the cost of increasing the possibility of filtering benign images
85
+ adjustment = 0.0
86
+
87
+ for concept_idx in range(len(special_cos_dist[0])):
88
+ concept_cos = special_cos_dist[i][concept_idx]
89
+ concept_threshold = self.special_care_embeds_weights[concept_idx].item()
90
+ result_img["special_scores"][concept_idx] = round(
91
+ concept_cos - concept_threshold + adjustment, 3
92
+ )
93
+ if result_img["special_scores"][concept_idx] > 0:
94
+ result_img["special_care"].append(
95
+ {concept_idx, result_img["special_scores"][concept_idx]}
96
+ )
97
+ adjustment = 0.01
98
+
99
+ for concept_idx in range(len(cos_dist[0])):
100
+ concept_cos = cos_dist[i][concept_idx]
101
+ concept_threshold = self.concept_embeds_weights[concept_idx].item()
102
+ result_img["concept_scores"][concept_idx] = round(
103
+ concept_cos - concept_threshold + adjustment, 3
104
+ )
105
+ if result_img["concept_scores"][concept_idx] > get_nsfw_threshold():
106
+ result_img["bad_concepts"].append(concept_idx)
107
+
108
+ result.append(result_img)
109
+
110
+ has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
111
+
112
+ # Blur images based on NSFW score
113
+ # -------------------------------
114
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
115
+ if any(has_nsfw_concepts) and not get_nsfw_access():
116
+ if torch.is_tensor(images) or torch.is_tensor(images[0]):
117
+ image = images[idx].cpu().numpy().astype(np.float32)
118
+ image = cv2.blur(image, (30, 30))
119
+ image = torch.from_numpy(image)
120
+ images[idx] = image
121
+ else:
122
+ images[idx] = cv2.blur(images[idx], (30, 30))
123
+
124
+ if any(has_nsfw_concepts):
125
+ print("NSFW")
126
+
127
+ return images, has_nsfw_concepts
128
+
129
+ @torch.no_grad()
130
+ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
131
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
132
+ image_embeds = self.visual_projection(pooled_output)
133
+
134
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
135
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds)
136
+
137
+ # increase this value to create a stronger `nsfw` filter
138
+ # at the cost of increasing the possibility of filtering benign images
139
+ adjustment = 0.0
140
+
141
+ special_scores = (
142
+ special_cos_dist - self.special_care_embeds_weights + adjustment
143
+ )
144
+ # special_scores = special_scores.round(decimals=3)
145
+ special_care = torch.any(special_scores > 0, dim=1)
146
+ special_adjustment = special_care * 0.01
147
+ special_adjustment = special_adjustment.unsqueeze(1).expand(
148
+ -1, cos_dist.shape[1]
149
+ )
150
+
151
+ concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
152
+ # concept_scores = concept_scores.round(decimals=3)
153
+ has_nsfw_concepts = torch.any(concept_scores > get_nsfw_threshold(), dim=1)
154
+
155
+ # Blur images based on NSFW score
156
+ # -------------------------------
157
+ if not get_nsfw_access():
158
+ image = images[has_nsfw_concepts].cpu().numpy().astype(np.float32)
159
+ image = cv2.blur(image, (30, 30))
160
+ image = torch.from_numpy(image)
161
+ images[has_nsfw_concepts] = image
162
+
163
+ return images, has_nsfw_concepts
internals/pipelines/twoStepPipeline.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionPipeline
3
+
4
+ torch.backends.cudnn.benchmark = True
5
+ torch.backends.cuda.matmul.allow_tf32 = True
6
+
7
+ from typing import Any, Callable, Dict, List, Optional, Union
8
+
9
+ from diffusers import StableDiffusionPipeline
10
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
11
+
12
+
13
+ class two_step_pipeline(StableDiffusionPipeline):
14
+ @torch.no_grad()
15
+ def two_step_pipeline(
16
+ self,
17
+ prompt: Union[str, List[str]] = None,
18
+ modified_prompts: Union[str, List[str]] = None,
19
+ height: Optional[int] = None,
20
+ width: Optional[int] = None,
21
+ num_inference_steps: int = 50,
22
+ guidance_scale: float = 7.5,
23
+ negative_prompt: Optional[Union[str, List[str]]] = None,
24
+ num_images_per_prompt: Optional[int] = 1,
25
+ eta: float = 0.0,
26
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
27
+ latents: Optional[torch.FloatTensor] = None,
28
+ prompt_embeds: Optional[torch.FloatTensor] = None,
29
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
30
+ output_type: Optional[str] = "pil",
31
+ return_dict: bool = True,
32
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
33
+ callback_steps: int = 1,
34
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
35
+ iteration: float = 3.0,
36
+ ):
37
+ r"""
38
+ Function invoked when calling the pipeline for generation.
39
+ Args:
40
+ prompt (`str` or `List[str]`, *optional*):
41
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
42
+ instead.
43
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
44
+ The height in pixels of the generated image.
45
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
46
+ The width in pixels of the generated image.
47
+ num_inference_steps (`int`, *optional*, defaults to 50):
48
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
49
+ expense of slower inference.
50
+ guidance_scale (`float`, *optional*, defaults to 7.5):
51
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
52
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
53
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
54
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
55
+ usually at the expense of lower image quality.
56
+ negative_prompt (`str` or `List[str]`, *optional*):
57
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
58
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
59
+ less than `1`).
60
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
61
+ The number of images to generate per prompt.
62
+ eta (`float`, *optional*, defaults to 0.0):
63
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
64
+ [`schedulers.DDIMScheduler`], will be ignored for others.
65
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
66
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
67
+ to make generation deterministic.
68
+ latents (`torch.FloatTensor`, *optional*):
69
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
70
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
71
+ tensor will ge generated by sampling using the supplied random `generator`.
72
+ prompt_embeds (`torch.FloatTensor`, *optional*):
73
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
74
+ provided, text embeddings will be generated from `prompt` input argument.
75
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
76
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
77
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
78
+ argument.
79
+ output_type (`str`, *optional*, defaults to `"pil"`):
80
+ The output format of the generate image. Choose between
81
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
82
+ return_dict (`bool`, *optional*, defaults to `True`):
83
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
84
+ plain tuple.
85
+ callback (`Callable`, *optional*):
86
+ A function that will be called every `callback_steps` steps during inference. The function will be
87
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
88
+ callback_steps (`int`, *optional*, defaults to 1):
89
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
90
+ called at every step.
91
+ cross_attention_kwargs (`dict`, *optional*):
92
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
93
+ `self.processor` in
94
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
95
+ Examples:
96
+ Returns:
97
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
98
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
99
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
100
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
101
+ (nsfw) content, according to the `safety_checker`.
102
+ """
103
+ # 0. Default height and width to unet
104
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
105
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
106
+
107
+ # 1. Check inputs. Raise error if not correct
108
+ self.check_inputs(
109
+ prompt,
110
+ height,
111
+ width,
112
+ callback_steps,
113
+ negative_prompt,
114
+ prompt_embeds,
115
+ negative_prompt_embeds,
116
+ )
117
+
118
+ # 2. Define call parameters
119
+ if prompt is not None and isinstance(prompt, str):
120
+ batch_size = 1
121
+ elif prompt is not None and isinstance(prompt, list):
122
+ batch_size = len(prompt)
123
+ else:
124
+ batch_size = prompt_embeds.shape[0]
125
+
126
+ device = self._execution_device
127
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
128
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
129
+ # corresponds to doing no classifier free guidance.
130
+ do_classifier_free_guidance = guidance_scale > 1.0
131
+
132
+ # 3. Encode input prompt
133
+ modified_embeds = self._encode_prompt(
134
+ modified_prompts,
135
+ device,
136
+ num_images_per_prompt,
137
+ do_classifier_free_guidance,
138
+ negative_prompt,
139
+ prompt_embeds=prompt_embeds,
140
+ negative_prompt_embeds=negative_prompt_embeds,
141
+ )
142
+ print("mod prompt size : ", modified_embeds.size(), modified_embeds.dtype)
143
+
144
+ prompt_embeds = self._encode_prompt(
145
+ prompt,
146
+ device,
147
+ num_images_per_prompt,
148
+ do_classifier_free_guidance,
149
+ negative_prompt,
150
+ prompt_embeds=prompt_embeds,
151
+ negative_prompt_embeds=negative_prompt_embeds,
152
+ )
153
+
154
+ print("prompt size : ", prompt_embeds.size(), prompt_embeds.dtype)
155
+
156
+ # 4. Prepare timesteps
157
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
158
+ timesteps = self.scheduler.timesteps
159
+
160
+ # 5. Prepare latent variables
161
+ num_channels_latents = self.unet.config.in_channels
162
+ latents = self.prepare_latents(
163
+ batch_size * num_images_per_prompt,
164
+ num_channels_latents,
165
+ height,
166
+ width,
167
+ prompt_embeds.dtype,
168
+ device,
169
+ generator,
170
+ latents,
171
+ )
172
+
173
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
174
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
175
+
176
+ # 7. Denoising loop
177
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
178
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
179
+ for i, t in enumerate(timesteps):
180
+ # expand the latents if we are doing classifier free guidance
181
+ latent_model_input = (
182
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
183
+ )
184
+ latent_model_input = self.scheduler.scale_model_input(
185
+ latent_model_input, t
186
+ )
187
+
188
+ # predict the noise residual
189
+ noise_pred = self.unet(
190
+ latent_model_input,
191
+ t,
192
+ encoder_hidden_states=prompt_embeds,
193
+ cross_attention_kwargs=cross_attention_kwargs,
194
+ ).sample
195
+
196
+ # perform guidance
197
+ if do_classifier_free_guidance:
198
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
199
+ noise_pred = noise_pred_uncond + guidance_scale * (
200
+ noise_pred_text - noise_pred_uncond
201
+ )
202
+
203
+ # compute the previous noisy sample x_t -> x_t-1
204
+ latents = self.scheduler.step(
205
+ noise_pred, t, latents, **extra_step_kwargs
206
+ ).prev_sample
207
+
208
+ # call the callback, if provided
209
+ if i == len(timesteps) - 1 or (
210
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
211
+ ):
212
+ progress_bar.update()
213
+ if callback is not None and i % callback_steps == 0:
214
+ callback(i, t, latents)
215
+
216
+ if i == int(len(timesteps) / iteration):
217
+ print("modified prompts")
218
+ prompt_embeds = modified_embeds
219
+
220
+ if output_type == "latent":
221
+ image = latents
222
+ has_nsfw_concept = None
223
+ elif output_type == "pil":
224
+ # 8. Post-processing
225
+ image = self.decode_latents(latents)
226
+
227
+ # 9. Run safety checker
228
+ image, has_nsfw_concept = self.run_safety_checker(
229
+ image, device, prompt_embeds.dtype
230
+ )
231
+
232
+ # 10. Convert to PIL
233
+ image = self.numpy_to_pil(image)
234
+ else:
235
+ # 8. Post-processing
236
+ image = self.decode_latents(latents)
237
+
238
+ # 9. Run safety checker
239
+ image, has_nsfw_concept = self.run_safety_checker(
240
+ image, device, prompt_embeds.dtype
241
+ )
242
+
243
+ # Offload last model to CPU
244
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
245
+ self.final_offload_hook.offload()
246
+
247
+ if not return_dict:
248
+ return (image, has_nsfw_concept)
249
+
250
+ return StableDiffusionPipelineOutput(
251
+ images=image, nsfw_content_detected=has_nsfw_concept
252
+ )
internals/pipelines/upscaler.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from basicsr.archs.rrdbnet_arch import RRDBNet
9
+ from basicsr.utils.download_util import load_file_from_url
10
+ from PIL import Image
11
+ from realesrgan import RealESRGANer
12
+
13
+ import internals.util.image as ImageUtil
14
+ from internals.util.commons import download_image
15
+
16
+
17
+ class Upscaler:
18
+ __model_esrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
19
+ __model_esrgan_anime_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
20
+
21
+ def load(self):
22
+ download_dir = Path(Path.home() / ".cache" / "realesrgan")
23
+ download_dir.mkdir(parents=True, exist_ok=True)
24
+
25
+ self.__model_path = self.__preload_model(self.__model_esrgan_url, download_dir)
26
+ self.__model_path_anime = self.__preload_model(
27
+ self.__model_esrgan_anime_url, download_dir
28
+ )
29
+
30
+ def upscale(self, image: Union[str, Image.Image], resize_dimension: int) -> bytes:
31
+ model = RRDBNet(
32
+ num_in_ch=3,
33
+ num_out_ch=3,
34
+ num_feat=64,
35
+ num_block=23,
36
+ num_grow_ch=32,
37
+ scale=4,
38
+ )
39
+ return self.__internal_upscale(
40
+ image, resize_dimension, self.__model_path, model
41
+ )
42
+
43
+ def upscale_anime(
44
+ self, image: Union[str, Image.Image], resize_dimension: int
45
+ ) -> bytes:
46
+ model = RRDBNet(
47
+ num_in_ch=3,
48
+ num_out_ch=3,
49
+ num_feat=64,
50
+ num_block=23,
51
+ num_grow_ch=32,
52
+ scale=4,
53
+ )
54
+ return self.__internal_upscale(
55
+ image, resize_dimension, self.__model_path_anime, model
56
+ )
57
+
58
+ def __preload_model(self, url: str, download_dir: Path):
59
+ name = url.split("/")[-1]
60
+ if not os.path.exists(str(download_dir / name)):
61
+ return load_file_from_url(
62
+ url=url,
63
+ model_dir=str(download_dir),
64
+ progress=True,
65
+ file_name=None,
66
+ )
67
+ else:
68
+ return str(download_dir / name)
69
+
70
+ def __internal_upscale(
71
+ self,
72
+ image,
73
+ resize_dimension: int,
74
+ model_path: str,
75
+ rrbdnet: RRDBNet,
76
+ ) -> bytes:
77
+ if type(image) is str:
78
+ image = download_image(image)
79
+ image = ImageUtil.resize_image_to512(image)
80
+ image = ImageUtil.to_bytes(image)
81
+
82
+ upsampler = RealESRGANer(
83
+ scale=4, model_path=model_path, model=rrbdnet, half="fp16", gpu_id="0"
84
+ )
85
+ image_array = np.frombuffer(image, dtype=np.uint8)
86
+ input_image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
87
+ dimension = min(input_image.shape[0], input_image.shape[1])
88
+ scale = max(math.floor(resize_dimension / dimension), 2)
89
+ output, _ = upsampler.enhance(input_image, outscale=scale)
90
+ out_bytes = cv2.imencode(".png", output)[1].tobytes()
91
+ return out_bytes
internals/util/__init__.py ADDED
File without changes
internals/util/args.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Dict
3
+
4
+
5
+ def apply_style_args(data: Dict):
6
+ prompt = data.get("prompt", None)
7
+ if prompt is None:
8
+ return
9
+ result = re.match(r"\[style:(.*?)\]", prompt)
10
+ if result is not None:
11
+ style = result.group(1)
12
+ data["style"] = style
13
+ data["prompt"] = prompt.replace(f"[style:{style}]", "").strip()
internals/util/avatar.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+
5
+ from internals.data.dataAccessor import getCharacters
6
+ from internals.util.config import root_dir
7
+
8
+
9
+ class Avatar:
10
+ __avatars = {}
11
+
12
+ def load_local(self):
13
+ self.__find_available_characters(root_dir)
14
+ if len(self.__avatars.items()) > 0:
15
+ print("Local characters", self.__avatars)
16
+
17
+ def fetch_from_network(self, model_id: int):
18
+ characters = getCharacters(str(model_id))
19
+ if characters is not None:
20
+ for character in characters:
21
+ item = {
22
+ "avatarName": str(character["title"]).lower(),
23
+ "codename": character["tag"],
24
+ "extraPrompt": character["extraData"]["extraPrompt"],
25
+ }
26
+ self.__avatars[item["avatarName"]] = item
27
+
28
+ def add_code_names(self, prompt):
29
+ array_of_objects = self.__avatars.values()
30
+
31
+ for obj in array_of_objects:
32
+ prompt = (
33
+ re.sub(
34
+ r"\b" + obj["avatarName"] + r"\b",
35
+ obj["extraPrompt"],
36
+ prompt,
37
+ flags=re.IGNORECASE,
38
+ )
39
+ + " "
40
+ )
41
+ print(prompt)
42
+ return prompt
43
+
44
+ def __find_available_characters(self, path: str):
45
+ if os.path.exists(path + "/characters.json"):
46
+ print(path)
47
+ try:
48
+ print("Loading characters")
49
+ with open(path + "/characters.json") as f:
50
+ data = json.load(f)
51
+ print("Characters: ", data)
52
+ if "avatarName" in data[0]:
53
+ for item in data:
54
+ self.__avatars[item["avatarName"]] = item
55
+ print("Avatars", self.__avatars)
56
+ else:
57
+ print("Invalid characters.json file")
58
+ except Exception as e:
59
+ print("Error Loading characters", e)
internals/util/cache.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import torch
4
+
5
+
6
+ def clear_cuda_and_gc():
7
+ clear_cuda()
8
+ clear_gc()
9
+
10
+
11
+ def clear_cuda():
12
+ torch.cuda.empty_cache()
13
+
14
+
15
+ def clear_gc():
16
+ gc.collect()
17
+
18
+
19
+ def auto_clear_cuda_and_gc(controlnet):
20
+ def auto_clear_cuda_and_gc_wrapper(func):
21
+ def wrapper(*args, **kwargs):
22
+ try:
23
+ return func(*args, **kwargs)
24
+ except Exception as e:
25
+ controlnet.cleanup()
26
+ clear_cuda_and_gc()
27
+ raise e
28
+
29
+ return wrapper
30
+
31
+ return auto_clear_cuda_and_gc_wrapper
internals/util/commons.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pprint
4
+ import random
5
+ import re
6
+ from io import BytesIO
7
+ from pathlib import Path
8
+ from typing import Union
9
+
10
+ import boto3
11
+ import requests
12
+
13
+ from internals.util.config import api_endpoint, api_headers
14
+
15
+ s3 = boto3.client("s3")
16
+ import io
17
+ import urllib.request
18
+
19
+ from PIL import Image
20
+
21
+ black_list = {"alphonse mucha": "", "adolphe bouguereau": ""}
22
+ pp = pprint.PrettyPrinter(indent=4)
23
+
24
+ webhook_url = (
25
+ "https://hooks.slack.com/services/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"
26
+ )
27
+ error_webhook = (
28
+ "https://hooks.slack.com/services/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"
29
+ )
30
+
31
+ characterSheets = [
32
+ "character+sheets/1.1.png",
33
+ "character+sheets/10.1.png",
34
+ "character+sheets/11.1.png",
35
+ "character+sheets/12.1.png",
36
+ "character+sheets/13.1.png",
37
+ "character+sheets/14.1.png",
38
+ "character+sheets/16.1.png",
39
+ "character+sheets/17.1.png",
40
+ "character+sheets/18.1.png",
41
+ "character+sheets/19.1.png",
42
+ "character+sheets/2.1.png",
43
+ "character+sheets/20.1.png",
44
+ "character+sheets/21.1.png",
45
+ "character+sheets/22.1.png",
46
+ "character+sheets/23.1.png",
47
+ "character+sheets/24.1.png",
48
+ "character+sheets/25.1.png",
49
+ "character+sheets/26.1.png",
50
+ "character+sheets/27.1.png",
51
+ "character+sheets/28.1.png",
52
+ "character+sheets/29.1.png",
53
+ "character+sheets/3.1.png",
54
+ "character+sheets/30.1.png",
55
+ "character+sheets/31.1.png",
56
+ "character+sheets/32.1.png",
57
+ "character+sheets/33.1.png",
58
+ "character+sheets/34.1.png",
59
+ "character+sheets/35.1.png",
60
+ "character+sheets/36.1.png",
61
+ "character+sheets/38.1.png",
62
+ "character+sheets/39.1.png",
63
+ "character+sheets/4.1.png",
64
+ "character+sheets/40.1.png",
65
+ "character+sheets/42.1.png",
66
+ "character+sheets/43.1.png",
67
+ "character+sheets/44.1.png",
68
+ "character+sheets/45.1.png",
69
+ "character+sheets/46.1.png",
70
+ "character+sheets/47.1.png",
71
+ "character+sheets/48.1.png",
72
+ "character+sheets/49.1.png",
73
+ "character+sheets/5.1.png",
74
+ "character+sheets/50.1.png",
75
+ "character+sheets/51.1.png",
76
+ "character+sheets/52.1.png",
77
+ "character+sheets/53.1.png",
78
+ "character+sheets/54.1.png",
79
+ "character+sheets/55.1.png",
80
+ "character+sheets/56.1.png",
81
+ "character+sheets/57.1.png",
82
+ "character+sheets/58.1.png",
83
+ "character+sheets/59.1.png",
84
+ "character+sheets/60.1.png",
85
+ "character+sheets/61.1.png",
86
+ "character+sheets/62.1.png",
87
+ "character+sheets/63.1.png",
88
+ "character+sheets/64.1.png",
89
+ "character+sheets/65.1.png",
90
+ "character+sheets/66.1.png",
91
+ "character+sheets/7.1.png",
92
+ "character+sheets/8.1.png",
93
+ "character+sheets/9.1.png",
94
+ ]
95
+
96
+
97
+ def upload_images(images, processName: str, taskId: str):
98
+ imageUrls = []
99
+ for i, image in enumerate(images):
100
+ img_io = BytesIO()
101
+ image.save(img_io, "JPEG", quality=100)
102
+ img_io.seek(0)
103
+ key = "crecoAI/{}{}_{}.png".format(taskId, processName, i)
104
+ requests.post(
105
+ api_endpoint()
106
+ + "/comic-content/v1.0/upload/crecoai-assets-2?fileName="
107
+ + "{}{}_{}.png".format(taskId, processName, i),
108
+ headers=api_headers(),
109
+ files={"file": ("image.png", img_io, "image/png")},
110
+ )
111
+ # t = s3.put_object(
112
+ # Bucket="comic-assets", Key=key, Body=img_io.getvalue(), ACL="public-read"
113
+ # )
114
+ # print("uploading done to s3", key, t)
115
+ imageUrls.append(
116
+ "https://comic-assets.s3.ap-south-1.amazonaws.com/crecoAI/{}{}_{}.png".format(
117
+ taskId, processName, i
118
+ )
119
+ )
120
+
121
+ print({"promptImages": imageUrls})
122
+
123
+ return imageUrls
124
+
125
+
126
+ def upload_image(image: Union[Image.Image, BytesIO], out_path):
127
+ if type(image) is Image.Image:
128
+ buffer = io.BytesIO()
129
+ image.save(buffer, format="PNG")
130
+ image = buffer
131
+
132
+ image.seek(0)
133
+ requests.post(
134
+ api_endpoint()
135
+ + "/comic-content/v1.0/upload/crecoai-assets-2?fileName="
136
+ + str(out_path).replace("crecoAI/", ""),
137
+ headers=api_headers(),
138
+ files={"file": ("image.png", image, "image/png")},
139
+ )
140
+ # s3.upload_fileobj(image, "comic-assets", out_path, ExtraArgs={"ACL": "public-read"})
141
+ image.close()
142
+
143
+ image_url = "https://comic-assets.s3.ap-south-1.amazonaws.com/" + out_path
144
+ print({"promptImages": image_url})
145
+
146
+ return image_url
147
+
148
+
149
+ def download_image(url) -> Image.Image:
150
+ response = requests.get(url)
151
+ return Image.open(BytesIO(response.content)).convert("RGB")
152
+
153
+
154
+ def download_file(url, out_path: Path):
155
+ with requests.get(url, stream=True) as r:
156
+ r.raise_for_status()
157
+ with open(out_path, "wb") as f:
158
+ for chunk in r.iter_content(chunk_size=8192):
159
+ f.write(chunk)
160
+
161
+
162
+ def pickPoses():
163
+ random_images = random.sample(characterSheets, 4)
164
+ poses = []
165
+ prefix = "https://comic-assets.s3.ap-south-1.amazonaws.com/"
166
+
167
+ # Use list comprehension to add prefix to all elements in the array
168
+ random_images_with_prefix = [prefix + img for img in random_images]
169
+
170
+ print(random_images_with_prefix)
171
+ for imageUrl in random_images_with_prefix:
172
+ # Download and resize the image
173
+ init_image = download_image(imageUrl).resize((512, 512))
174
+
175
+ # Open the pose image
176
+ imageUrlPose = imageUrl
177
+ # print(imageUrl)
178
+ input_image_bytes = read_url(imageUrlPose)
179
+ # print(input_image_bytes)
180
+ pose_image = Image.open(io.BytesIO(input_image_bytes)).convert("RGB")
181
+ # print(pose_image)
182
+ pose_image = pose_image.resize((512, 512))
183
+ # print(pose_image)
184
+ # Append the result to the poses array
185
+ poses.append(pose_image)
186
+
187
+ return poses
188
+
189
+
190
+ def construct_default_s3_url(key):
191
+ return "https://comic-assets.s3.ap-south-1.amazonaws.com/" + key
192
+
193
+
194
+ def read_url(url: str):
195
+ with urllib.request.urlopen(url) as u:
196
+ return u.read()
197
+
198
+
199
+ def disable_safety_checker(pipe):
200
+ def dummy(images, **kwargs):
201
+ return images, False
202
+
203
+ pipe.safety_checker = None
internals/util/config.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from internals.data.task import Task
4
+
5
+ env = "gamma"
6
+ nsfw_threshold = 0.0
7
+ nsfw_access = False
8
+ access_token = ""
9
+ root_dir = ""
10
+
11
+
12
+ def set_root_dir(main_file: str):
13
+ global root_dir
14
+ root_dir = os.path.dirname(os.path.abspath(main_file))
15
+
16
+
17
+ def set_configs_from_task(task: Task):
18
+ global env, nsfw_threshold, nsfw_access, access_token
19
+ name = task.get_queue_name()
20
+ if name.startswith("prod"):
21
+ env = "prod"
22
+ else:
23
+ env = "gamma"
24
+ nsfw_threshold = task.get_nsfw_threshold()
25
+ nsfw_access = task.can_access_nsfw()
26
+ access_token = task.get_access_token()
27
+
28
+
29
+ def get_root_dir():
30
+ global root_dir
31
+ return root_dir
32
+
33
+
34
+ def get_environment():
35
+ global env
36
+ return env
37
+
38
+
39
+ def get_nsfw_threshold():
40
+ global nsfw_threshold
41
+ return nsfw_threshold
42
+
43
+
44
+ def get_nsfw_access():
45
+ global nsfw_access
46
+ return nsfw_access
47
+
48
+
49
+ def api_headers():
50
+ return {
51
+ "Access-Token": access_token,
52
+ }
53
+
54
+
55
+ def api_endpoint():
56
+ if env == "prod":
57
+ return "https://prod.pratilipicomics.com"
58
+ else:
59
+ return "https://gamma.pratilipicomics.com"
60
+
61
+
62
+ def comic_url():
63
+ if env == "prod":
64
+ return "http://internal-k8s-prod-internal-bb9c57a6bb-1524739074.ap-south-1.elb.amazonaws.com:80"
65
+ else:
66
+ return "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80"
internals/util/failure_hander.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+
5
+ from internals.data.dataAccessor import updateSource
6
+ from internals.data.task import Task
7
+ from internals.util.config import set_configs_from_task
8
+ from internals.util.slack import Slack
9
+
10
+
11
+ class FailureHandler:
12
+ __task_path = Path.home() / ".cache" / "inference" / "task.json"
13
+
14
+ @staticmethod
15
+ def register():
16
+ path = FailureHandler.__task_path
17
+ path.parent.mkdir(parents=True, exist_ok=True)
18
+ if path.exists():
19
+ task = Task(json.loads(path.read_text()))
20
+ set_configs_from_task(task)
21
+ # Slack().error_alert(task, Exception("CATASTROPHIC FAILURE"))
22
+ updateSource(task.get_sourceId(), task.get_userId(), "FAILED")
23
+ os.remove(path)
24
+
25
+ @staticmethod
26
+ def clear(func):
27
+ def wrapper(*args, **kwargs):
28
+ result = func(*args, **kwargs)
29
+ if result is not None:
30
+ path = FailureHandler.__task_path
31
+ if path.exists():
32
+ os.remove(path)
33
+ return result
34
+
35
+ return wrapper
36
+
37
+ @staticmethod
38
+ def handle(task: Task):
39
+ path = FailureHandler.__task_path
40
+ path.write_text(json.dumps(task.get_raw()))
internals/util/image.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ from PIL import Image
4
+
5
+
6
+ def to_bytes(image: Image.Image) -> bytes:
7
+ with io.BytesIO() as output:
8
+ image.save(output, format="JPEG")
9
+ return output.getvalue()
10
+
11
+
12
+ def resize_image_to512(image: Image.Image) -> Image.Image:
13
+ iw, ih = image.size
14
+ if iw > ih:
15
+ image = image.resize((512, int(512 * ih / iw)))
16
+ else:
17
+ image = image.resize((int(512 * iw / ih), 512))
18
+ return image
internals/util/lora_style.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Any, Dict, List, Union
5
+
6
+ import boto3
7
+ import torch
8
+ from lora_diffusion import patch_pipe, tune_lora_scale
9
+ from pydash import chain
10
+
11
+ from internals.data.dataAccessor import getStyles
12
+ from internals.util.commons import download_file
13
+
14
+
15
+ class LoraStyle:
16
+ class LoraPatcher:
17
+ def __init__(self, pipe, style: Dict[str, Any]):
18
+ self.__style = style
19
+ self.pipe = pipe
20
+
21
+ @torch.inference_mode()
22
+ def patch(self):
23
+ path = self.__style["path"]
24
+ if str(path).endswith((".pt", ".safetensors")):
25
+ patch_pipe(self.pipe, self.__style["path"])
26
+ tune_lora_scale(self.pipe.unet, self.__style["weight"])
27
+ tune_lora_scale(self.pipe.text_encoder, self.__style["weight"])
28
+
29
+ def kwargs(self):
30
+ return {}
31
+
32
+ def cleanup(self):
33
+ tune_lora_scale(self.pipe.unet, 0.0)
34
+ tune_lora_scale(self.pipe.text_encoder, 0.0)
35
+ pass
36
+
37
+ class EmptyLoraPatcher:
38
+ def __init__(self, pipe):
39
+ self.pipe = pipe
40
+
41
+ def patch(self):
42
+ "Patch will act as cleanup, to tune down any corrupted lora"
43
+ self.cleanup()
44
+ pass
45
+
46
+ def kwargs(self):
47
+ return {}
48
+
49
+ def cleanup(self):
50
+ tune_lora_scale(self.pipe.unet, 0.0)
51
+ tune_lora_scale(self.pipe.text_encoder, 0.0)
52
+ pass
53
+
54
+ def load(self, model_dir: str):
55
+ self.model = model_dir
56
+ self.fetch_styles()
57
+
58
+ def fetch_styles(self):
59
+ model_dir = self.model
60
+ result = getStyles()
61
+ if result is not None:
62
+ self.__styles = self.__parse_styles(model_dir, result["data"])
63
+ else:
64
+ self.__styles = self.__get_default_styles(model_dir)
65
+ self.__verify()
66
+
67
+ def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
68
+ if key in self.__styles:
69
+ style = self.__styles[key]
70
+ return f"{', '.join(style['text'])}, {prompt}"
71
+ return prompt
72
+
73
+ def get_patcher(self, pipe, key: str) -> Union[LoraPatcher, EmptyLoraPatcher]:
74
+ if key in self.__styles:
75
+ style = self.__styles[key]
76
+ return self.LoraPatcher(pipe, style)
77
+ return self.EmptyLoraPatcher(pipe)
78
+
79
+ def __parse_styles(self, model_dir: str, data: List[Dict]) -> Dict:
80
+ styles = {}
81
+ download_dir = Path(Path.home() / ".cache" / "lora")
82
+ download_dir.mkdir(exist_ok=True)
83
+ data = chain(data).uniq_by(lambda x: x["tag"]).value()
84
+ for item in data:
85
+ if item["attributes"] is not None:
86
+ attr = json.loads(item["attributes"])
87
+ if "path" in attr:
88
+ file_path = Path(download_dir / attr["path"].split("/")[-1])
89
+
90
+ if not file_path.exists():
91
+ s3_uri = attr["path"]
92
+ download_file(s3_uri, file_path)
93
+
94
+ styles[item["tag"]] = {
95
+ "path": str(file_path),
96
+ "weight": attr["weight"],
97
+ "type": attr["type"],
98
+ "text": attr["text"],
99
+ "negativePrompt": attr["negativePrompt"],
100
+ }
101
+ if len(styles) == 0:
102
+ return self.__get_default_styles(model_dir)
103
+ return styles
104
+
105
+ def __get_default_styles(self, model_dir: str) -> Dict:
106
+ return {
107
+ "nq6akX1CIp": {
108
+ "path": model_dir + "/laur_style/nq6akX1CIp/final_lora.safetensors",
109
+ "text": ["nq6akX1CIp style"],
110
+ "weight": 0.5,
111
+ "negativePrompt": [""],
112
+ "type": "custom",
113
+ },
114
+ "ghibli": {
115
+ "path": model_dir + "/laur_style/nq6akX1CIp/ghibli.bin",
116
+ "text": ["ghibli style"],
117
+ "weight": 1,
118
+ "negativePrompt": [""],
119
+ "type": "custom",
120
+ },
121
+ "eQAmnK2kB2": {
122
+ "path": model_dir + "/laur_style/eQAmnK2kB2/final_lora.safetensors",
123
+ "text": ["eQAmnK2kB2 style"],
124
+ "weight": 0.5,
125
+ "negativePrompt": [""],
126
+ "type": "custom",
127
+ },
128
+ "to8contrast": {
129
+ "path": model_dir + "/laur_style/rpjgusOgqD/final_lora.bin",
130
+ "text": ["to8contrast style"],
131
+ "weight": 0.5,
132
+ "negativePrompt": [""],
133
+ "type": "custom",
134
+ },
135
+ "sfrrfz8vge": {
136
+ "path": model_dir + "/laur_style/replicate/sfrrfz8vge.safetensors",
137
+ "text": ["sfrrfz8vge style"],
138
+ "weight": 1.2,
139
+ "negativePrompt": [""],
140
+ "type": "custom",
141
+ },
142
+ }
143
+
144
+ def __verify(self):
145
+ "A method to verify if lora exists within the required path otherwise throw error"
146
+
147
+ for item in self.__styles.keys():
148
+ if not os.path.exists(self.__styles[item]["path"]):
149
+ raise Exception(
150
+ "Lora style model "
151
+ + item
152
+ + " not found at path: "
153
+ + self.__styles[item]["path"]
154
+ )
internals/util/slack.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import sleep
2
+ from typing import Optional
3
+
4
+ import requests
5
+
6
+ from internals.data.task import Task
7
+ from internals.util.config import get_environment
8
+
9
+
10
+ class Slack:
11
+ def __init__(self):
12
+ # self.webhook_url = "https://hooks.slack.com/services/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"
13
+ self.webhook_url = "https://hooks.slack.com/services/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"
14
+ self.error_webhook = "https://hooks.slack.com/services/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"
15
+
16
+ def send_alert(self, task: Task, args: Optional[dict]):
17
+ raw = task.get_raw().copy()
18
+
19
+ raw["environment"] = get_environment()
20
+ raw.pop("queue_name", None)
21
+ raw.pop("attempt", None)
22
+ raw.pop("timestamp", None)
23
+ raw.pop("task_id", None)
24
+ raw.pop("maskImageUrl", None)
25
+
26
+ if args is not None:
27
+ raw.update(args.items())
28
+
29
+ message = ""
30
+ for key, value in raw.items():
31
+ if value:
32
+ if type(value) == list:
33
+ message += f"*{key}*: {', '.join(value)}\n"
34
+ else:
35
+ message += f"*{key}*: {value}\n"
36
+
37
+ requests.post(
38
+ self.webhook_url,
39
+ headers={"Content-Type": "application/json"},
40
+ json={"text": message},
41
+ )
42
+
43
+ def error_alert(self, task: Task, e: Exception):
44
+ requests.post(
45
+ self.error_webhook,
46
+ headers={"Content-Type": "application/json"},
47
+ json={
48
+ "text": "Task failed:\n{} \n error is: \n {}".format(task.get_raw(), e)
49
+ },
50
+ )
51
+
52
+ def auto_send_alert(self, func):
53
+ def inner(*args, **kwargs):
54
+ rargs = func(*args, **kwargs)
55
+ self.send_alert(args[0], rargs)
56
+ return rargs
57
+
58
+ return inner
models/ade20k/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/ade20k/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .base import *
models/ade20k/base.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
2
+
3
+ import os
4
+
5
+ import pandas as pd
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from scipy.io import loadmat
10
+ from torch.nn.modules import BatchNorm2d
11
+
12
+ from . import resnet
13
+ from . import mobilenet
14
+
15
+
16
+ NUM_CLASS = 150
17
+ base_path = os.path.dirname(os.path.abspath(__file__)) # current file path
18
+ colors_path = os.path.join(base_path, 'color150.mat')
19
+ classes_path = os.path.join(base_path, 'object150_info.csv')
20
+
21
+ segm_options = dict(colors=loadmat(colors_path)['colors'],
22
+ classes=pd.read_csv(classes_path),)
23
+
24
+
25
+ class NormalizeTensor:
26
+ def __init__(self, mean, std, inplace=False):
27
+ """Normalize a tensor image with mean and standard deviation.
28
+ .. note::
29
+ This transform acts out of place by default, i.e., it does not mutates the input tensor.
30
+ See :class:`~torchvision.transforms.Normalize` for more details.
31
+ Args:
32
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
33
+ mean (sequence): Sequence of means for each channel.
34
+ std (sequence): Sequence of standard deviations for each channel.
35
+ inplace(bool,optional): Bool to make this operation inplace.
36
+ Returns:
37
+ Tensor: Normalized Tensor image.
38
+ """
39
+
40
+ self.mean = mean
41
+ self.std = std
42
+ self.inplace = inplace
43
+
44
+ def __call__(self, tensor):
45
+ if not self.inplace:
46
+ tensor = tensor.clone()
47
+
48
+ dtype = tensor.dtype
49
+ mean = torch.as_tensor(self.mean, dtype=dtype, device=tensor.device)
50
+ std = torch.as_tensor(self.std, dtype=dtype, device=tensor.device)
51
+ tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
52
+ return tensor
53
+
54
+
55
+ # Model Builder
56
+ class ModelBuilder:
57
+ # custom weights initialization
58
+ @staticmethod
59
+ def weights_init(m):
60
+ classname = m.__class__.__name__
61
+ if classname.find('Conv') != -1:
62
+ nn.init.kaiming_normal_(m.weight.data)
63
+ elif classname.find('BatchNorm') != -1:
64
+ m.weight.data.fill_(1.)
65
+ m.bias.data.fill_(1e-4)
66
+
67
+ @staticmethod
68
+ def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''):
69
+ pretrained = True if len(weights) == 0 else False
70
+ arch = arch.lower()
71
+ if arch == 'mobilenetv2dilated':
72
+ orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained)
73
+ net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8)
74
+ elif arch == 'resnet18':
75
+ orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
76
+ net_encoder = Resnet(orig_resnet)
77
+ elif arch == 'resnet18dilated':
78
+ orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
79
+ net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
80
+ elif arch == 'resnet50dilated':
81
+ orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
82
+ net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
83
+ elif arch == 'resnet50':
84
+ orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
85
+ net_encoder = Resnet(orig_resnet)
86
+ else:
87
+ raise Exception('Architecture undefined!')
88
+
89
+ # encoders are usually pretrained
90
+ # net_encoder.apply(ModelBuilder.weights_init)
91
+ if len(weights) > 0:
92
+ print('Loading weights for net_encoder')
93
+ net_encoder.load_state_dict(
94
+ torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
95
+ return net_encoder
96
+
97
+ @staticmethod
98
+ def build_decoder(arch='ppm_deepsup',
99
+ fc_dim=512, num_class=NUM_CLASS,
100
+ weights='', use_softmax=False, drop_last_conv=False):
101
+ arch = arch.lower()
102
+ if arch == 'ppm_deepsup':
103
+ net_decoder = PPMDeepsup(
104
+ num_class=num_class,
105
+ fc_dim=fc_dim,
106
+ use_softmax=use_softmax,
107
+ drop_last_conv=drop_last_conv)
108
+ elif arch == 'c1_deepsup':
109
+ net_decoder = C1DeepSup(
110
+ num_class=num_class,
111
+ fc_dim=fc_dim,
112
+ use_softmax=use_softmax,
113
+ drop_last_conv=drop_last_conv)
114
+ else:
115
+ raise Exception('Architecture undefined!')
116
+
117
+ net_decoder.apply(ModelBuilder.weights_init)
118
+ if len(weights) > 0:
119
+ print('Loading weights for net_decoder')
120
+ net_decoder.load_state_dict(
121
+ torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
122
+ return net_decoder
123
+
124
+ @staticmethod
125
+ def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, drop_last_conv, *arts, **kwargs):
126
+ path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth')
127
+ return ModelBuilder.build_decoder(arch=arch_decoder, fc_dim=fc_dim, weights=path, use_softmax=True, drop_last_conv=drop_last_conv)
128
+
129
+ @staticmethod
130
+ def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, segmentation,
131
+ *arts, **kwargs):
132
+ if segmentation:
133
+ path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth')
134
+ else:
135
+ path = ''
136
+ return ModelBuilder.build_encoder(arch=arch_encoder, fc_dim=fc_dim, weights=path)
137
+
138
+
139
+ def conv3x3_bn_relu(in_planes, out_planes, stride=1):
140
+ return nn.Sequential(
141
+ nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False),
142
+ BatchNorm2d(out_planes),
143
+ nn.ReLU(inplace=True),
144
+ )
145
+
146
+
147
+ class SegmentationModule(nn.Module):
148
+ def __init__(self,
149
+ weights_path,
150
+ num_classes=150,
151
+ arch_encoder="resnet50dilated",
152
+ drop_last_conv=False,
153
+ net_enc=None, # None for Default encoder
154
+ net_dec=None, # None for Default decoder
155
+ encode=None, # {None, 'binary', 'color', 'sky'}
156
+ use_default_normalization=False,
157
+ return_feature_maps=False,
158
+ return_feature_maps_level=3, # {0, 1, 2, 3}
159
+ return_feature_maps_only=True,
160
+ **kwargs,
161
+ ):
162
+ super().__init__()
163
+ self.weights_path = weights_path
164
+ self.drop_last_conv = drop_last_conv
165
+ self.arch_encoder = arch_encoder
166
+ if self.arch_encoder == "resnet50dilated":
167
+ self.arch_decoder = "ppm_deepsup"
168
+ self.fc_dim = 2048
169
+ elif self.arch_encoder == "mobilenetv2dilated":
170
+ self.arch_decoder = "c1_deepsup"
171
+ self.fc_dim = 320
172
+ else:
173
+ raise NotImplementedError(f"No such arch_encoder={self.arch_encoder}")
174
+ model_builder_kwargs = dict(arch_encoder=self.arch_encoder,
175
+ arch_decoder=self.arch_decoder,
176
+ fc_dim=self.fc_dim,
177
+ drop_last_conv=drop_last_conv,
178
+ weights_path=self.weights_path)
179
+
180
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
181
+ self.encoder = ModelBuilder.get_encoder(**model_builder_kwargs) if net_enc is None else net_enc
182
+ self.decoder = ModelBuilder.get_decoder(**model_builder_kwargs) if net_dec is None else net_dec
183
+ self.use_default_normalization = use_default_normalization
184
+ self.default_normalization = NormalizeTensor(mean=[0.485, 0.456, 0.406],
185
+ std=[0.229, 0.224, 0.225])
186
+
187
+ self.encode = encode
188
+
189
+ self.return_feature_maps = return_feature_maps
190
+
191
+ assert 0 <= return_feature_maps_level <= 3
192
+ self.return_feature_maps_level = return_feature_maps_level
193
+
194
+ def normalize_input(self, tensor):
195
+ if tensor.min() < 0 or tensor.max() > 1:
196
+ raise ValueError("Tensor should be 0..1 before using normalize_input")
197
+ return self.default_normalization(tensor)
198
+
199
+ @property
200
+ def feature_maps_channels(self):
201
+ return 256 * 2**(self.return_feature_maps_level) # 256, 512, 1024, 2048
202
+
203
+ def forward(self, img_data, segSize=None):
204
+ if segSize is None:
205
+ raise NotImplementedError("Please pass segSize param. By default: (300, 300)")
206
+
207
+ fmaps = self.encoder(img_data, return_feature_maps=True)
208
+ pred = self.decoder(fmaps, segSize=segSize)
209
+
210
+ if self.return_feature_maps:
211
+ return pred, fmaps
212
+ # print("BINARY", img_data.shape, pred.shape)
213
+ return pred
214
+
215
+ def multi_mask_from_multiclass(self, pred, classes):
216
+ def isin(ar1, ar2):
217
+ return (ar1[..., None] == ar2).any(-1).float()
218
+ return isin(pred, torch.LongTensor(classes).to(self.device))
219
+
220
+ @staticmethod
221
+ def multi_mask_from_multiclass_probs(scores, classes):
222
+ res = None
223
+ for c in classes:
224
+ if res is None:
225
+ res = scores[:, c]
226
+ else:
227
+ res += scores[:, c]
228
+ return res
229
+
230
+ def predict(self, tensor, imgSizes=(-1,), # (300, 375, 450, 525, 600)
231
+ segSize=None):
232
+ """Entry-point for segmentation. Use this methods instead of forward
233
+ Arguments:
234
+ tensor {torch.Tensor} -- BCHW
235
+ Keyword Arguments:
236
+ imgSizes {tuple or list} -- imgSizes for segmentation input.
237
+ default: (300, 450)
238
+ original implementation: (300, 375, 450, 525, 600)
239
+
240
+ """
241
+ if segSize is None:
242
+ segSize = tensor.shape[-2:]
243
+ segSize = (tensor.shape[2], tensor.shape[3])
244
+ with torch.no_grad():
245
+ if self.use_default_normalization:
246
+ tensor = self.normalize_input(tensor)
247
+ scores = torch.zeros(1, NUM_CLASS, segSize[0], segSize[1]).to(self.device)
248
+ features = torch.zeros(1, self.feature_maps_channels, segSize[0], segSize[1]).to(self.device)
249
+
250
+ result = []
251
+ for img_size in imgSizes:
252
+ if img_size != -1:
253
+ img_data = F.interpolate(tensor.clone(), size=img_size)
254
+ else:
255
+ img_data = tensor.clone()
256
+
257
+ if self.return_feature_maps:
258
+ pred_current, fmaps = self.forward(img_data, segSize=segSize)
259
+ else:
260
+ pred_current = self.forward(img_data, segSize=segSize)
261
+
262
+
263
+ result.append(pred_current)
264
+ scores = scores + pred_current / len(imgSizes)
265
+
266
+ # Disclaimer: We use and aggregate only last fmaps: fmaps[3]
267
+ if self.return_feature_maps:
268
+ features = features + F.interpolate(fmaps[self.return_feature_maps_level], size=segSize) / len(imgSizes)
269
+
270
+ _, pred = torch.max(scores, dim=1)
271
+
272
+ if self.return_feature_maps:
273
+ return features
274
+
275
+ return pred, result
276
+
277
+ def get_edges(self, t):
278
+ edge = torch.cuda.ByteTensor(t.size()).zero_()
279
+ edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
280
+ edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
281
+ edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
282
+ edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
283
+
284
+ if True:
285
+ return edge.half()
286
+ return edge.float()
287
+
288
+
289
+ # pyramid pooling, deep supervision
290
+ class PPMDeepsup(nn.Module):
291
+ def __init__(self, num_class=NUM_CLASS, fc_dim=4096,
292
+ use_softmax=False, pool_scales=(1, 2, 3, 6),
293
+ drop_last_conv=False):
294
+ super().__init__()
295
+ self.use_softmax = use_softmax
296
+ self.drop_last_conv = drop_last_conv
297
+
298
+ self.ppm = []
299
+ for scale in pool_scales:
300
+ self.ppm.append(nn.Sequential(
301
+ nn.AdaptiveAvgPool2d(scale),
302
+ nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
303
+ BatchNorm2d(512),
304
+ nn.ReLU(inplace=True)
305
+ ))
306
+ self.ppm = nn.ModuleList(self.ppm)
307
+ self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
308
+
309
+ self.conv_last = nn.Sequential(
310
+ nn.Conv2d(fc_dim + len(pool_scales) * 512, 512,
311
+ kernel_size=3, padding=1, bias=False),
312
+ BatchNorm2d(512),
313
+ nn.ReLU(inplace=True),
314
+ nn.Dropout2d(0.1),
315
+ nn.Conv2d(512, num_class, kernel_size=1)
316
+ )
317
+ self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
318
+ self.dropout_deepsup = nn.Dropout2d(0.1)
319
+
320
+ def forward(self, conv_out, segSize=None):
321
+ conv5 = conv_out[-1]
322
+
323
+ input_size = conv5.size()
324
+ ppm_out = [conv5]
325
+ for pool_scale in self.ppm:
326
+ ppm_out.append(nn.functional.interpolate(
327
+ pool_scale(conv5),
328
+ (input_size[2], input_size[3]),
329
+ mode='bilinear', align_corners=False))
330
+ ppm_out = torch.cat(ppm_out, 1)
331
+
332
+ if self.drop_last_conv:
333
+ return ppm_out
334
+ else:
335
+ x = self.conv_last(ppm_out)
336
+
337
+ if self.use_softmax: # is True during inference
338
+ x = nn.functional.interpolate(
339
+ x, size=segSize, mode='bilinear', align_corners=False)
340
+ x = nn.functional.softmax(x, dim=1)
341
+ return x
342
+
343
+ # deep sup
344
+ conv4 = conv_out[-2]
345
+ _ = self.cbr_deepsup(conv4)
346
+ _ = self.dropout_deepsup(_)
347
+ _ = self.conv_last_deepsup(_)
348
+
349
+ x = nn.functional.log_softmax(x, dim=1)
350
+ _ = nn.functional.log_softmax(_, dim=1)
351
+
352
+ return (x, _)
353
+
354
+
355
+ class Resnet(nn.Module):
356
+ def __init__(self, orig_resnet):
357
+ super(Resnet, self).__init__()
358
+
359
+ # take pretrained resnet, except AvgPool and FC
360
+ self.conv1 = orig_resnet.conv1
361
+ self.bn1 = orig_resnet.bn1
362
+ self.relu1 = orig_resnet.relu1
363
+ self.conv2 = orig_resnet.conv2
364
+ self.bn2 = orig_resnet.bn2
365
+ self.relu2 = orig_resnet.relu2
366
+ self.conv3 = orig_resnet.conv3
367
+ self.bn3 = orig_resnet.bn3
368
+ self.relu3 = orig_resnet.relu3
369
+ self.maxpool = orig_resnet.maxpool
370
+ self.layer1 = orig_resnet.layer1
371
+ self.layer2 = orig_resnet.layer2
372
+ self.layer3 = orig_resnet.layer3
373
+ self.layer4 = orig_resnet.layer4
374
+
375
+ def forward(self, x, return_feature_maps=False):
376
+ conv_out = []
377
+
378
+ x = self.relu1(self.bn1(self.conv1(x)))
379
+ x = self.relu2(self.bn2(self.conv2(x)))
380
+ x = self.relu3(self.bn3(self.conv3(x)))
381
+ x = self.maxpool(x)
382
+
383
+ x = self.layer1(x); conv_out.append(x);
384
+ x = self.layer2(x); conv_out.append(x);
385
+ x = self.layer3(x); conv_out.append(x);
386
+ x = self.layer4(x); conv_out.append(x);
387
+
388
+ if return_feature_maps:
389
+ return conv_out
390
+ return [x]
391
+
392
+ # Resnet Dilated
393
+ class ResnetDilated(nn.Module):
394
+ def __init__(self, orig_resnet, dilate_scale=8):
395
+ super().__init__()
396
+ from functools import partial
397
+
398
+ if dilate_scale == 8:
399
+ orig_resnet.layer3.apply(
400
+ partial(self._nostride_dilate, dilate=2))
401
+ orig_resnet.layer4.apply(
402
+ partial(self._nostride_dilate, dilate=4))
403
+ elif dilate_scale == 16:
404
+ orig_resnet.layer4.apply(
405
+ partial(self._nostride_dilate, dilate=2))
406
+
407
+ # take pretrained resnet, except AvgPool and FC
408
+ self.conv1 = orig_resnet.conv1
409
+ self.bn1 = orig_resnet.bn1
410
+ self.relu1 = orig_resnet.relu1
411
+ self.conv2 = orig_resnet.conv2
412
+ self.bn2 = orig_resnet.bn2
413
+ self.relu2 = orig_resnet.relu2
414
+ self.conv3 = orig_resnet.conv3
415
+ self.bn3 = orig_resnet.bn3
416
+ self.relu3 = orig_resnet.relu3
417
+ self.maxpool = orig_resnet.maxpool
418
+ self.layer1 = orig_resnet.layer1
419
+ self.layer2 = orig_resnet.layer2
420
+ self.layer3 = orig_resnet.layer3
421
+ self.layer4 = orig_resnet.layer4
422
+
423
+ def _nostride_dilate(self, m, dilate):
424
+ classname = m.__class__.__name__
425
+ if classname.find('Conv') != -1:
426
+ # the convolution with stride
427
+ if m.stride == (2, 2):
428
+ m.stride = (1, 1)
429
+ if m.kernel_size == (3, 3):
430
+ m.dilation = (dilate // 2, dilate // 2)
431
+ m.padding = (dilate // 2, dilate // 2)
432
+ # other convoluions
433
+ else:
434
+ if m.kernel_size == (3, 3):
435
+ m.dilation = (dilate, dilate)
436
+ m.padding = (dilate, dilate)
437
+
438
+ def forward(self, x, return_feature_maps=False):
439
+ conv_out = []
440
+
441
+ x = self.relu1(self.bn1(self.conv1(x)))
442
+ x = self.relu2(self.bn2(self.conv2(x)))
443
+ x = self.relu3(self.bn3(self.conv3(x)))
444
+ x = self.maxpool(x)
445
+
446
+ x = self.layer1(x)
447
+ conv_out.append(x)
448
+ x = self.layer2(x)
449
+ conv_out.append(x)
450
+ x = self.layer3(x)
451
+ conv_out.append(x)
452
+ x = self.layer4(x)
453
+ conv_out.append(x)
454
+
455
+ if return_feature_maps:
456
+ return conv_out
457
+ return [x]
458
+
459
+ class MobileNetV2Dilated(nn.Module):
460
+ def __init__(self, orig_net, dilate_scale=8):
461
+ super(MobileNetV2Dilated, self).__init__()
462
+ from functools import partial
463
+
464
+ # take pretrained mobilenet features
465
+ self.features = orig_net.features[:-1]
466
+
467
+ self.total_idx = len(self.features)
468
+ self.down_idx = [2, 4, 7, 14]
469
+
470
+ if dilate_scale == 8:
471
+ for i in range(self.down_idx[-2], self.down_idx[-1]):
472
+ self.features[i].apply(
473
+ partial(self._nostride_dilate, dilate=2)
474
+ )
475
+ for i in range(self.down_idx[-1], self.total_idx):
476
+ self.features[i].apply(
477
+ partial(self._nostride_dilate, dilate=4)
478
+ )
479
+ elif dilate_scale == 16:
480
+ for i in range(self.down_idx[-1], self.total_idx):
481
+ self.features[i].apply(
482
+ partial(self._nostride_dilate, dilate=2)
483
+ )
484
+
485
+ def _nostride_dilate(self, m, dilate):
486
+ classname = m.__class__.__name__
487
+ if classname.find('Conv') != -1:
488
+ # the convolution with stride
489
+ if m.stride == (2, 2):
490
+ m.stride = (1, 1)
491
+ if m.kernel_size == (3, 3):
492
+ m.dilation = (dilate//2, dilate//2)
493
+ m.padding = (dilate//2, dilate//2)
494
+ # other convoluions
495
+ else:
496
+ if m.kernel_size == (3, 3):
497
+ m.dilation = (dilate, dilate)
498
+ m.padding = (dilate, dilate)
499
+
500
+ def forward(self, x, return_feature_maps=False):
501
+ if return_feature_maps:
502
+ conv_out = []
503
+ for i in range(self.total_idx):
504
+ x = self.features[i](x)
505
+ if i in self.down_idx:
506
+ conv_out.append(x)
507
+ conv_out.append(x)
508
+ return conv_out
509
+
510
+ else:
511
+ return [self.features(x)]
512
+
513
+
514
+ # last conv, deep supervision
515
+ class C1DeepSup(nn.Module):
516
+ def __init__(self, num_class=150, fc_dim=2048, use_softmax=False, drop_last_conv=False):
517
+ super(C1DeepSup, self).__init__()
518
+ self.use_softmax = use_softmax
519
+ self.drop_last_conv = drop_last_conv
520
+
521
+ self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
522
+ self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
523
+
524
+ # last conv
525
+ self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
526
+ self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
527
+
528
+ def forward(self, conv_out, segSize=None):
529
+ conv5 = conv_out[-1]
530
+
531
+ x = self.cbr(conv5)
532
+
533
+ if self.drop_last_conv:
534
+ return x
535
+ else:
536
+ x = self.conv_last(x)
537
+
538
+ if self.use_softmax: # is True during inference
539
+ x = nn.functional.interpolate(
540
+ x, size=segSize, mode='bilinear', align_corners=False)
541
+ x = nn.functional.softmax(x, dim=1)
542
+ return x
543
+
544
+ # deep sup
545
+ conv4 = conv_out[-2]
546
+ _ = self.cbr_deepsup(conv4)
547
+ _ = self.conv_last_deepsup(_)
548
+
549
+ x = nn.functional.log_softmax(x, dim=1)
550
+ _ = nn.functional.log_softmax(_, dim=1)
551
+
552
+ return (x, _)
553
+
554
+
555
+ # last conv
556
+ class C1(nn.Module):
557
+ def __init__(self, num_class=150, fc_dim=2048, use_softmax=False):
558
+ super(C1, self).__init__()
559
+ self.use_softmax = use_softmax
560
+
561
+ self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
562
+
563
+ # last conv
564
+ self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
565
+
566
+ def forward(self, conv_out, segSize=None):
567
+ conv5 = conv_out[-1]
568
+ x = self.cbr(conv5)
569
+ x = self.conv_last(x)
570
+
571
+ if self.use_softmax: # is True during inference
572
+ x = nn.functional.interpolate(
573
+ x, size=segSize, mode='bilinear', align_corners=False)
574
+ x = nn.functional.softmax(x, dim=1)
575
+ else:
576
+ x = nn.functional.log_softmax(x, dim=1)
577
+
578
+ return x
579
+
580
+
581
+ # pyramid pooling
582
+ class PPM(nn.Module):
583
+ def __init__(self, num_class=150, fc_dim=4096,
584
+ use_softmax=False, pool_scales=(1, 2, 3, 6)):
585
+ super(PPM, self).__init__()
586
+ self.use_softmax = use_softmax
587
+
588
+ self.ppm = []
589
+ for scale in pool_scales:
590
+ self.ppm.append(nn.Sequential(
591
+ nn.AdaptiveAvgPool2d(scale),
592
+ nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
593
+ BatchNorm2d(512),
594
+ nn.ReLU(inplace=True)
595
+ ))
596
+ self.ppm = nn.ModuleList(self.ppm)
597
+
598
+ self.conv_last = nn.Sequential(
599
+ nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
600
+ kernel_size=3, padding=1, bias=False),
601
+ BatchNorm2d(512),
602
+ nn.ReLU(inplace=True),
603
+ nn.Dropout2d(0.1),
604
+ nn.Conv2d(512, num_class, kernel_size=1)
605
+ )
606
+
607
+ def forward(self, conv_out, segSize=None):
608
+ conv5 = conv_out[-1]
609
+
610
+ input_size = conv5.size()
611
+ ppm_out = [conv5]
612
+ for pool_scale in self.ppm:
613
+ ppm_out.append(nn.functional.interpolate(
614
+ pool_scale(conv5),
615
+ (input_size[2], input_size[3]),
616
+ mode='bilinear', align_corners=False))
617
+ ppm_out = torch.cat(ppm_out, 1)
618
+
619
+ x = self.conv_last(ppm_out)
620
+
621
+ if self.use_softmax: # is True during inference
622
+ x = nn.functional.interpolate(
623
+ x, size=segSize, mode='bilinear', align_corners=False)
624
+ x = nn.functional.softmax(x, dim=1)
625
+ else:
626
+ x = nn.functional.log_softmax(x, dim=1)
627
+ return x
models/ade20k/color150.mat ADDED
Binary file (502 Bytes). View file
 
models/ade20k/mobilenet.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This MobileNetV2 implementation is modified from the following repository:
3
+ https://github.com/tonylins/pytorch-mobilenet-v2
4
+ """
5
+
6
+ import torch.nn as nn
7
+ import math
8
+ from .utils import load_url
9
+ from .segm_lib.nn import SynchronizedBatchNorm2d
10
+
11
+ BatchNorm2d = SynchronizedBatchNorm2d
12
+
13
+
14
+ __all__ = ['mobilenetv2']
15
+
16
+
17
+ model_urls = {
18
+ 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar',
19
+ }
20
+
21
+
22
+ def conv_bn(inp, oup, stride):
23
+ return nn.Sequential(
24
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
25
+ BatchNorm2d(oup),
26
+ nn.ReLU6(inplace=True)
27
+ )
28
+
29
+
30
+ def conv_1x1_bn(inp, oup):
31
+ return nn.Sequential(
32
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
33
+ BatchNorm2d(oup),
34
+ nn.ReLU6(inplace=True)
35
+ )
36
+
37
+
38
+ class InvertedResidual(nn.Module):
39
+ def __init__(self, inp, oup, stride, expand_ratio):
40
+ super(InvertedResidual, self).__init__()
41
+ self.stride = stride
42
+ assert stride in [1, 2]
43
+
44
+ hidden_dim = round(inp * expand_ratio)
45
+ self.use_res_connect = self.stride == 1 and inp == oup
46
+
47
+ if expand_ratio == 1:
48
+ self.conv = nn.Sequential(
49
+ # dw
50
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
51
+ BatchNorm2d(hidden_dim),
52
+ nn.ReLU6(inplace=True),
53
+ # pw-linear
54
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
55
+ BatchNorm2d(oup),
56
+ )
57
+ else:
58
+ self.conv = nn.Sequential(
59
+ # pw
60
+ nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
61
+ BatchNorm2d(hidden_dim),
62
+ nn.ReLU6(inplace=True),
63
+ # dw
64
+ nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
65
+ BatchNorm2d(hidden_dim),
66
+ nn.ReLU6(inplace=True),
67
+ # pw-linear
68
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
69
+ BatchNorm2d(oup),
70
+ )
71
+
72
+ def forward(self, x):
73
+ if self.use_res_connect:
74
+ return x + self.conv(x)
75
+ else:
76
+ return self.conv(x)
77
+
78
+
79
+ class MobileNetV2(nn.Module):
80
+ def __init__(self, n_class=1000, input_size=224, width_mult=1.):
81
+ super(MobileNetV2, self).__init__()
82
+ block = InvertedResidual
83
+ input_channel = 32
84
+ last_channel = 1280
85
+ interverted_residual_setting = [
86
+ # t, c, n, s
87
+ [1, 16, 1, 1],
88
+ [6, 24, 2, 2],
89
+ [6, 32, 3, 2],
90
+ [6, 64, 4, 2],
91
+ [6, 96, 3, 1],
92
+ [6, 160, 3, 2],
93
+ [6, 320, 1, 1],
94
+ ]
95
+
96
+ # building first layer
97
+ assert input_size % 32 == 0
98
+ input_channel = int(input_channel * width_mult)
99
+ self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
100
+ self.features = [conv_bn(3, input_channel, 2)]
101
+ # building inverted residual blocks
102
+ for t, c, n, s in interverted_residual_setting:
103
+ output_channel = int(c * width_mult)
104
+ for i in range(n):
105
+ if i == 0:
106
+ self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
107
+ else:
108
+ self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
109
+ input_channel = output_channel
110
+ # building last several layers
111
+ self.features.append(conv_1x1_bn(input_channel, self.last_channel))
112
+ # make it nn.Sequential
113
+ self.features = nn.Sequential(*self.features)
114
+
115
+ # building classifier
116
+ self.classifier = nn.Sequential(
117
+ nn.Dropout(0.2),
118
+ nn.Linear(self.last_channel, n_class),
119
+ )
120
+
121
+ self._initialize_weights()
122
+
123
+ def forward(self, x):
124
+ x = self.features(x)
125
+ x = x.mean(3).mean(2)
126
+ x = self.classifier(x)
127
+ return x
128
+
129
+ def _initialize_weights(self):
130
+ for m in self.modules():
131
+ if isinstance(m, nn.Conv2d):
132
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
133
+ m.weight.data.normal_(0, math.sqrt(2. / n))
134
+ if m.bias is not None:
135
+ m.bias.data.zero_()
136
+ elif isinstance(m, BatchNorm2d):
137
+ m.weight.data.fill_(1)
138
+ m.bias.data.zero_()
139
+ elif isinstance(m, nn.Linear):
140
+ n = m.weight.size(1)
141
+ m.weight.data.normal_(0, 0.01)
142
+ m.bias.data.zero_()
143
+
144
+
145
+ def mobilenetv2(pretrained=False, **kwargs):
146
+ """Constructs a MobileNet_V2 model.
147
+
148
+ Args:
149
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
150
+ """
151
+ model = MobileNetV2(n_class=1000, **kwargs)
152
+ if pretrained:
153
+ model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False)
154
+ return model
models/ade20k/object150_info.csv ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Idx,Ratio,Train,Val,Stuff,Name
2
+ 1,0.1576,11664,1172,1,wall
3
+ 2,0.1072,6046,612,1,building;edifice
4
+ 3,0.0878,8265,796,1,sky
5
+ 4,0.0621,9336,917,1,floor;flooring
6
+ 5,0.0480,6678,641,0,tree
7
+ 6,0.0450,6604,643,1,ceiling
8
+ 7,0.0398,4023,408,1,road;route
9
+ 8,0.0231,1906,199,0,bed
10
+ 9,0.0198,4688,460,0,windowpane;window
11
+ 10,0.0183,2423,225,1,grass
12
+ 11,0.0181,2874,294,0,cabinet
13
+ 12,0.0166,3068,310,1,sidewalk;pavement
14
+ 13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul
15
+ 14,0.0151,1804,190,1,earth;ground
16
+ 15,0.0118,6666,796,0,door;double;door
17
+ 16,0.0110,4269,411,0,table
18
+ 17,0.0109,1691,160,1,mountain;mount
19
+ 18,0.0104,3999,441,0,plant;flora;plant;life
20
+ 19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall
21
+ 20,0.0103,3261,318,0,chair
22
+ 21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar
23
+ 22,0.0074,709,75,1,water
24
+ 23,0.0067,3296,315,0,painting;picture
25
+ 24,0.0065,1191,106,0,sofa;couch;lounge
26
+ 25,0.0061,1516,162,0,shelf
27
+ 26,0.0060,667,69,1,house
28
+ 27,0.0053,651,57,1,sea
29
+ 28,0.0052,1847,224,0,mirror
30
+ 29,0.0046,1158,128,1,rug;carpet;carpeting
31
+ 30,0.0044,480,44,1,field
32
+ 31,0.0044,1172,98,0,armchair
33
+ 32,0.0044,1292,184,0,seat
34
+ 33,0.0033,1386,138,0,fence;fencing
35
+ 34,0.0031,698,61,0,desk
36
+ 35,0.0030,781,73,0,rock;stone
37
+ 36,0.0027,380,43,0,wardrobe;closet;press
38
+ 37,0.0026,3089,302,0,lamp
39
+ 38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub
40
+ 39,0.0024,804,99,0,railing;rail
41
+ 40,0.0023,1453,153,0,cushion
42
+ 41,0.0023,411,37,0,base;pedestal;stand
43
+ 42,0.0022,1440,162,0,box
44
+ 43,0.0022,800,77,0,column;pillar
45
+ 44,0.0020,2650,298,0,signboard;sign
46
+ 45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser
47
+ 46,0.0019,367,36,0,counter
48
+ 47,0.0018,311,30,1,sand
49
+ 48,0.0018,1181,122,0,sink
50
+ 49,0.0018,287,23,1,skyscraper
51
+ 50,0.0018,468,38,0,fireplace;hearth;open;fireplace
52
+ 51,0.0018,402,43,0,refrigerator;icebox
53
+ 52,0.0018,130,12,1,grandstand;covered;stand
54
+ 53,0.0018,561,64,1,path
55
+ 54,0.0017,880,102,0,stairs;steps
56
+ 55,0.0017,86,12,1,runway
57
+ 56,0.0017,172,11,0,case;display;case;showcase;vitrine
58
+ 57,0.0017,198,18,0,pool;table;billiard;table;snooker;table
59
+ 58,0.0017,930,109,0,pillow
60
+ 59,0.0015,139,18,0,screen;door;screen
61
+ 60,0.0015,564,52,1,stairway;staircase
62
+ 61,0.0015,320,26,1,river
63
+ 62,0.0015,261,29,1,bridge;span
64
+ 63,0.0014,275,22,0,bookcase
65
+ 64,0.0014,335,60,0,blind;screen
66
+ 65,0.0014,792,75,0,coffee;table;cocktail;table
67
+ 66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne
68
+ 67,0.0014,1309,138,0,flower
69
+ 68,0.0013,1112,113,0,book
70
+ 69,0.0013,266,27,1,hill
71
+ 70,0.0013,659,66,0,bench
72
+ 71,0.0012,331,31,0,countertop
73
+ 72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove
74
+ 73,0.0012,369,36,0,palm;palm;tree
75
+ 74,0.0012,144,9,0,kitchen;island
76
+ 75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system
77
+ 76,0.0010,324,33,0,swivel;chair
78
+ 77,0.0009,304,27,0,boat
79
+ 78,0.0009,170,20,0,bar
80
+ 79,0.0009,68,6,0,arcade;machine
81
+ 80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty
82
+ 81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle
83
+ 82,0.0008,492,49,0,towel
84
+ 83,0.0008,2510,269,0,light;light;source
85
+ 84,0.0008,440,39,0,truck;motortruck
86
+ 85,0.0008,147,18,1,tower
87
+ 86,0.0008,583,56,0,chandelier;pendant;pendent
88
+ 87,0.0007,533,61,0,awning;sunshade;sunblind
89
+ 88,0.0007,1989,239,0,streetlight;street;lamp
90
+ 89,0.0007,71,5,0,booth;cubicle;stall;kiosk
91
+ 90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box
92
+ 91,0.0007,135,12,0,airplane;aeroplane;plane
93
+ 92,0.0007,83,5,1,dirt;track
94
+ 93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes
95
+ 94,0.0006,1003,104,0,pole
96
+ 95,0.0006,182,12,1,land;ground;soil
97
+ 96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail
98
+ 97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway
99
+ 98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock
100
+ 99,0.0006,965,114,0,bottle
101
+ 100,0.0006,117,13,0,buffet;counter;sideboard
102
+ 101,0.0006,354,35,0,poster;posting;placard;notice;bill;card
103
+ 102,0.0006,108,9,1,stage
104
+ 103,0.0006,557,55,0,van
105
+ 104,0.0006,52,4,0,ship
106
+ 105,0.0005,99,5,0,fountain
107
+ 106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter
108
+ 107,0.0005,292,31,0,canopy
109
+ 108,0.0005,77,9,0,washer;automatic;washer;washing;machine
110
+ 109,0.0005,340,38,0,plaything;toy
111
+ 110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium
112
+ 111,0.0005,465,49,0,stool
113
+ 112,0.0005,50,4,0,barrel;cask
114
+ 113,0.0005,622,75,0,basket;handbasket
115
+ 114,0.0005,80,9,1,waterfall;falls
116
+ 115,0.0005,59,3,0,tent;collapsible;shelter
117
+ 116,0.0005,531,72,0,bag
118
+ 117,0.0005,282,30,0,minibike;motorbike
119
+ 118,0.0005,73,7,0,cradle
120
+ 119,0.0005,435,44,0,oven
121
+ 120,0.0005,136,25,0,ball
122
+ 121,0.0005,116,24,0,food;solid;food
123
+ 122,0.0004,266,31,0,step;stair
124
+ 123,0.0004,58,12,0,tank;storage;tank
125
+ 124,0.0004,418,83,0,trade;name;brand;name;brand;marque
126
+ 125,0.0004,319,43,0,microwave;microwave;oven
127
+ 126,0.0004,1193,139,0,pot;flowerpot
128
+ 127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna
129
+ 128,0.0004,347,36,0,bicycle;bike;wheel;cycle
130
+ 129,0.0004,52,5,1,lake
131
+ 130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine
132
+ 131,0.0004,108,13,0,screen;silver;screen;projection;screen
133
+ 132,0.0004,201,30,0,blanket;cover
134
+ 133,0.0004,285,21,0,sculpture
135
+ 134,0.0004,268,27,0,hood;exhaust;hood
136
+ 135,0.0003,1020,108,0,sconce
137
+ 136,0.0003,1282,122,0,vase
138
+ 137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight
139
+ 138,0.0003,453,57,0,tray
140
+ 139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin
141
+ 140,0.0003,397,44,0,fan
142
+ 141,0.0003,92,8,1,pier;wharf;wharfage;dock
143
+ 142,0.0003,228,18,0,crt;screen
144
+ 143,0.0003,570,59,0,plate
145
+ 144,0.0003,217,22,0,monitor;monitoring;device
146
+ 145,0.0003,206,19,0,bulletin;board;notice;board
147
+ 146,0.0003,130,14,0,shower
148
+ 147,0.0003,178,28,0,radiator
149
+ 148,0.0002,504,57,0,glass;drinking;glass
150
+ 149,0.0002,775,96,0,clock
151
+ 150,0.0002,421,56,0,flag
models/ade20k/resnet.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
2
+
3
+ import math
4
+
5
+ import torch.nn as nn
6
+ from torch.nn import BatchNorm2d
7
+
8
+ from .utils import load_url
9
+
10
+ __all__ = ['ResNet', 'resnet50']
11
+
12
+
13
+ model_urls = {
14
+ 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',
15
+ }
16
+
17
+
18
+ def conv3x3(in_planes, out_planes, stride=1):
19
+ "3x3 convolution with padding"
20
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
21
+ padding=1, bias=False)
22
+
23
+
24
+ class BasicBlock(nn.Module):
25
+ expansion = 1
26
+
27
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
28
+ super(BasicBlock, self).__init__()
29
+ self.conv1 = conv3x3(inplanes, planes, stride)
30
+ self.bn1 = BatchNorm2d(planes)
31
+ self.relu = nn.ReLU(inplace=True)
32
+ self.conv2 = conv3x3(planes, planes)
33
+ self.bn2 = BatchNorm2d(planes)
34
+ self.downsample = downsample
35
+ self.stride = stride
36
+
37
+ def forward(self, x):
38
+ residual = x
39
+
40
+ out = self.conv1(x)
41
+ out = self.bn1(out)
42
+ out = self.relu(out)
43
+
44
+ out = self.conv2(out)
45
+ out = self.bn2(out)
46
+
47
+ if self.downsample is not None:
48
+ residual = self.downsample(x)
49
+
50
+ out += residual
51
+ out = self.relu(out)
52
+
53
+ return out
54
+
55
+
56
+ class Bottleneck(nn.Module):
57
+ expansion = 4
58
+
59
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
60
+ super(Bottleneck, self).__init__()
61
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
62
+ self.bn1 = BatchNorm2d(planes)
63
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
64
+ padding=1, bias=False)
65
+ self.bn2 = BatchNorm2d(planes)
66
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
67
+ self.bn3 = BatchNorm2d(planes * 4)
68
+ self.relu = nn.ReLU(inplace=True)
69
+ self.downsample = downsample
70
+ self.stride = stride
71
+
72
+ def forward(self, x):
73
+ residual = x
74
+
75
+ out = self.conv1(x)
76
+ out = self.bn1(out)
77
+ out = self.relu(out)
78
+
79
+ out = self.conv2(out)
80
+ out = self.bn2(out)
81
+ out = self.relu(out)
82
+
83
+ out = self.conv3(out)
84
+ out = self.bn3(out)
85
+
86
+ if self.downsample is not None:
87
+ residual = self.downsample(x)
88
+
89
+ out += residual
90
+ out = self.relu(out)
91
+
92
+ return out
93
+
94
+
95
+ class ResNet(nn.Module):
96
+
97
+ def __init__(self, block, layers, num_classes=1000):
98
+ self.inplanes = 128
99
+ super(ResNet, self).__init__()
100
+ self.conv1 = conv3x3(3, 64, stride=2)
101
+ self.bn1 = BatchNorm2d(64)
102
+ self.relu1 = nn.ReLU(inplace=True)
103
+ self.conv2 = conv3x3(64, 64)
104
+ self.bn2 = BatchNorm2d(64)
105
+ self.relu2 = nn.ReLU(inplace=True)
106
+ self.conv3 = conv3x3(64, 128)
107
+ self.bn3 = BatchNorm2d(128)
108
+ self.relu3 = nn.ReLU(inplace=True)
109
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
110
+
111
+ self.layer1 = self._make_layer(block, 64, layers[0])
112
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
113
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
114
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
115
+ self.avgpool = nn.AvgPool2d(7, stride=1)
116
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
117
+
118
+ for m in self.modules():
119
+ if isinstance(m, nn.Conv2d):
120
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
121
+ m.weight.data.normal_(0, math.sqrt(2. / n))
122
+ elif isinstance(m, BatchNorm2d):
123
+ m.weight.data.fill_(1)
124
+ m.bias.data.zero_()
125
+
126
+ def _make_layer(self, block, planes, blocks, stride=1):
127
+ downsample = None
128
+ if stride != 1 or self.inplanes != planes * block.expansion:
129
+ downsample = nn.Sequential(
130
+ nn.Conv2d(self.inplanes, planes * block.expansion,
131
+ kernel_size=1, stride=stride, bias=False),
132
+ BatchNorm2d(planes * block.expansion),
133
+ )
134
+
135
+ layers = []
136
+ layers.append(block(self.inplanes, planes, stride, downsample))
137
+ self.inplanes = planes * block.expansion
138
+ for i in range(1, blocks):
139
+ layers.append(block(self.inplanes, planes))
140
+
141
+ return nn.Sequential(*layers)
142
+
143
+ def forward(self, x):
144
+ x = self.relu1(self.bn1(self.conv1(x)))
145
+ x = self.relu2(self.bn2(self.conv2(x)))
146
+ x = self.relu3(self.bn3(self.conv3(x)))
147
+ x = self.maxpool(x)
148
+
149
+ x = self.layer1(x)
150
+ x = self.layer2(x)
151
+ x = self.layer3(x)
152
+ x = self.layer4(x)
153
+
154
+ x = self.avgpool(x)
155
+ x = x.view(x.size(0), -1)
156
+ x = self.fc(x)
157
+
158
+ return x
159
+
160
+
161
+ def resnet50(pretrained=False, **kwargs):
162
+ """Constructs a ResNet-50 model.
163
+
164
+ Args:
165
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
166
+ """
167
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
168
+ if pretrained:
169
+ model.load_state_dict(load_url(model_urls['resnet50']), strict=False)
170
+ return model
171
+
172
+
173
+ def resnet18(pretrained=False, **kwargs):
174
+ """Constructs a ResNet-18 model.
175
+ Args:
176
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
177
+ """
178
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
179
+ if pretrained:
180
+ model.load_state_dict(load_url(model_urls['resnet18']))
181
+ return model
models/ade20k/segm_lib/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/ade20k/segm_lib/nn/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/ade20k/segm_lib/nn/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .modules import *
2
+ from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
models/ade20k/segm_lib/nn/modules/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12
+ from .replicate import DataParallelWithCallback, patch_replication_callback
models/ade20k/segm_lib/nn/modules/batchnorm.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ from torch.nn.modules.batchnorm import _BatchNorm
17
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18
+
19
+ from .comm import SyncMaster
20
+
21
+ __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22
+
23
+
24
+ def _sum_ft(tensor):
25
+ """sum over the first and last dimention"""
26
+ return tensor.sum(dim=0).sum(dim=-1)
27
+
28
+
29
+ def _unsqueeze_ft(tensor):
30
+ """add new dementions at the front and the tail"""
31
+ return tensor.unsqueeze(0).unsqueeze(-1)
32
+
33
+
34
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36
+
37
+
38
+ class _SynchronizedBatchNorm(_BatchNorm):
39
+ def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
40
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41
+
42
+ self._sync_master = SyncMaster(self._data_parallel_master)
43
+
44
+ self._is_parallel = False
45
+ self._parallel_id = None
46
+ self._slave_pipe = None
47
+
48
+ # customed batch norm statistics
49
+ self._moving_average_fraction = 1. - momentum
50
+ self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features))
51
+ self.register_buffer('_tmp_running_var', torch.ones(self.num_features))
52
+ self.register_buffer('_running_iter', torch.ones(1))
53
+ self._tmp_running_mean = self.running_mean.clone() * self._running_iter
54
+ self._tmp_running_var = self.running_var.clone() * self._running_iter
55
+
56
+ def forward(self, input):
57
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
58
+ if not (self._is_parallel and self.training):
59
+ return F.batch_norm(
60
+ input, self.running_mean, self.running_var, self.weight, self.bias,
61
+ self.training, self.momentum, self.eps)
62
+
63
+ # Resize the input to (B, C, -1).
64
+ input_shape = input.size()
65
+ input = input.view(input.size(0), self.num_features, -1)
66
+
67
+ # Compute the sum and square-sum.
68
+ sum_size = input.size(0) * input.size(2)
69
+ input_sum = _sum_ft(input)
70
+ input_ssum = _sum_ft(input ** 2)
71
+
72
+ # Reduce-and-broadcast the statistics.
73
+ if self._parallel_id == 0:
74
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
75
+ else:
76
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
77
+
78
+ # Compute the output.
79
+ if self.affine:
80
+ # MJY:: Fuse the multiplication for speed.
81
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
82
+ else:
83
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
84
+
85
+ # Reshape it.
86
+ return output.view(input_shape)
87
+
88
+ def __data_parallel_replicate__(self, ctx, copy_id):
89
+ self._is_parallel = True
90
+ self._parallel_id = copy_id
91
+
92
+ # parallel_id == 0 means master device.
93
+ if self._parallel_id == 0:
94
+ ctx.sync_master = self._sync_master
95
+ else:
96
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
97
+
98
+ def _data_parallel_master(self, intermediates):
99
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
100
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
101
+
102
+ to_reduce = [i[1][:2] for i in intermediates]
103
+ to_reduce = [j for i in to_reduce for j in i] # flatten
104
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
105
+
106
+ sum_size = sum([i[1].sum_size for i in intermediates])
107
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
108
+
109
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
110
+
111
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
112
+
113
+ outputs = []
114
+ for i, rec in enumerate(intermediates):
115
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
116
+
117
+ return outputs
118
+
119
+ def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0):
120
+ """return *dest* by `dest := dest*alpha + delta*beta + bias`"""
121
+ return dest * alpha + delta * beta + bias
122
+
123
+ def _compute_mean_std(self, sum_, ssum, size):
124
+ """Compute the mean and standard-deviation with sum and square-sum. This method
125
+ also maintains the moving average on the master device."""
126
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
127
+ mean = sum_ / size
128
+ sumvar = ssum - sum_ * mean
129
+ unbias_var = sumvar / (size - 1)
130
+ bias_var = sumvar / size
131
+
132
+ self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction)
133
+ self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction)
134
+ self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction)
135
+
136
+ self.running_mean = self._tmp_running_mean / self._running_iter
137
+ self.running_var = self._tmp_running_var / self._running_iter
138
+
139
+ return mean, bias_var.clamp(self.eps) ** -0.5
140
+
141
+
142
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
143
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
144
+ mini-batch.
145
+
146
+ .. math::
147
+
148
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
149
+
150
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
151
+ standard-deviation are reduced across all devices during training.
152
+
153
+ For example, when one uses `nn.DataParallel` to wrap the network during
154
+ training, PyTorch's implementation normalize the tensor on each device using
155
+ the statistics only on that device, which accelerated the computation and
156
+ is also easy to implement, but the statistics might be inaccurate.
157
+ Instead, in this synchronized version, the statistics will be computed
158
+ over all training samples distributed on multiple devices.
159
+
160
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
161
+ as the built-in PyTorch implementation.
162
+
163
+ The mean and standard-deviation are calculated per-dimension over
164
+ the mini-batches and gamma and beta are learnable parameter vectors
165
+ of size C (where C is the input size).
166
+
167
+ During training, this layer keeps a running estimate of its computed mean
168
+ and variance. The running sum is kept with a default momentum of 0.1.
169
+
170
+ During evaluation, this running mean/variance is used for normalization.
171
+
172
+ Because the BatchNorm is done over the `C` dimension, computing statistics
173
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
174
+
175
+ Args:
176
+ num_features: num_features from an expected input of size
177
+ `batch_size x num_features [x width]`
178
+ eps: a value added to the denominator for numerical stability.
179
+ Default: 1e-5
180
+ momentum: the value used for the running_mean and running_var
181
+ computation. Default: 0.1
182
+ affine: a boolean value that when set to ``True``, gives the layer learnable
183
+ affine parameters. Default: ``True``
184
+
185
+ Shape:
186
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
187
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
188
+
189
+ Examples:
190
+ >>> # With Learnable Parameters
191
+ >>> m = SynchronizedBatchNorm1d(100)
192
+ >>> # Without Learnable Parameters
193
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
194
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
195
+ >>> output = m(input)
196
+ """
197
+
198
+ def _check_input_dim(self, input):
199
+ if input.dim() != 2 and input.dim() != 3:
200
+ raise ValueError('expected 2D or 3D input (got {}D input)'
201
+ .format(input.dim()))
202
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
203
+
204
+
205
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
206
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
207
+ of 3d inputs
208
+
209
+ .. math::
210
+
211
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
212
+
213
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
214
+ standard-deviation are reduced across all devices during training.
215
+
216
+ For example, when one uses `nn.DataParallel` to wrap the network during
217
+ training, PyTorch's implementation normalize the tensor on each device using
218
+ the statistics only on that device, which accelerated the computation and
219
+ is also easy to implement, but the statistics might be inaccurate.
220
+ Instead, in this synchronized version, the statistics will be computed
221
+ over all training samples distributed on multiple devices.
222
+
223
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
224
+ as the built-in PyTorch implementation.
225
+
226
+ The mean and standard-deviation are calculated per-dimension over
227
+ the mini-batches and gamma and beta are learnable parameter vectors
228
+ of size C (where C is the input size).
229
+
230
+ During training, this layer keeps a running estimate of its computed mean
231
+ and variance. The running sum is kept with a default momentum of 0.1.
232
+
233
+ During evaluation, this running mean/variance is used for normalization.
234
+
235
+ Because the BatchNorm is done over the `C` dimension, computing statistics
236
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
237
+
238
+ Args:
239
+ num_features: num_features from an expected input of
240
+ size batch_size x num_features x height x width
241
+ eps: a value added to the denominator for numerical stability.
242
+ Default: 1e-5
243
+ momentum: the value used for the running_mean and running_var
244
+ computation. Default: 0.1
245
+ affine: a boolean value that when set to ``True``, gives the layer learnable
246
+ affine parameters. Default: ``True``
247
+
248
+ Shape:
249
+ - Input: :math:`(N, C, H, W)`
250
+ - Output: :math:`(N, C, H, W)` (same shape as input)
251
+
252
+ Examples:
253
+ >>> # With Learnable Parameters
254
+ >>> m = SynchronizedBatchNorm2d(100)
255
+ >>> # Without Learnable Parameters
256
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
257
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
258
+ >>> output = m(input)
259
+ """
260
+
261
+ def _check_input_dim(self, input):
262
+ if input.dim() != 4:
263
+ raise ValueError('expected 4D input (got {}D input)'
264
+ .format(input.dim()))
265
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
266
+
267
+
268
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
269
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
270
+ of 4d inputs
271
+
272
+ .. math::
273
+
274
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
275
+
276
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
277
+ standard-deviation are reduced across all devices during training.
278
+
279
+ For example, when one uses `nn.DataParallel` to wrap the network during
280
+ training, PyTorch's implementation normalize the tensor on each device using
281
+ the statistics only on that device, which accelerated the computation and
282
+ is also easy to implement, but the statistics might be inaccurate.
283
+ Instead, in this synchronized version, the statistics will be computed
284
+ over all training samples distributed on multiple devices.
285
+
286
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
287
+ as the built-in PyTorch implementation.
288
+
289
+ The mean and standard-deviation are calculated per-dimension over
290
+ the mini-batches and gamma and beta are learnable parameter vectors
291
+ of size C (where C is the input size).
292
+
293
+ During training, this layer keeps a running estimate of its computed mean
294
+ and variance. The running sum is kept with a default momentum of 0.1.
295
+
296
+ During evaluation, this running mean/variance is used for normalization.
297
+
298
+ Because the BatchNorm is done over the `C` dimension, computing statistics
299
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
300
+ or Spatio-temporal BatchNorm
301
+
302
+ Args:
303
+ num_features: num_features from an expected input of
304
+ size batch_size x num_features x depth x height x width
305
+ eps: a value added to the denominator for numerical stability.
306
+ Default: 1e-5
307
+ momentum: the value used for the running_mean and running_var
308
+ computation. Default: 0.1
309
+ affine: a boolean value that when set to ``True``, gives the layer learnable
310
+ affine parameters. Default: ``True``
311
+
312
+ Shape:
313
+ - Input: :math:`(N, C, D, H, W)`
314
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
315
+
316
+ Examples:
317
+ >>> # With Learnable Parameters
318
+ >>> m = SynchronizedBatchNorm3d(100)
319
+ >>> # Without Learnable Parameters
320
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
321
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
322
+ >>> output = m(input)
323
+ """
324
+
325
+ def _check_input_dim(self, input):
326
+ if input.dim() != 5:
327
+ raise ValueError('expected 5D input (got {}D input)'
328
+ .format(input.dim()))
329
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
models/ade20k/segm_lib/nn/modules/comm.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import queue
12
+ import collections
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+
59
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62
+ and passed to a registered callback.
63
+ - After receiving the messages, the master device should gather the information and determine to message passed
64
+ back to each slave devices.
65
+ """
66
+
67
+ def __init__(self, master_callback):
68
+ """
69
+
70
+ Args:
71
+ master_callback: a callback to be invoked after having collected messages from slave devices.
72
+ """
73
+ self._master_callback = master_callback
74
+ self._queue = queue.Queue()
75
+ self._registry = collections.OrderedDict()
76
+ self._activated = False
77
+
78
+ def register_slave(self, identifier):
79
+ """
80
+ Register an slave device.
81
+
82
+ Args:
83
+ identifier: an identifier, usually is the device id.
84
+
85
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
86
+
87
+ """
88
+ if self._activated:
89
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
90
+ self._activated = False
91
+ self._registry.clear()
92
+ future = FutureResult()
93
+ self._registry[identifier] = _MasterRegistry(future)
94
+ return SlavePipe(identifier, self._queue, future)
95
+
96
+ def run_master(self, master_msg):
97
+ """
98
+ Main entry for the master device in each forward pass.
99
+ The messages were first collected from each devices (including the master device), and then
100
+ an callback will be invoked to compute the message to be sent back to each devices
101
+ (including the master device).
102
+
103
+ Args:
104
+ master_msg: the message that the master want to send to itself. This will be placed as the first
105
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106
+
107
+ Returns: the message to be sent back to the master device.
108
+
109
+ """
110
+ self._activated = True
111
+
112
+ intermediates = [(0, master_msg)]
113
+ for i in range(self.nr_slaves):
114
+ intermediates.append(self._queue.get())
115
+
116
+ results = self._master_callback(intermediates)
117
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
118
+
119
+ for i, res in results:
120
+ if i == 0:
121
+ continue
122
+ self._registry[i].result.put(res)
123
+
124
+ for i in range(self.nr_slaves):
125
+ assert self._queue.get() is True
126
+
127
+ return results[0][1]
128
+
129
+ @property
130
+ def nr_slaves(self):
131
+ return len(self._registry)
models/ade20k/segm_lib/nn/modules/replicate.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : replicate.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import functools
12
+
13
+ from torch.nn.parallel.data_parallel import DataParallel
14
+
15
+ __all__ = [
16
+ 'CallbackContext',
17
+ 'execute_replication_callbacks',
18
+ 'DataParallelWithCallback',
19
+ 'patch_replication_callback'
20
+ ]
21
+
22
+
23
+ class CallbackContext(object):
24
+ pass
25
+
26
+
27
+ def execute_replication_callbacks(modules):
28
+ """
29
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30
+
31
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32
+
33
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
34
+ (shared among multiple copies of this module on different devices).
35
+ Through this context, different copies can share some information.
36
+
37
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38
+ of any slave copies.
39
+ """
40
+ master_copy = modules[0]
41
+ nr_modules = len(list(master_copy.modules()))
42
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
43
+
44
+ for i, module in enumerate(modules):
45
+ for j, m in enumerate(module.modules()):
46
+ if hasattr(m, '__data_parallel_replicate__'):
47
+ m.__data_parallel_replicate__(ctxs[j], i)
48
+
49
+
50
+ class DataParallelWithCallback(DataParallel):
51
+ """
52
+ Data Parallel with a replication callback.
53
+
54
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55
+ original `replicate` function.
56
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57
+
58
+ Examples:
59
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61
+ # sync_bn.__data_parallel_replicate__ will be invoked.
62
+ """
63
+
64
+ def replicate(self, module, device_ids):
65
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66
+ execute_replication_callbacks(modules)
67
+ return modules
68
+
69
+
70
+ def patch_replication_callback(data_parallel):
71
+ """
72
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
73
+ Useful when you have customized `DataParallel` implementation.
74
+
75
+ Examples:
76
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78
+ > patch_replication_callback(sync_bn)
79
+ # this is equivalent to
80
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82
+ """
83
+
84
+ assert isinstance(data_parallel, DataParallel)
85
+
86
+ old_replicate = data_parallel.replicate
87
+
88
+ @functools.wraps(old_replicate)
89
+ def new_replicate(module, device_ids):
90
+ modules = old_replicate(module, device_ids)
91
+ execute_replication_callbacks(modules)
92
+ return modules
93
+
94
+ data_parallel.replicate = new_replicate
models/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : test_numeric_batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+
9
+ import unittest
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.autograd import Variable
14
+
15
+ from sync_batchnorm.unittest import TorchTestCase
16
+
17
+
18
+ def handy_var(a, unbias=True):
19
+ n = a.size(0)
20
+ asum = a.sum(dim=0)
21
+ as_sum = (a ** 2).sum(dim=0) # a square sum
22
+ sumvar = as_sum - asum * asum / n
23
+ if unbias:
24
+ return sumvar / (n - 1)
25
+ else:
26
+ return sumvar / n
27
+
28
+
29
+ class NumericTestCase(TorchTestCase):
30
+ def testNumericBatchNorm(self):
31
+ a = torch.rand(16, 10)
32
+ bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
33
+ bn.train()
34
+
35
+ a_var1 = Variable(a, requires_grad=True)
36
+ b_var1 = bn(a_var1)
37
+ loss1 = b_var1.sum()
38
+ loss1.backward()
39
+
40
+ a_var2 = Variable(a, requires_grad=True)
41
+ a_mean2 = a_var2.mean(dim=0, keepdim=True)
42
+ a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
43
+ # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
44
+ b_var2 = (a_var2 - a_mean2) / a_std2
45
+ loss2 = b_var2.sum()
46
+ loss2.backward()
47
+
48
+ self.assertTensorClose(bn.running_mean, a.mean(dim=0))
49
+ self.assertTensorClose(bn.running_var, handy_var(a))
50
+ self.assertTensorClose(a_var1.data, a_var2.data)
51
+ self.assertTensorClose(b_var1.data, b_var2.data)
52
+ self.assertTensorClose(a_var1.grad, a_var2.grad)
53
+
54
+
55
+ if __name__ == '__main__':
56
+ unittest.main()
models/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : test_sync_batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+
9
+ import unittest
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.autograd import Variable
14
+
15
+ from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
16
+ from sync_batchnorm.unittest import TorchTestCase
17
+
18
+
19
+ def handy_var(a, unbias=True):
20
+ n = a.size(0)
21
+ asum = a.sum(dim=0)
22
+ as_sum = (a ** 2).sum(dim=0) # a square sum
23
+ sumvar = as_sum - asum * asum / n
24
+ if unbias:
25
+ return sumvar / (n - 1)
26
+ else:
27
+ return sumvar / n
28
+
29
+
30
+ def _find_bn(module):
31
+ for m in module.modules():
32
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
33
+ return m
34
+
35
+
36
+ class SyncTestCase(TorchTestCase):
37
+ def _syncParameters(self, bn1, bn2):
38
+ bn1.reset_parameters()
39
+ bn2.reset_parameters()
40
+ if bn1.affine and bn2.affine:
41
+ bn2.weight.data.copy_(bn1.weight.data)
42
+ bn2.bias.data.copy_(bn1.bias.data)
43
+
44
+ def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
45
+ """Check the forward and backward for the customized batch normalization."""
46
+ bn1.train(mode=is_train)
47
+ bn2.train(mode=is_train)
48
+
49
+ if cuda:
50
+ input = input.cuda()
51
+
52
+ self._syncParameters(_find_bn(bn1), _find_bn(bn2))
53
+
54
+ input1 = Variable(input, requires_grad=True)
55
+ output1 = bn1(input1)
56
+ output1.sum().backward()
57
+ input2 = Variable(input, requires_grad=True)
58
+ output2 = bn2(input2)
59
+ output2.sum().backward()
60
+
61
+ self.assertTensorClose(input1.data, input2.data)
62
+ self.assertTensorClose(output1.data, output2.data)
63
+ self.assertTensorClose(input1.grad, input2.grad)
64
+ self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
65
+ self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
66
+
67
+ def testSyncBatchNormNormalTrain(self):
68
+ bn = nn.BatchNorm1d(10)
69
+ sync_bn = SynchronizedBatchNorm1d(10)
70
+
71
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
72
+
73
+ def testSyncBatchNormNormalEval(self):
74
+ bn = nn.BatchNorm1d(10)
75
+ sync_bn = SynchronizedBatchNorm1d(10)
76
+
77
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
78
+
79
+ def testSyncBatchNormSyncTrain(self):
80
+ bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
81
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
82
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
83
+
84
+ bn.cuda()
85
+ sync_bn.cuda()
86
+
87
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
88
+
89
+ def testSyncBatchNormSyncEval(self):
90
+ bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
91
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
92
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
93
+
94
+ bn.cuda()
95
+ sync_bn.cuda()
96
+
97
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
98
+
99
+ def testSyncBatchNorm2DSyncTrain(self):
100
+ bn = nn.BatchNorm2d(10)
101
+ sync_bn = SynchronizedBatchNorm2d(10)
102
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
103
+
104
+ bn.cuda()
105
+ sync_bn.cuda()
106
+
107
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
108
+
109
+
110
+ if __name__ == '__main__':
111
+ unittest.main()
models/ade20k/segm_lib/nn/modules/unittest.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : unittest.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import unittest
12
+
13
+ import numpy as np
14
+ from torch.autograd import Variable
15
+
16
+
17
+ def as_numpy(v):
18
+ if isinstance(v, Variable):
19
+ v = v.data
20
+ return v.cpu().numpy()
21
+
22
+
23
+ class TorchTestCase(unittest.TestCase):
24
+ def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25
+ npa, npb = as_numpy(a), as_numpy(b)
26
+ self.assertTrue(
27
+ np.allclose(npa, npb, atol=atol),
28
+ 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29
+ )
models/ade20k/segm_lib/nn/parallel/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to