victorisgeek commited on
Commit
10ca7b7
1 Parent(s): 059b53a

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ .idea/
161
+ .vscode
162
+ .DS_Store
README.md CHANGED
@@ -1,12 +1,50 @@
1
- ---
2
- title: Vrtclothes
3
- emoji: 💻
4
- colorFrom: red
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 4.44.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cloths Virtual Try On
2
+ [![Open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SwayamInSync/clothes-virtual-try-on/blob/main/setup_gradio.ipynb)
3
+
4
+ ## Updates
5
+ - **[19/02/2024] From now on this repo won't receive any future updates from my side (Spoiler: It's not gone for good 😉. Expect its return, stronger than ever.) (Community Contributions & Issues discussions are still welcome 🤗)**
6
+ - [26/12/2023] Added the Gradio interface and removed all the external depenency
7
+ - [19/12/2023] Fixed the `openpose` installation and missing model weights issue
8
+ - [19/12/2023] Replaced the `remove.bg` dependecy with `rembg`
9
+ - [26/04/2023] Fixed the GAN generation issue
10
+
11
+ ## Star History
12
+ [![Star History Chart](https://api.star-history.com/svg?repos=SwayamInSync/clothes-virtual-try-on&type=Date)](https://star-history.com/#SwayamInSync/clothes-virtual-try-on&Date)
13
+
14
+ ## Table of contents
15
+ - [Cloths Virtual Try On](#cloths-virtual-try-on)
16
+ - [Table of contents](#table-of-contents)
17
+ - [General info](#general-info)
18
+ - [Demo](#demo)
19
+ - [Block Diagram](#block-diagram)
20
+ - [Methodology](#methodology)
21
+ - [Usage](#usage)
22
+ - [Citation](#citation)
23
+
24
+ ## General info
25
+
26
+ This project is a part of a crework community project. While buying clothes online, it is difficult for a customer to select a desirable outfit in the first attempt because they can’t try on clothes. This project aims to solve this problem.
27
+
28
+ <img width="383" alt="general_info" src="https://user-images.githubusercontent.com/63489382/163923011-c2898812-2491-4ec2-beb7-dcaaaf680e4f.png">
29
+
30
+
31
+ ## Demo
32
+
33
+ https://user-images.githubusercontent.com/63489382/163922795-5dbb0f52-95e4-42c6-95d7-2d965abeba6d.mp4
34
+
35
+
36
+
37
+ ## Block Diagram
38
+ ![block_diagram_whole](https://user-images.githubusercontent.com/63489382/163922947-c1677f79-ad6f-4550-affc-7d4e80f0d247.png)
39
+
40
+
41
+ ## Methodology
42
+ ![block_diagram_detailed](https://user-images.githubusercontent.com/63489382/163922991-86d148c2-1a97-48a5-b4ec-d8c16819374a.png)
43
+
44
+
45
+ ## Usage
46
+ - Just Click on `open in colab` button on top of this README file
47
+
48
+
49
+ ## Citation
50
+ **Work in progress**
assets/cloth/01260_00.jpg ADDED
assets/cloth/01430_00.jpg ADDED
assets/cloth/02783_00.jpg ADDED
assets/cloth/03751_00.jpg ADDED
assets/cloth/06429_00.jpg ADDED
assets/cloth/06802_00.jpg ADDED
assets/cloth/07429_00.jpg ADDED
assets/cloth/08348_00.jpg ADDED
assets/cloth/09933_00.jpg ADDED
assets/cloth/11028_00.jpg ADDED
assets/cloth/11351_00.jpg ADDED
assets/cloth/11791_00.jpg ADDED
assets/image/00891_00.jpg ADDED
assets/image/03615_00.jpg ADDED
assets/image/07445_00.jpg ADDED
assets/image/07573_00.jpg ADDED
assets/image/08909_00.jpg ADDED
assets/image/10549_00.jpg ADDED
client-side/app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, render_template
2
+ from PIL import Image
3
+ import requests
4
+ from io import BytesIO
5
+ import base64
6
+
7
+ app = Flask(__name__)
8
+
9
+
10
+ @app.route('/')
11
+ def home():
12
+ return render_template("index.html")
13
+
14
+
15
+ @app.route("/preds", methods=['POST'])
16
+ def submit():
17
+ cloth = request.files['cloth']
18
+ model = request.files['model']
19
+
20
+ ## replace the url from the ngrok url provided on the notebook on server.
21
+ url = "http://e793-34-123-73-186.ngrok-free.app/api/transform"
22
+ print("sending")
23
+ response = requests.post(url=url, files={"cloth":cloth.stream, "model":model.stream})
24
+ op = Image.open(BytesIO(response.content))
25
+
26
+ buffer = BytesIO()
27
+ op.save(buffer, 'png')
28
+ buffer.seek(0)
29
+
30
+ data = buffer.read()
31
+ data = base64.b64encode(data).decode()
32
+
33
+
34
+ return render_template('index.html', op=data)
35
+ # return render_template('index.html', test=True)
36
+
37
+ if __name__ == '__main__':
38
+ app.run(debug=True)
client-side/static/css/style.css ADDED
File without changes
client-side/static/images/logo.png ADDED
client-side/static/output/dog.png ADDED
client-side/templates/index.html ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html>
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+
8
+ <script src="https://cdn.tailwindcss.com"></script>
9
+ </head>
10
+
11
+ <body>
12
+
13
+ <!-- <h1>Welcome to your own virtual clothing assistant</h1>
14
+ <form action="{{ url_for('submit') }}" method="post" enctype="multipart/form-data">
15
+ <input type="file" name="cloth">
16
+ <input type="file" name="model">
17
+ <button type="submit">Submit</button>
18
+ </form> -->
19
+
20
+
21
+
22
+ <header class="text-gray-400 bg-gray-900 body-font">
23
+ <div class="container mx-auto flex flex-wrap p-5 flex-col md:flex-row items-center">
24
+ <a class="flex title-font font-medium items-center text-white mb-4 md:mb-0">
25
+
26
+ <img src="{{url_for('static', filename='images/logo.png')}}" height="50" width="50" style="border-radius: 50%;"/>
27
+
28
+ <span class="ml-3 text-xl">Virtual Cloth Assistant</span>
29
+ </a>
30
+
31
+
32
+ </div>
33
+ </header>
34
+
35
+
36
+ <section class="text-gray-400 bg-gray-900 body-font">
37
+ <div class="container px-5 py-24 mx-auto">
38
+ <div class="text-center mb-20">
39
+ <h1 class="sm:text-3xl text-2xl font-medium title-font text-white mb-4">Virtual Cloth Assistant</h1>
40
+ <p class="text-base leading-relaxed xl:w-2/4 lg:w-3/4 mx-auto text-gray-400 text-opacity-80">Wanna try out, How that cloth suits you ?
41
+ <br>
42
+ Upgrade your shopping experience with an intelligent trial room.
43
+ <br> Check out our API and get your wish fulfilled in seconds!!</p>
44
+ <div class="flex mt-6 justify-center">
45
+ <div class="w-16 h-1 rounded-full bg-blue-500 inline-flex"></div>
46
+ </div>
47
+ </div>
48
+ <div class="flex flex-wrap sm:-m-4 -mx-4 -mb-10 -mt-4 md:space-y-0 space-y-6">
49
+ <div class="p-4 md:w-1/3 flex flex-col text-center items-center">
50
+ <div
51
+ class="w-20 h-20 inline-flex items-center justify-center rounded-full bg-gray-800 text-blue-400 mb-5 flex-shrink-0">
52
+ <svg fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round"
53
+ stroke-width="2" class="w-10 h-10" viewBox="0 0 24 24">
54
+ <path d="M22 12h-4l-3 9L9 3l-3 9H2"></path>
55
+ </svg>
56
+ </div>
57
+ <div class="flex-grow">
58
+ <h2 class="text-white text-lg title-font font-medium mb-3">The Problem</h2>
59
+ <p class="leading-relaxed text-base">While buying clothes online, it is difficult for a customer to select a desirable outfit in the first attempt because they can’t try on clothes before they are delivered physically.
60
+
61
+ </p>
62
+
63
+ </div>
64
+ </div>
65
+ <div class="p-4 md:w-1/3 flex flex-col text-center items-center">
66
+ <div
67
+ class="w-20 h-20 inline-flex items-center justify-center rounded-full bg-gray-800 text-blue-400 mb-5 flex-shrink-0">
68
+ <svg fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round"
69
+ stroke-width="2" class="w-10 h-10" viewBox="0 0 24 24">
70
+ <circle cx="6" cy="6" r="3"></circle>
71
+ <circle cx="6" cy="18" r="3"></circle>
72
+ <path d="M20 4L8.12 15.88M14.47 14.48L20 20M8.12 8.12L12 12"></path>
73
+ </svg>
74
+ </div>
75
+ <div class="flex-grow">
76
+ <h2 class="text-white text-lg title-font font-medium mb-3">The Solution</h2>
77
+ <p class="leading-relaxed text-base">E-commerce websites can be equipped with virtual trial rooms that allow users to try on multiple clothes virtually and select the best looking outfit in a single attempt.
78
+
79
+ </p>
80
+
81
+ </div>
82
+ </div>
83
+ <div class="p-4 md:w-1/3 flex flex-col text-center items-center">
84
+ <div
85
+ class="w-20 h-20 inline-flex items-center justify-center rounded-full bg-gray-800 text-blue-400 mb-5 flex-shrink-0">
86
+ <svg fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round"
87
+ stroke-width="2" class="w-10 h-10" viewBox="0 0 24 24">
88
+ <path d="M20 21v-2a4 4 0 00-4-4H8a4 4 0 00-4 4v2"></path>
89
+ <circle cx="12" cy="7" r="4"></circle>
90
+ </svg>
91
+ </div>
92
+ <div class="flex-grow">
93
+ <h2 class="text-white text-lg title-font font-medium mb-3">The Approach</h2>
94
+ <p class="leading-relaxed text-base">
95
+ We used Deep Learning to solve this problem. VCA (virtual clothing assistant) for consumers, where
96
+ user can select the cloth he/she wants to wear and then upload his/her image of any pose they want and VCA is capable to
97
+ dress that human with his/her selected cloth.
98
+ </p>
99
+
100
+ </div>
101
+ </div>
102
+ </div>
103
+ <!-- <button
104
+ class="flex mx-auto mt-16 text-white bg-blue-500 border-0 py-2 px-8 focus:outline-none hover:bg-blue-600 rounded text-lg">Button</button> -->
105
+ </div>
106
+ </section>
107
+
108
+
109
+
110
+ <section class="text-gray-400 bg-gray-900 body-font">
111
+ <form action="{{ url_for('submit') }}" method="post" enctype="multipart/form-data">
112
+ <div class="container mx-auto flex flex-col px-5 py-24 justify-center items-center">
113
+
114
+ <div class="flex flex-wrap -m-2">
115
+ <div class="p-1 xl:w-1/2 md:w-1/2 w-full">
116
+ <center><label class="block text-lr font-medium text-white-700"> Cloth Image </label></center>
117
+ <div
118
+ class="mt-1 flex justify-center px-6 pt-5 pb-6 border-2 border-gray-300 border-dashed rounded-md">
119
+ <div class="space-y-1 text-center">
120
+ <svg class="mx-auto h-12 w-12 text-gray-400" stroke="currentColor" fill="none"
121
+ viewBox="0 0 48 48" aria-hidden="true">
122
+ <path
123
+ d="M28 8H12a4 4 0 00-4 4v20m32-12v8m0 0v8a4 4 0 01-4 4H12a4 4 0 01-4-4v-4m32-4l-3.172-3.172a4 4 0 00-5.656 0L28 28M8 32l9.172-9.172a4 4 0 015.656 0L28 28m0 0l4 4m4-24h8m-4-4v8m-12 4h.02"
124
+ stroke-width="2" stroke-linecap="round" stroke-linejoin="round" />
125
+ </svg>
126
+ <div class="flex text-sm text-gray-600">
127
+ <!-- <label for="file-upload"
128
+ class="relative cursor-pointer rounded-md font-medium text-indigo-600 hover:text-indigo-500 focus-within:outline-none focus-within:ring-2 focus-within:ring-offset-2 focus-within:ring-indigo-500">
129
+ <span>Upload a file</span> -->
130
+ <input class="block w-full text-sm text-slate-500
131
+ file:mr-4 file:py-2 file:px-4
132
+ file:rounded-full file:border-0
133
+ file:text-sm file:font-semibold
134
+ file:bg-violet-50 file:text-violet-700
135
+ hover:file:bg-violet-100" id="file-upload" type="file" name="cloth" class="sr-only">
136
+ </label>
137
+ <p class="pl-1">or drag and drop</p>
138
+ </div>
139
+ <p class="text-xs text-gray-500">PNG, JPG up to 10MB</p>
140
+ </div>
141
+ </div>
142
+
143
+ </div>
144
+ <div class="p-1 xl:w-1/2 md:w-1/2 w-full">
145
+ <center> <label class="block text-lr font-medium text-white-700"> Model Image </label></center>
146
+ <div
147
+ class="mt-1 flex justify-center px-6 pt-5 pb-6 border-2 border-gray-300 border-dashed rounded-md">
148
+ <div class="space-y-1 text-center">
149
+ <svg class="mx-auto h-12 w-12 text-gray-400" stroke="currentColor" fill="none"
150
+ viewBox="0 0 48 48" aria-hidden="true">
151
+ <path
152
+ d="M28 8H12a4 4 0 00-4 4v20m32-12v8m0 0v8a4 4 0 01-4 4H12a4 4 0 01-4-4v-4m32-4l-3.172-3.172a4 4 0 00-5.656 0L28 28M8 32l9.172-9.172a4 4 0 015.656 0L28 28m0 0l4 4m4-24h8m-4-4v8m-12 4h.02"
153
+ stroke-width="2" stroke-linecap="round" stroke-linejoin="round" />
154
+ </svg>
155
+ <div class="flex text-sm text-gray-600">
156
+ <!-- <label for="file-upload"
157
+ class="relative cursor-pointer rounded-md font-medium text-indigo-600 hover:text-indigo-500 focus-within:outline-none focus-within:ring-2 focus-within:ring-offset-2 focus-within:ring-indigo-500">
158
+ <span>Upload a file</span> -->
159
+ <input class="block w-full text-sm text-slate-500
160
+ file:mr-4 file:py-2 file:px-4
161
+ file:rounded-full file:border-0
162
+ file:text-sm file:font-semibold
163
+ file:bg-violet-50 file:text-violet-700
164
+ hover:file:bg-violet-100" id="file-upload" type="file" name="model" class="sr-only">
165
+ </label>
166
+ <p class="pl-1">or drag and drop</p>
167
+ </div>
168
+ <p class="text-xs text-gray-500">PNG, JPG up to 10MB</p>
169
+ </div>
170
+ </div>
171
+ </div>
172
+ </div>
173
+
174
+ <br>
175
+ <br>
176
+
177
+ <div class="w-full md:w-2/3 flex flex-col mb-16 items-center text-center">
178
+ <h1 class="title-font sm:text-4xl text-3xl mb-4 font-medium text-white">Upload corresponding images
179
+ and get the Result</h1>
180
+ <p class="mb-8 leading-relaxed"> Need to wait for 20-30 seconds so chill and drink some water....
181
+ </p>
182
+ <div class="flex w-full justify-center items-end">
183
+
184
+ <button type="submit"
185
+ class="inline-flex text-white bg-blue-500 border-0 py-2 px-6 focus:outline-none hover:bg-blue-600 rounded text-lg">Try
186
+ it</button>
187
+ </div>
188
+
189
+ </button>
190
+ </div>
191
+
192
+
193
+ </div>
194
+
195
+ <div>
196
+ {% if op %}
197
+ <center style="color: white; font-size: x-large;">HERE'S IS YOUR RESULT 🤗</center>
198
+
199
+ <center>
200
+ <div class="sm: w-3/4 mb-10 lg:mb-0 rounded-lg overflow-hidden">
201
+ <img alt="output" class="object-cover object-center h-2/4 w-2/4"
202
+ src="data:image/png;base64,{{ op }}">
203
+ </div>
204
+ </center>
205
+ {% endif %}
206
+ </div>
207
+ </form>
208
+
209
+
210
+ </section>
211
+
212
+
213
+ <footer class="text-gray-400 bg-gray-900 body-font">
214
+ <div
215
+ class="container px-5 py-24 mx-auto flex md:items-center lg:items-start md:flex-row md:flex-nowrap flex-wrap flex-col">
216
+ <div class="w-64 flex-shrink-0 md:mx-0 mx-auto text-center md:text-left md:mt-0 mt-10">
217
+ <a class="flex title-font font-medium items-center md:justify-start justify-center text-white">
218
+ <img src="{{url_for('static', filename='images/logo.png')}}" height="50" width="50" style="border-radius: 50%;"/>
219
+ <span class="ml-3 text-xl">V-Cloth Assistant</span>
220
+ </a>
221
+ </div>
222
+ <div class="flex-grow flex flex-wrap md:pr-20 -mb-10 md:text-left text-center order-first">
223
+ <div class="lg:w-1/4 md:w-1/2 w-full px-4">
224
+ <h2 class="title-font font-medium text-white tracking-widest text-sm mb-3">SWAYAM</h2>
225
+ </div>
226
+ <div class="lg:w-1/4 md:w-1/2 w-full px-4">
227
+ <h2 class="title-font font-medium text-white tracking-widest text-sm mb-3">PARTH</h2>
228
+ </div>
229
+ <div class="lg:w-1/4 md:w-1/2 w-full px-4">
230
+ <h2 class="title-font font-medium text-white tracking-widest text-sm mb-3">KEERTHI</h2>
231
+ </div>
232
+ <div class="lg:w-1/4 md:w-1/2 w-full px-4">
233
+ <h2 class="title-font font-medium text-white tracking-widest text-sm mb-3">NAVANEETH</h2>
234
+ </div>
235
+ </div>
236
+ </div>
237
+ <div class="bg-gray-800 bg-opacity-75">
238
+ <div class="container mx-auto py-4 px-5 flex flex-wrap flex-col sm:flex-row">
239
+ <p class="text-gray-400 text-sm text-center sm:text-left">© 2022 Crework Batch 3 —
240
+ <a href="https://twitter.com/knyttneve" class="text-gray-500 ml-1" rel="noopener noreferrer"
241
+ target="_blank">@Crework</a>
242
+ </p>
243
+ </div>
244
+ </div>
245
+ </footer>
246
+
247
+ </body>
248
+
249
+ </html>
cloth-mask.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import numpy as np
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchvision.transforms as transforms
9
+
10
+ from networks.u2net import U2NET
11
+ device = 'cuda'
12
+
13
+ image_dir = '/content/inputs/test/cloth'
14
+ result_dir = '/content/inputs/test/cloth-mask'
15
+ checkpoint_path = 'cloth_segm_u2net_latest.pth'
16
+
17
+ def load_checkpoint_mgpu(model, checkpoint_path):
18
+ if not os.path.exists(checkpoint_path):
19
+ print("----No checkpoints at given path----")
20
+ return
21
+ model_state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
22
+ new_state_dict = OrderedDict()
23
+ for k, v in model_state_dict.items():
24
+ name = k[7:] # remove `module.`
25
+ new_state_dict[name] = v
26
+
27
+ model.load_state_dict(new_state_dict)
28
+ print("----checkpoints loaded from path: {}----".format(checkpoint_path))
29
+ return model
30
+
31
+ class Normalize_image(object):
32
+ """Normalize given tensor into given mean and standard dev
33
+
34
+ Args:
35
+ mean (float): Desired mean to substract from tensors
36
+ std (float): Desired std to divide from tensors
37
+ """
38
+
39
+ def __init__(self, mean, std):
40
+ assert isinstance(mean, (float))
41
+ if isinstance(mean, float):
42
+ self.mean = mean
43
+
44
+ if isinstance(std, float):
45
+ self.std = std
46
+
47
+ self.normalize_1 = transforms.Normalize(self.mean, self.std)
48
+ self.normalize_3 = transforms.Normalize([self.mean] * 3, [self.std] * 3)
49
+ self.normalize_18 = transforms.Normalize([self.mean] * 18, [self.std] * 18)
50
+
51
+ def __call__(self, image_tensor):
52
+ if image_tensor.shape[0] == 1:
53
+ return self.normalize_1(image_tensor)
54
+
55
+ elif image_tensor.shape[0] == 3:
56
+ return self.normalize_3(image_tensor)
57
+
58
+ elif image_tensor.shape[0] == 18:
59
+ return self.normalize_18(image_tensor)
60
+
61
+ else:
62
+ assert "Please set proper channels! Normlization implemented only for 1, 3 and 18"
63
+
64
+
65
+ def get_palette(num_cls):
66
+ """ Returns the color map for visualizing the segmentation mask.
67
+ Args:
68
+ num_cls: Number of classes
69
+ Returns:
70
+ The color map
71
+ """
72
+ n = num_cls
73
+ palette = [0] * (n * 3)
74
+ for j in range(0, n):
75
+ lab = j
76
+ palette[j * 3 + 0] = 0
77
+ palette[j * 3 + 1] = 0
78
+ palette[j * 3 + 2] = 0
79
+ i = 0
80
+ while lab:
81
+ palette[j * 3 + 0] = 255
82
+ palette[j * 3 + 1] = 255
83
+ palette[j * 3 + 2] = 255
84
+ # palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
85
+ # palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
86
+ # palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
87
+ i += 1
88
+ lab >>= 3
89
+ return palette
90
+
91
+
92
+ transforms_list = []
93
+ transforms_list += [transforms.ToTensor()]
94
+ transforms_list += [Normalize_image(0.5, 0.5)]
95
+ transform_rgb = transforms.Compose(transforms_list)
96
+
97
+ net = U2NET(in_ch=3, out_ch=4)
98
+ net = load_checkpoint_mgpu(net, checkpoint_path)
99
+ net = net.to(device)
100
+ net = net.eval()
101
+
102
+ palette = get_palette(4)
103
+
104
+ images_list = sorted(os.listdir(image_dir))
105
+ for image_name in images_list:
106
+ img = Image.open(os.path.join(image_dir, image_name)).convert('RGB')
107
+ img_size = img.size
108
+ img = img.resize((768, 768), Image.BICUBIC)
109
+ image_tensor = transform_rgb(img)
110
+ image_tensor = torch.unsqueeze(image_tensor, 0)
111
+
112
+ output_tensor = net(image_tensor.to(device))
113
+ output_tensor = F.log_softmax(output_tensor[0], dim=1)
114
+ output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
115
+ output_tensor = torch.squeeze(output_tensor, dim=0)
116
+ output_tensor = torch.squeeze(output_tensor, dim=0)
117
+ output_arr = output_tensor.cpu().numpy()
118
+
119
+ output_img = Image.fromarray(output_arr.astype('uint8'), mode='L')
120
+ output_img = output_img.resize(img_size, Image.BICUBIC)
121
+
122
+ output_img.putpalette(palette)
123
+ output_img = output_img.convert('L')
124
+ output_img.save(os.path.join(result_dir, image_name[:-4]+'.jpg'))
datasets.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from os import path as osp
3
+
4
+ import numpy as np
5
+ from PIL import Image, ImageDraw
6
+ import torch
7
+ from torch.utils import data
8
+ from torchvision import transforms
9
+
10
+
11
+ class VITONDataset(data.Dataset):
12
+ def __init__(self, opt):
13
+ super(VITONDataset, self).__init__()
14
+ self.load_height = opt.load_height
15
+ self.load_width = opt.load_width
16
+ self.semantic_nc = opt.semantic_nc
17
+ self.data_path = osp.join(opt.dataset_dir, opt.dataset_mode)
18
+ self.transform = transforms.Compose([
19
+ transforms.ToTensor(),
20
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
21
+ ])
22
+
23
+ # load data list
24
+ img_names = []
25
+ c_names = []
26
+ with open(osp.join(opt.dataset_dir, opt.dataset_list), 'r') as f:
27
+ for line in f.readlines():
28
+ img_name, c_name = line.strip().split()
29
+ img_names.append(img_name)
30
+ c_names.append(c_name)
31
+
32
+ self.img_names = img_names
33
+ self.c_names = dict()
34
+ self.c_names['unpaired'] = c_names
35
+
36
+ def get_parse_agnostic(self, parse, pose_data):
37
+ parse_array = np.array(parse)
38
+ parse_upper = ((parse_array == 5).astype(np.float32) +
39
+ (parse_array == 6).astype(np.float32) +
40
+ (parse_array == 7).astype(np.float32))
41
+ parse_neck = (parse_array == 10).astype(np.float32)
42
+
43
+ r = 10
44
+ agnostic = parse.copy()
45
+
46
+ # mask arms
47
+ for parse_id, pose_ids in [(14, [2, 5, 6, 7]), (15, [5, 2, 3, 4])]:
48
+ mask_arm = Image.new('L', (self.load_width, self.load_height), 'black')
49
+ mask_arm_draw = ImageDraw.Draw(mask_arm)
50
+ i_prev = pose_ids[0]
51
+ for i in pose_ids[1:]:
52
+ if (pose_data[i_prev, 0] == 0.0 and pose_data[i_prev, 1] == 0.0) or (pose_data[i, 0] == 0.0 and pose_data[i, 1] == 0.0):
53
+ continue
54
+ mask_arm_draw.line([tuple(pose_data[j]) for j in [i_prev, i]], 'white', width=r*10)
55
+ pointx, pointy = pose_data[i]
56
+ radius = r*4 if i == pose_ids[-1] else r*15
57
+ mask_arm_draw.ellipse((pointx-radius, pointy-radius, pointx+radius, pointy+radius), 'white', 'white')
58
+ i_prev = i
59
+ parse_arm = (np.array(mask_arm) / 255) * (parse_array == parse_id).astype(np.float32)
60
+ agnostic.paste(0, None, Image.fromarray(np.uint8(parse_arm * 255), 'L'))
61
+
62
+ # mask torso & neck
63
+ agnostic.paste(0, None, Image.fromarray(np.uint8(parse_upper * 255), 'L'))
64
+ agnostic.paste(0, None, Image.fromarray(np.uint8(parse_neck * 255), 'L'))
65
+
66
+ return agnostic
67
+
68
+ def get_img_agnostic(self, img, parse, pose_data):
69
+ parse_array = np.array(parse)
70
+ parse_head = ((parse_array == 4).astype(np.float32) +
71
+ (parse_array == 13).astype(np.float32))
72
+ parse_lower = ((parse_array == 9).astype(np.float32) +
73
+ (parse_array == 12).astype(np.float32) +
74
+ (parse_array == 16).astype(np.float32) +
75
+ (parse_array == 17).astype(np.float32) +
76
+ (parse_array == 18).astype(np.float32) +
77
+ (parse_array == 19).astype(np.float32))
78
+
79
+ r = 20
80
+ agnostic = img.copy()
81
+ agnostic_draw = ImageDraw.Draw(agnostic)
82
+
83
+ length_a = np.linalg.norm(pose_data[5] - pose_data[2])
84
+ length_b = np.linalg.norm(pose_data[12] - pose_data[9])
85
+ point = (pose_data[9] + pose_data[12]) / 2
86
+ pose_data[9] = point + (pose_data[9] - point) / length_b * length_a
87
+ pose_data[12] = point + (pose_data[12] - point) / length_b * length_a
88
+
89
+ # mask arms
90
+ agnostic_draw.line([tuple(pose_data[i]) for i in [2, 5]], 'gray', width=r*10)
91
+ for i in [2, 5]:
92
+ pointx, pointy = pose_data[i]
93
+ agnostic_draw.ellipse((pointx-r*5, pointy-r*5, pointx+r*5, pointy+r*5), 'gray', 'gray')
94
+ for i in [3, 4, 6, 7]:
95
+ if (pose_data[i - 1, 0] == 0.0 and pose_data[i - 1, 1] == 0.0) or (pose_data[i, 0] == 0.0 and pose_data[i, 1] == 0.0):
96
+ continue
97
+ agnostic_draw.line([tuple(pose_data[j]) for j in [i - 1, i]], 'gray', width=r*10)
98
+ pointx, pointy = pose_data[i]
99
+ agnostic_draw.ellipse((pointx-r*5, pointy-r*5, pointx+r*5, pointy+r*5), 'gray', 'gray')
100
+
101
+ # mask torso
102
+ for i in [9, 12]:
103
+ pointx, pointy = pose_data[i]
104
+ agnostic_draw.ellipse((pointx-r*3, pointy-r*6, pointx+r*3, pointy+r*6), 'gray', 'gray')
105
+ agnostic_draw.line([tuple(pose_data[i]) for i in [2, 9]], 'gray', width=r*6)
106
+ agnostic_draw.line([tuple(pose_data[i]) for i in [5, 12]], 'gray', width=r*6)
107
+ agnostic_draw.line([tuple(pose_data[i]) for i in [9, 12]], 'gray', width=r*12)
108
+ agnostic_draw.polygon([tuple(pose_data[i]) for i in [2, 5, 12, 9]], 'gray', 'gray')
109
+
110
+ # mask neck
111
+ pointx, pointy = pose_data[1]
112
+ agnostic_draw.rectangle((pointx-r*7, pointy-r*7, pointx+r*7, pointy+r*7), 'gray', 'gray')
113
+ agnostic.paste(img, None, Image.fromarray(np.uint8(parse_head * 255), 'L'))
114
+ agnostic.paste(img, None, Image.fromarray(np.uint8(parse_lower * 255), 'L'))
115
+
116
+ return agnostic
117
+
118
+ def __getitem__(self, index):
119
+ img_name = self.img_names[index]
120
+ c_name = {}
121
+ c = {}
122
+ cm = {}
123
+ for key in self.c_names:
124
+ c_name[key] = self.c_names[key][index]
125
+ c[key] = Image.open(osp.join(self.data_path, 'cloth', c_name[key])).convert('RGB')
126
+ c[key] = transforms.Resize(self.load_width, interpolation=2)(c[key])
127
+ cm[key] = Image.open(osp.join(self.data_path, 'cloth-mask', c_name[key]))
128
+ cm[key] = transforms.Resize(self.load_width, interpolation=0)(cm[key])
129
+
130
+ c[key] = self.transform(c[key]) # [-1,1]
131
+ cm_array = np.array(cm[key])
132
+ cm_array = (cm_array >= 128).astype(np.float32)
133
+ cm[key] = torch.from_numpy(cm_array) # [0,1]
134
+ cm[key].unsqueeze_(0)
135
+
136
+ # load pose image
137
+ pose_name = img_name.replace('.jpg', '_rendered.png')
138
+ pose_rgb = Image.open(osp.join(self.data_path, 'openpose-img', pose_name))
139
+ pose_rgb = transforms.Resize(self.load_width, interpolation=2)(pose_rgb)
140
+ pose_rgb = self.transform(pose_rgb) # [-1,1]
141
+
142
+ pose_name = img_name.replace('.jpg', '_keypoints.json')
143
+ with open(osp.join(self.data_path, 'openpose-json', pose_name), 'r') as f:
144
+ pose_label = json.load(f)
145
+ pose_data = pose_label['people'][0]['pose_keypoints_2d']
146
+ pose_data = np.array(pose_data)
147
+ pose_data = pose_data.reshape((-1, 3))[:, :2]
148
+
149
+ # load parsing image
150
+ parse_name = img_name.replace('.jpg', '.png')
151
+ parse = Image.open(osp.join(self.data_path, 'image-parse', parse_name))
152
+ parse = transforms.Resize(self.load_width, interpolation=0)(parse)
153
+ parse_agnostic = self.get_parse_agnostic(parse, pose_data)
154
+ parse_agnostic = torch.from_numpy(np.array(parse_agnostic)[None]).long()
155
+
156
+ labels = {
157
+ 0: ['background', [0, 10]],
158
+ 1: ['hair', [1, 2]],
159
+ 2: ['face', [4, 13]],
160
+ 3: ['upper', [5, 6, 7]],
161
+ 4: ['bottom', [9, 12]],
162
+ 5: ['left_arm', [14]],
163
+ 6: ['right_arm', [15]],
164
+ 7: ['left_leg', [16]],
165
+ 8: ['right_leg', [17]],
166
+ 9: ['left_shoe', [18]],
167
+ 10: ['right_shoe', [19]],
168
+ 11: ['socks', [8]],
169
+ 12: ['noise', [3, 11]]
170
+ }
171
+ parse_agnostic_map = torch.zeros(20, self.load_height, self.load_width, dtype=torch.float)
172
+ parse_agnostic_map.scatter_(0, parse_agnostic, 1.0)
173
+ new_parse_agnostic_map = torch.zeros(self.semantic_nc, self.load_height, self.load_width, dtype=torch.float)
174
+ for i in range(len(labels)):
175
+ for label in labels[i][1]:
176
+ new_parse_agnostic_map[i] += parse_agnostic_map[label]
177
+
178
+ # load person image
179
+ img = Image.open(osp.join(self.data_path, 'image', img_name))
180
+ img = transforms.Resize(self.load_width, interpolation=2)(img)
181
+ img_agnostic = self.get_img_agnostic(img, parse, pose_data)
182
+ img = self.transform(img)
183
+ img_agnostic = self.transform(img_agnostic) # [-1,1]
184
+
185
+ result = {
186
+ 'img_name': img_name,
187
+ 'c_name': c_name,
188
+ 'img': img,
189
+ 'img_agnostic': img_agnostic,
190
+ 'parse_agnostic': new_parse_agnostic_map,
191
+ 'pose': pose_rgb,
192
+ 'cloth': c,
193
+ 'cloth_mask': cm,
194
+ }
195
+ return result
196
+
197
+ def __len__(self):
198
+ return len(self.img_names)
199
+
200
+
201
+ class VITONDataLoader:
202
+ def __init__(self, opt, dataset):
203
+ super(VITONDataLoader, self).__init__()
204
+
205
+ if opt.shuffle:
206
+ train_sampler = data.sampler.RandomSampler(dataset)
207
+ else:
208
+ train_sampler = None
209
+
210
+ self.data_loader = data.DataLoader(
211
+ dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
212
+ num_workers=opt.workers, pin_memory=True, drop_last=True, sampler=train_sampler
213
+ )
214
+ self.dataset = dataset
215
+ self.data_iter = self.data_loader.__iter__()
216
+
217
+ def next_batch(self):
218
+ try:
219
+ batch = self.data_iter.__next__()
220
+ except StopIteration:
221
+ self.data_iter = self.data_loader.__iter__()
222
+ batch = self.data_iter.__next__()
223
+
224
+ return batch
network.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from torch.nn import init
6
+ from torch.nn.utils.spectral_norm import spectral_norm
7
+
8
+
9
+ # ----------------------------------------------------------------------------------------------------------------------
10
+ # Common classes
11
+ # ----------------------------------------------------------------------------------------------------------------------
12
+ class BaseNetwork(nn.Module):
13
+ def __init__(self):
14
+ super(BaseNetwork, self).__init__()
15
+
16
+ def print_network(self):
17
+ num_params = 0
18
+ for param in self.parameters():
19
+ num_params += param.numel()
20
+ print("Network [{}] was created. Total number of parameters: {:.1f} million. "
21
+ "To see the architecture, do print(network).".format(self.__class__.__name__, num_params / 1000000))
22
+
23
+ def init_weights(self, init_type='normal', gain=0.02):
24
+ def init_func(m):
25
+ classname = m.__class__.__name__
26
+ if 'BatchNorm2d' in classname:
27
+ if hasattr(m, 'weight') and m.weight is not None:
28
+ init.normal_(m.weight.data, 1.0, gain)
29
+ if hasattr(m, 'bias') and m.bias is not None:
30
+ init.constant_(m.bias.data, 0.0)
31
+ elif ('Conv' in classname or 'Linear' in classname) and hasattr(m, 'weight'):
32
+ if init_type == 'normal':
33
+ init.normal_(m.weight.data, 0.0, gain)
34
+ elif init_type == 'xavier':
35
+ init.xavier_normal_(m.weight.data, gain=gain)
36
+ elif init_type == 'xavier_uniform':
37
+ init.xavier_uniform_(m.weight.data, gain=1.0)
38
+ elif init_type == 'kaiming':
39
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
40
+ elif init_type == 'orthogonal':
41
+ init.orthogonal_(m.weight.data, gain=gain)
42
+ elif init_type == 'none': # uses pytorch's default init method
43
+ m.reset_parameters()
44
+ else:
45
+ raise NotImplementedError("initialization method '{}' is not implemented".format(init_type))
46
+ if hasattr(m, 'bias') and m.bias is not None:
47
+ init.constant_(m.bias.data, 0.0)
48
+
49
+ self.apply(init_func)
50
+
51
+ def forward(self, *inputs):
52
+ pass
53
+
54
+
55
+ # ----------------------------------------------------------------------------------------------------------------------
56
+ # SegGenerator-related classes
57
+ # ----------------------------------------------------------------------------------------------------------------------
58
+ class SegGenerator(BaseNetwork):
59
+ def __init__(self, opt, input_nc, output_nc=13, norm_layer=nn.InstanceNorm2d):
60
+ super(SegGenerator, self).__init__()
61
+
62
+ self.conv1 = nn.Sequential(nn.Conv2d(input_nc, 64, kernel_size=3, padding=1), norm_layer(64), nn.ReLU(),
63
+ nn.Conv2d(64, 64, kernel_size=3, padding=1), norm_layer(64), nn.ReLU())
64
+
65
+ self.conv2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, padding=1), norm_layer(128), nn.ReLU(),
66
+ nn.Conv2d(128, 128, kernel_size=3, padding=1), norm_layer(128), nn.ReLU())
67
+
68
+ self.conv3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, padding=1), norm_layer(256), nn.ReLU(),
69
+ nn.Conv2d(256, 256, kernel_size=3, padding=1), norm_layer(256), nn.ReLU())
70
+
71
+ self.conv4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, padding=1), norm_layer(512), nn.ReLU(),
72
+ nn.Conv2d(512, 512, kernel_size=3, padding=1), norm_layer(512), nn.ReLU())
73
+
74
+ self.conv5 = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=3, padding=1), norm_layer(1024), nn.ReLU(),
75
+ nn.Conv2d(1024, 1024, kernel_size=3, padding=1), norm_layer(1024), nn.ReLU())
76
+
77
+ self.up6 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
78
+ nn.Conv2d(1024, 512, kernel_size=3, padding=1), norm_layer(512), nn.ReLU())
79
+ self.conv6 = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=3, padding=1), norm_layer(512), nn.ReLU(),
80
+ nn.Conv2d(512, 512, kernel_size=3, padding=1), norm_layer(512), nn.ReLU())
81
+
82
+ self.up7 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
83
+ nn.Conv2d(512, 256, kernel_size=3, padding=1), norm_layer(256), nn.ReLU())
84
+ self.conv7 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1), norm_layer(256), nn.ReLU(),
85
+ nn.Conv2d(256, 256, kernel_size=3, padding=1), norm_layer(256), nn.ReLU())
86
+
87
+ self.up8 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
88
+ nn.Conv2d(256, 128, kernel_size=3, padding=1), norm_layer(128), nn.ReLU())
89
+ self.conv8 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, padding=1), norm_layer(128), nn.ReLU(),
90
+ nn.Conv2d(128, 128, kernel_size=3, padding=1), norm_layer(128), nn.ReLU())
91
+
92
+ self.up9 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
93
+ nn.Conv2d(128, 64, kernel_size=3, padding=1), norm_layer(64), nn.ReLU())
94
+ self.conv9 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1), norm_layer(64), nn.ReLU(),
95
+ nn.Conv2d(64, 64, kernel_size=3, padding=1), norm_layer(64), nn.ReLU(),
96
+ nn.Conv2d(64, output_nc, kernel_size=3, padding=1))
97
+
98
+ self.pool = nn.MaxPool2d(2)
99
+ self.drop = nn.Dropout(0.5)
100
+ self.sigmoid = nn.Sigmoid()
101
+
102
+ self.print_network()
103
+ self.init_weights(opt.init_type, opt.init_variance)
104
+
105
+ def forward(self, x):
106
+ conv1 = self.conv1(x)
107
+ conv2 = self.conv2(self.pool(conv1))
108
+ conv3 = self.conv3(self.pool(conv2))
109
+ conv4 = self.drop(self.conv4(self.pool(conv3)))
110
+ conv5 = self.drop(self.conv5(self.pool(conv4)))
111
+
112
+ conv6 = self.conv6(torch.cat((conv4, self.up6(conv5)), 1))
113
+ conv7 = self.conv7(torch.cat((conv3, self.up7(conv6)), 1))
114
+ conv8 = self.conv8(torch.cat((conv2, self.up8(conv7)), 1))
115
+ conv9 = self.conv9(torch.cat((conv1, self.up9(conv8)), 1))
116
+ return self.sigmoid(conv9)
117
+
118
+
119
+ # ----------------------------------------------------------------------------------------------------------------------
120
+ # GMM-related classes
121
+ # ----------------------------------------------------------------------------------------------------------------------
122
+ class FeatureExtraction(BaseNetwork):
123
+ def __init__(self, input_nc, ngf=64, num_layers=4, norm_layer=nn.BatchNorm2d):
124
+ super(FeatureExtraction, self).__init__()
125
+
126
+ nf = ngf
127
+ layers = [nn.Conv2d(input_nc, nf, kernel_size=4, stride=2, padding=1), nn.ReLU(), norm_layer(nf)]
128
+
129
+ for i in range(1, num_layers):
130
+ nf_prev = nf
131
+ nf = min(nf * 2, 512)
132
+ layers += [nn.Conv2d(nf_prev, nf, kernel_size=4, stride=2, padding=1), nn.ReLU(), norm_layer(nf)]
133
+
134
+ layers += [nn.Conv2d(nf, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(), norm_layer(512)]
135
+ layers += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU()]
136
+
137
+ self.model = nn.Sequential(*layers)
138
+ self.init_weights()
139
+
140
+ def forward(self, x):
141
+ return self.model(x)
142
+
143
+
144
+ class FeatureCorrelation(nn.Module):
145
+ def __init__(self):
146
+ super(FeatureCorrelation, self).__init__()
147
+
148
+ def forward(self, featureA, featureB):
149
+ # Reshape features for matrix multiplication.
150
+ b, c, h, w = featureA.size()
151
+ featureA = featureA.permute(0, 3, 2, 1).reshape(b, w * h, c)
152
+ featureB = featureB.reshape(b, c, h * w)
153
+
154
+ # Perform matrix multiplication.
155
+ corr = torch.bmm(featureA, featureB).reshape(b, w * h, h, w)
156
+ return corr
157
+
158
+
159
+ class FeatureRegression(nn.Module):
160
+ def __init__(self, input_nc=512, output_size=6, norm_layer=nn.BatchNorm2d):
161
+ super(FeatureRegression, self).__init__()
162
+
163
+ self.conv = nn.Sequential(
164
+ nn.Conv2d(input_nc, 512, kernel_size=4, stride=2, padding=1), norm_layer(512), nn.ReLU(),
165
+ nn.Conv2d(512, 256, kernel_size=4, stride=2, padding=1), norm_layer(256), nn.ReLU(),
166
+ nn.Conv2d(256, 128, kernel_size=3, padding=1), norm_layer(128), nn.ReLU(),
167
+ nn.Conv2d(128, 64, kernel_size=3, padding=1), norm_layer(64), nn.ReLU()
168
+ )
169
+ self.linear = nn.Linear(64 * (input_nc // 16), output_size)
170
+ self.tanh = nn.Tanh()
171
+
172
+ def forward(self, x):
173
+ x = self.conv(x)
174
+ x = self.linear(x.reshape(x.size(0), -1))
175
+ return self.tanh(x)
176
+
177
+
178
+ class TpsGridGen(nn.Module):
179
+ def __init__(self, opt, dtype=torch.float):
180
+ super(TpsGridGen, self).__init__()
181
+
182
+ # Create a grid in numpy.
183
+ # TODO: set an appropriate interval ([-1, 1] in CP-VTON, [-0.9, 0.9] in the current version of VITON-HD)
184
+ grid_X, grid_Y = np.meshgrid(np.linspace(-0.9, 0.9, opt.load_width), np.linspace(-0.9, 0.9, opt.load_height))
185
+ grid_X = torch.tensor(grid_X, dtype=dtype).unsqueeze(0).unsqueeze(3) # size: (1, h, w, 1)
186
+ grid_Y = torch.tensor(grid_Y, dtype=dtype).unsqueeze(0).unsqueeze(3) # size: (1, h, w, 1)
187
+
188
+ # Initialize the regular grid for control points P.
189
+ self.N = opt.grid_size * opt.grid_size
190
+ coords = np.linspace(-0.9, 0.9, opt.grid_size)
191
+ # FIXME: why P_Y and P_X are swapped?
192
+ P_Y, P_X = np.meshgrid(coords, coords)
193
+ P_X = torch.tensor(P_X, dtype=dtype).reshape(self.N, 1)
194
+ P_Y = torch.tensor(P_Y, dtype=dtype).reshape(self.N, 1)
195
+ P_X_base = P_X.clone()
196
+ P_Y_base = P_Y.clone()
197
+
198
+ Li = self.compute_L_inverse(P_X, P_Y).unsqueeze(0)
199
+ P_X = P_X.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0, 4) # size: (1, 1, 1, 1, self.N)
200
+ P_Y = P_Y.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0, 4) # size: (1, 1, 1, 1, self.N)
201
+
202
+ self.register_buffer('grid_X', grid_X, False)
203
+ self.register_buffer('grid_Y', grid_Y, False)
204
+ self.register_buffer('P_X_base', P_X_base, False)
205
+ self.register_buffer('P_Y_base', P_Y_base, False)
206
+ self.register_buffer('Li', Li, False)
207
+ self.register_buffer('P_X', P_X, False)
208
+ self.register_buffer('P_Y', P_Y, False)
209
+
210
+ # TODO: refactor
211
+ def compute_L_inverse(self,X,Y):
212
+ N = X.size()[0] # num of points (along dim 0)
213
+ # construct matrix K
214
+ Xmat = X.expand(N,N)
215
+ Ymat = Y.expand(N,N)
216
+ P_dist_squared = torch.pow(Xmat-Xmat.transpose(0,1),2)+torch.pow(Ymat-Ymat.transpose(0,1),2)
217
+ P_dist_squared[P_dist_squared==0]=1 # make diagonal 1 to avoid NaN in log computation
218
+ K = torch.mul(P_dist_squared,torch.log(P_dist_squared))
219
+ # construct matrix L
220
+ O = torch.FloatTensor(N,1).fill_(1)
221
+ Z = torch.FloatTensor(3,3).fill_(0)
222
+ P = torch.cat((O,X,Y),1)
223
+ L = torch.cat((torch.cat((K,P),1),torch.cat((P.transpose(0,1),Z),1)),0)
224
+ Li = torch.inverse(L)
225
+ return Li
226
+
227
+ # TODO: refactor
228
+ def apply_transformation(self,theta,points):
229
+ if theta.dim()==2:
230
+ theta = theta.unsqueeze(2).unsqueeze(3)
231
+ # points should be in the [B,H,W,2] format,
232
+ # where points[:,:,:,0] are the X coords
233
+ # and points[:,:,:,1] are the Y coords
234
+
235
+ # input are the corresponding control points P_i
236
+ batch_size = theta.size()[0]
237
+ # split theta into point coordinates
238
+ Q_X=theta[:,:self.N,:,:].squeeze(3)
239
+ Q_Y=theta[:,self.N:,:,:].squeeze(3)
240
+ Q_X = Q_X + self.P_X_base.expand_as(Q_X)
241
+ Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
242
+
243
+ # get spatial dimensions of points
244
+ points_b = points.size()[0]
245
+ points_h = points.size()[1]
246
+ points_w = points.size()[2]
247
+
248
+ # repeat pre-defined control points along spatial dimensions of points to be transformed
249
+ P_X = self.P_X.expand((1,points_h,points_w,1,self.N))
250
+ P_Y = self.P_Y.expand((1,points_h,points_w,1,self.N))
251
+
252
+ # compute weigths for non-linear part
253
+ W_X = torch.bmm(self.Li[:,:self.N,:self.N].expand((batch_size,self.N,self.N)),Q_X)
254
+ W_Y = torch.bmm(self.Li[:,:self.N,:self.N].expand((batch_size,self.N,self.N)),Q_Y)
255
+ # reshape
256
+ # W_X,W,Y: size [B,H,W,1,N]
257
+ W_X = W_X.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)
258
+ W_Y = W_Y.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)
259
+ # compute weights for affine part
260
+ A_X = torch.bmm(self.Li[:,self.N:,:self.N].expand((batch_size,3,self.N)),Q_X)
261
+ A_Y = torch.bmm(self.Li[:,self.N:,:self.N].expand((batch_size,3,self.N)),Q_Y)
262
+ # reshape
263
+ # A_X,A,Y: size [B,H,W,1,3]
264
+ A_X = A_X.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)
265
+ A_Y = A_Y.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)
266
+
267
+ # compute distance P_i - (grid_X,grid_Y)
268
+ # grid is expanded in point dim 4, but not in batch dim 0, as points P_X,P_Y are fixed for all batch
269
+ points_X_for_summation = points[:,:,:,0].unsqueeze(3).unsqueeze(4).expand(points[:,:,:,0].size()+(1,self.N))
270
+ points_Y_for_summation = points[:,:,:,1].unsqueeze(3).unsqueeze(4).expand(points[:,:,:,1].size()+(1,self.N))
271
+
272
+ if points_b==1:
273
+ delta_X = points_X_for_summation-P_X
274
+ delta_Y = points_Y_for_summation-P_Y
275
+ else:
276
+ # use expanded P_X,P_Y in batch dimension
277
+ delta_X = points_X_for_summation-P_X.expand_as(points_X_for_summation)
278
+ delta_Y = points_Y_for_summation-P_Y.expand_as(points_Y_for_summation)
279
+
280
+ dist_squared = torch.pow(delta_X,2)+torch.pow(delta_Y,2)
281
+ # U: size [1,H,W,1,N]
282
+ dist_squared[dist_squared==0]=1 # avoid NaN in log computation
283
+ U = torch.mul(dist_squared,torch.log(dist_squared))
284
+
285
+ # expand grid in batch dimension if necessary
286
+ points_X_batch = points[:,:,:,0].unsqueeze(3)
287
+ points_Y_batch = points[:,:,:,1].unsqueeze(3)
288
+ if points_b==1:
289
+ points_X_batch = points_X_batch.expand((batch_size,)+points_X_batch.size()[1:])
290
+ points_Y_batch = points_Y_batch.expand((batch_size,)+points_Y_batch.size()[1:])
291
+
292
+ points_X_prime = A_X[:,:,:,:,0]+ \
293
+ torch.mul(A_X[:,:,:,:,1],points_X_batch) + \
294
+ torch.mul(A_X[:,:,:,:,2],points_Y_batch) + \
295
+ torch.sum(torch.mul(W_X,U.expand_as(W_X)),4)
296
+
297
+ points_Y_prime = A_Y[:,:,:,:,0]+ \
298
+ torch.mul(A_Y[:,:,:,:,1],points_X_batch) + \
299
+ torch.mul(A_Y[:,:,:,:,2],points_Y_batch) + \
300
+ torch.sum(torch.mul(W_Y,U.expand_as(W_Y)),4)
301
+
302
+ return torch.cat((points_X_prime,points_Y_prime),3)
303
+
304
+ def forward(self, theta):
305
+ warped_grid = self.apply_transformation(theta, torch.cat((self.grid_X, self.grid_Y), 3))
306
+ return warped_grid
307
+
308
+
309
+ class GMM(nn.Module):
310
+ def __init__(self, opt, inputA_nc, inputB_nc):
311
+ super(GMM, self).__init__()
312
+
313
+ self.extractionA = FeatureExtraction(inputA_nc, ngf=64, num_layers=4)
314
+ self.extractionB = FeatureExtraction(inputB_nc, ngf=64, num_layers=4)
315
+ self.correlation = FeatureCorrelation()
316
+ self.regression = FeatureRegression(input_nc=(opt.load_width // 64) * (opt.load_height // 64),
317
+ output_size=2 * opt.grid_size**2)
318
+ self.gridGen = TpsGridGen(opt)
319
+
320
+ def forward(self, inputA, inputB):
321
+ featureA = F.normalize(self.extractionA(inputA), dim=1)
322
+ featureB = F.normalize(self.extractionB(inputB), dim=1)
323
+ corr = self.correlation(featureA, featureB)
324
+ theta = self.regression(corr)
325
+
326
+ warped_grid = self.gridGen(theta)
327
+ return theta, warped_grid
328
+
329
+
330
+ # ----------------------------------------------------------------------------------------------------------------------
331
+ # ALIASGenerator-related classes
332
+ # ----------------------------------------------------------------------------------------------------------------------
333
+ class MaskNorm(nn.Module):
334
+ def __init__(self, norm_nc):
335
+ super(MaskNorm, self).__init__()
336
+
337
+ self.norm_layer = nn.InstanceNorm2d(norm_nc, affine=False)
338
+
339
+ def normalize_region(self, region, mask):
340
+ b, c, h, w = region.size()
341
+
342
+ num_pixels = mask.sum((2, 3), keepdim=True) # size: (b, 1, 1, 1)
343
+ num_pixels[num_pixels == 0] = 1
344
+ mu = region.sum((2, 3), keepdim=True) / num_pixels # size: (b, c, 1, 1)
345
+
346
+ normalized_region = self.norm_layer(region + (1 - mask) * mu)
347
+ return normalized_region * torch.sqrt(num_pixels / (h * w))
348
+
349
+ def forward(self, x, mask):
350
+ mask = mask.detach()
351
+ normalized_foreground = self.normalize_region(x * mask, mask)
352
+ normalized_background = self.normalize_region(x * (1 - mask), 1 - mask)
353
+ return normalized_foreground + normalized_background
354
+
355
+
356
+ class ALIASNorm(nn.Module):
357
+ def __init__(self, norm_type, norm_nc, label_nc):
358
+ super(ALIASNorm, self).__init__()
359
+
360
+ self.noise_scale = nn.Parameter(torch.zeros(norm_nc))
361
+
362
+ assert norm_type.startswith('alias')
363
+ param_free_norm_type = norm_type[len('alias'):]
364
+ if param_free_norm_type == 'batch':
365
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
366
+ elif param_free_norm_type == 'instance':
367
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
368
+ elif param_free_norm_type == 'mask':
369
+ self.param_free_norm = MaskNorm(norm_nc)
370
+ else:
371
+ raise ValueError(
372
+ "'{}' is not a recognized parameter-free normalization type in ALIASNorm".format(param_free_norm_type)
373
+ )
374
+
375
+ nhidden = 128
376
+ ks = 3
377
+ pw = ks // 2
378
+ self.conv_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
379
+ self.conv_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
380
+ self.conv_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
381
+
382
+ def forward(self, x, seg, misalign_mask=None):
383
+ # Part 1. Generate parameter-free normalized activations.
384
+ b, c, h, w = x.size()
385
+ noise = (torch.randn(b, w, h, 1).cuda() * self.noise_scale).transpose(1, 3)
386
+
387
+ if misalign_mask is None:
388
+ normalized = self.param_free_norm(x + noise)
389
+ else:
390
+ normalized = self.param_free_norm(x + noise, misalign_mask)
391
+
392
+ # Part 2. Produce affine parameters conditioned on the segmentation map.
393
+ actv = self.conv_shared(seg)
394
+ gamma = self.conv_gamma(actv)
395
+ beta = self.conv_beta(actv)
396
+
397
+ # Apply the affine parameters.
398
+ output = normalized * (1 + gamma) + beta
399
+ return output
400
+
401
+
402
+ class ALIASResBlock(nn.Module):
403
+ def __init__(self, opt, input_nc, output_nc, use_mask_norm=True):
404
+ super(ALIASResBlock, self).__init__()
405
+
406
+ self.learned_shortcut = (input_nc != output_nc)
407
+ middle_nc = min(input_nc, output_nc)
408
+
409
+ self.conv_0 = nn.Conv2d(input_nc, middle_nc, kernel_size=3, padding=1)
410
+ self.conv_1 = nn.Conv2d(middle_nc, output_nc, kernel_size=3, padding=1)
411
+ if self.learned_shortcut:
412
+ self.conv_s = nn.Conv2d(input_nc, output_nc, kernel_size=1, bias=False)
413
+
414
+ subnorm_type = opt.norm_G
415
+ if subnorm_type.startswith('spectral'):
416
+ subnorm_type = subnorm_type[len('spectral'):]
417
+ self.conv_0 = spectral_norm(self.conv_0)
418
+ self.conv_1 = spectral_norm(self.conv_1)
419
+ if self.learned_shortcut:
420
+ self.conv_s = spectral_norm(self.conv_s)
421
+
422
+ semantic_nc = opt.semantic_nc
423
+ if use_mask_norm:
424
+ subnorm_type = 'aliasmask'
425
+ semantic_nc = semantic_nc + 1
426
+
427
+ self.norm_0 = ALIASNorm(subnorm_type, input_nc, semantic_nc)
428
+ self.norm_1 = ALIASNorm(subnorm_type, middle_nc, semantic_nc)
429
+ if self.learned_shortcut:
430
+ self.norm_s = ALIASNorm(subnorm_type, input_nc, semantic_nc)
431
+
432
+ self.relu = nn.LeakyReLU(0.2)
433
+
434
+ def shortcut(self, x, seg, misalign_mask):
435
+ if self.learned_shortcut:
436
+ return self.conv_s(self.norm_s(x, seg, misalign_mask))
437
+ else:
438
+ return x
439
+
440
+ def forward(self, x, seg, misalign_mask=None):
441
+ seg = F.interpolate(seg, size=x.size()[2:], mode='nearest')
442
+ if misalign_mask is not None:
443
+ misalign_mask = F.interpolate(misalign_mask, size=x.size()[2:], mode='nearest')
444
+
445
+ x_s = self.shortcut(x, seg, misalign_mask)
446
+
447
+ dx = self.conv_0(self.relu(self.norm_0(x, seg, misalign_mask)))
448
+ dx = self.conv_1(self.relu(self.norm_1(dx, seg, misalign_mask)))
449
+ output = x_s + dx
450
+ return output
451
+
452
+
453
+ class ALIASGenerator(BaseNetwork):
454
+ def __init__(self, opt, input_nc):
455
+ super(ALIASGenerator, self).__init__()
456
+ self.num_upsampling_layers = opt.num_upsampling_layers
457
+
458
+ self.sh, self.sw = self.compute_latent_vector_size(opt)
459
+
460
+ nf = opt.ngf
461
+ self.conv_0 = nn.Conv2d(input_nc, nf * 16, kernel_size=3, padding=1)
462
+ for i in range(1, 8):
463
+ self.add_module('conv_{}'.format(i), nn.Conv2d(input_nc, 16, kernel_size=3, padding=1))
464
+
465
+ self.head_0 = ALIASResBlock(opt, nf * 16, nf * 16)
466
+
467
+ self.G_middle_0 = ALIASResBlock(opt, nf * 16 + 16, nf * 16)
468
+ self.G_middle_1 = ALIASResBlock(opt, nf * 16 + 16, nf * 16)
469
+
470
+ self.up_0 = ALIASResBlock(opt, nf * 16 + 16, nf * 8)
471
+ self.up_1 = ALIASResBlock(opt, nf * 8 + 16, nf * 4)
472
+ self.up_2 = ALIASResBlock(opt, nf * 4 + 16, nf * 2, use_mask_norm=False)
473
+ self.up_3 = ALIASResBlock(opt, nf * 2 + 16, nf * 1, use_mask_norm=False)
474
+ if self.num_upsampling_layers == 'most':
475
+ self.up_4 = ALIASResBlock(opt, nf * 1 + 16, nf // 2, use_mask_norm=False)
476
+ nf = nf // 2
477
+
478
+ self.conv_img = nn.Conv2d(nf, 3, kernel_size=3, padding=1)
479
+
480
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
481
+ self.relu = nn.LeakyReLU(0.2)
482
+ self.tanh = nn.Tanh()
483
+
484
+ self.print_network()
485
+ self.init_weights(opt.init_type, opt.init_variance)
486
+
487
+ def compute_latent_vector_size(self, opt):
488
+ if self.num_upsampling_layers == 'normal':
489
+ num_up_layers = 5
490
+ elif self.num_upsampling_layers == 'more':
491
+ num_up_layers = 6
492
+ elif self.num_upsampling_layers == 'most':
493
+ num_up_layers = 7
494
+ else:
495
+ raise ValueError("opt.num_upsampling_layers '{}' is not recognized".format(self.num_upsampling_layers))
496
+
497
+ sh = opt.load_height // 2**num_up_layers
498
+ sw = opt.load_width // 2**num_up_layers
499
+ return sh, sw
500
+
501
+ def forward(self, x, seg, seg_div, misalign_mask):
502
+ samples = [F.interpolate(x, size=(self.sh * 2**i, self.sw * 2**i), mode='nearest') for i in range(8)]
503
+ features = [self._modules['conv_{}'.format(i)](samples[i]) for i in range(8)]
504
+
505
+ x = self.head_0(features[0], seg_div, misalign_mask)
506
+
507
+ x = self.up(x)
508
+ x = self.G_middle_0(torch.cat((x, features[1]), 1), seg_div, misalign_mask)
509
+ if self.num_upsampling_layers in ['more', 'most']:
510
+ x = self.up(x)
511
+ x = self.G_middle_1(torch.cat((x, features[2]), 1), seg_div, misalign_mask)
512
+
513
+ x = self.up(x)
514
+ x = self.up_0(torch.cat((x, features[3]), 1), seg_div, misalign_mask)
515
+ x = self.up(x)
516
+ x = self.up_1(torch.cat((x, features[4]), 1), seg_div, misalign_mask)
517
+ x = self.up(x)
518
+ x = self.up_2(torch.cat((x, features[5]), 1), seg)
519
+ x = self.up(x)
520
+ x = self.up_3(torch.cat((x, features[6]), 1), seg)
521
+ if self.num_upsampling_layers == 'most':
522
+ x = self.up(x)
523
+ x = self.up_4(torch.cat((x, features[7]), 1), seg)
524
+
525
+ x = self.conv_img(self.relu(x))
526
+ return self.tanh(x)
networks/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .u2net import U2NET
networks/u2net.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class REBNCONV(nn.Module):
7
+ def __init__(self, in_ch=3, out_ch=3, dirate=1):
8
+ super(REBNCONV, self).__init__()
9
+
10
+ self.conv_s1 = nn.Conv2d(
11
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
12
+ )
13
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
14
+ self.relu_s1 = nn.ReLU(inplace=True)
15
+
16
+ def forward(self, x):
17
+
18
+ hx = x
19
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
20
+
21
+ return xout
22
+
23
+
24
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
25
+ def _upsample_like(src, tar):
26
+
27
+ src = F.upsample(src, size=tar.shape[2:], mode="bilinear")
28
+
29
+ return src
30
+
31
+
32
+ ### RSU-7 ###
33
+ class RSU7(nn.Module): # UNet07DRES(nn.Module):
34
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
35
+ super(RSU7, self).__init__()
36
+
37
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
38
+
39
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
40
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
41
+
42
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
43
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
44
+
45
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
46
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
47
+
48
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
49
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
50
+
51
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
52
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
53
+
54
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
55
+
56
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
57
+
58
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
59
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
60
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
61
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
62
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
63
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
64
+
65
+ def forward(self, x):
66
+
67
+ hx = x
68
+ hxin = self.rebnconvin(hx)
69
+
70
+ hx1 = self.rebnconv1(hxin)
71
+ hx = self.pool1(hx1)
72
+
73
+ hx2 = self.rebnconv2(hx)
74
+ hx = self.pool2(hx2)
75
+
76
+ hx3 = self.rebnconv3(hx)
77
+ hx = self.pool3(hx3)
78
+
79
+ hx4 = self.rebnconv4(hx)
80
+ hx = self.pool4(hx4)
81
+
82
+ hx5 = self.rebnconv5(hx)
83
+ hx = self.pool5(hx5)
84
+
85
+ hx6 = self.rebnconv6(hx)
86
+
87
+ hx7 = self.rebnconv7(hx6)
88
+
89
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
90
+ hx6dup = _upsample_like(hx6d, hx5)
91
+
92
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
93
+ hx5dup = _upsample_like(hx5d, hx4)
94
+
95
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
96
+ hx4dup = _upsample_like(hx4d, hx3)
97
+
98
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
99
+ hx3dup = _upsample_like(hx3d, hx2)
100
+
101
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
102
+ hx2dup = _upsample_like(hx2d, hx1)
103
+
104
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
105
+
106
+ """
107
+ del hx1, hx2, hx3, hx4, hx5, hx6, hx7
108
+ del hx6d, hx5d, hx3d, hx2d
109
+ del hx2dup, hx3dup, hx4dup, hx5dup, hx6dup
110
+ """
111
+
112
+ return hx1d + hxin
113
+
114
+
115
+ ### RSU-6 ###
116
+ class RSU6(nn.Module): # UNet06DRES(nn.Module):
117
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
118
+ super(RSU6, self).__init__()
119
+
120
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
121
+
122
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
123
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
124
+
125
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
126
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
+
134
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
+
136
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
137
+
138
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
139
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
140
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
141
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
143
+
144
+ def forward(self, x):
145
+
146
+ hx = x
147
+
148
+ hxin = self.rebnconvin(hx)
149
+
150
+ hx1 = self.rebnconv1(hxin)
151
+ hx = self.pool1(hx1)
152
+
153
+ hx2 = self.rebnconv2(hx)
154
+ hx = self.pool2(hx2)
155
+
156
+ hx3 = self.rebnconv3(hx)
157
+ hx = self.pool3(hx3)
158
+
159
+ hx4 = self.rebnconv4(hx)
160
+ hx = self.pool4(hx4)
161
+
162
+ hx5 = self.rebnconv5(hx)
163
+
164
+ hx6 = self.rebnconv6(hx5)
165
+
166
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
167
+ hx5dup = _upsample_like(hx5d, hx4)
168
+
169
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
170
+ hx4dup = _upsample_like(hx4d, hx3)
171
+
172
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
173
+ hx3dup = _upsample_like(hx3d, hx2)
174
+
175
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
176
+ hx2dup = _upsample_like(hx2d, hx1)
177
+
178
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
179
+
180
+ """
181
+ del hx1, hx2, hx3, hx4, hx5, hx6
182
+ del hx5d, hx4d, hx3d, hx2d
183
+ del hx2dup, hx3dup, hx4dup, hx5dup
184
+ """
185
+
186
+ return hx1d + hxin
187
+
188
+
189
+ ### RSU-5 ###
190
+ class RSU5(nn.Module): # UNet05DRES(nn.Module):
191
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
192
+ super(RSU5, self).__init__()
193
+
194
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
195
+
196
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
197
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
198
+
199
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
200
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
201
+
202
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
203
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
204
+
205
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
206
+
207
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
208
+
209
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
210
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
211
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
212
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
213
+
214
+ def forward(self, x):
215
+
216
+ hx = x
217
+
218
+ hxin = self.rebnconvin(hx)
219
+
220
+ hx1 = self.rebnconv1(hxin)
221
+ hx = self.pool1(hx1)
222
+
223
+ hx2 = self.rebnconv2(hx)
224
+ hx = self.pool2(hx2)
225
+
226
+ hx3 = self.rebnconv3(hx)
227
+ hx = self.pool3(hx3)
228
+
229
+ hx4 = self.rebnconv4(hx)
230
+
231
+ hx5 = self.rebnconv5(hx4)
232
+
233
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
234
+ hx4dup = _upsample_like(hx4d, hx3)
235
+
236
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
237
+ hx3dup = _upsample_like(hx3d, hx2)
238
+
239
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
240
+ hx2dup = _upsample_like(hx2d, hx1)
241
+
242
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
243
+
244
+ """
245
+ del hx1, hx2, hx3, hx4, hx5
246
+ del hx4d, hx3d, hx2d
247
+ del hx2dup, hx3dup, hx4dup
248
+ """
249
+
250
+ return hx1d + hxin
251
+
252
+
253
+ ### RSU-4 ###
254
+ class RSU4(nn.Module): # UNet04DRES(nn.Module):
255
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
256
+ super(RSU4, self).__init__()
257
+
258
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
259
+
260
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
261
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
262
+
263
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
264
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
265
+
266
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
267
+
268
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
269
+
270
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
271
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
272
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
273
+
274
+ def forward(self, x):
275
+
276
+ hx = x
277
+
278
+ hxin = self.rebnconvin(hx)
279
+
280
+ hx1 = self.rebnconv1(hxin)
281
+ hx = self.pool1(hx1)
282
+
283
+ hx2 = self.rebnconv2(hx)
284
+ hx = self.pool2(hx2)
285
+
286
+ hx3 = self.rebnconv3(hx)
287
+
288
+ hx4 = self.rebnconv4(hx3)
289
+
290
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
291
+ hx3dup = _upsample_like(hx3d, hx2)
292
+
293
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
294
+ hx2dup = _upsample_like(hx2d, hx1)
295
+
296
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
297
+
298
+ """
299
+ del hx1, hx2, hx3, hx4
300
+ del hx3d, hx2d
301
+ del hx2dup, hx3dup
302
+ """
303
+
304
+ return hx1d + hxin
305
+
306
+
307
+ ### RSU-4F ###
308
+ class RSU4F(nn.Module): # UNet04FRES(nn.Module):
309
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
310
+ super(RSU4F, self).__init__()
311
+
312
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
313
+
314
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
315
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
316
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
317
+
318
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
319
+
320
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
321
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
322
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
323
+
324
+ def forward(self, x):
325
+
326
+ hx = x
327
+
328
+ hxin = self.rebnconvin(hx)
329
+
330
+ hx1 = self.rebnconv1(hxin)
331
+ hx2 = self.rebnconv2(hx1)
332
+ hx3 = self.rebnconv3(hx2)
333
+
334
+ hx4 = self.rebnconv4(hx3)
335
+
336
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
337
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
338
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
339
+
340
+ """
341
+ del hx1, hx2, hx3, hx4
342
+ del hx3d, hx2d
343
+ """
344
+
345
+ return hx1d + hxin
346
+
347
+
348
+ ##### U^2-Net ####
349
+ class U2NET(nn.Module):
350
+ def __init__(self, in_ch=3, out_ch=1):
351
+ super(U2NET, self).__init__()
352
+
353
+ self.stage1 = RSU7(in_ch, 32, 64)
354
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
355
+
356
+ self.stage2 = RSU6(64, 32, 128)
357
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
358
+
359
+ self.stage3 = RSU5(128, 64, 256)
360
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
361
+
362
+ self.stage4 = RSU4(256, 128, 512)
363
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
364
+
365
+ self.stage5 = RSU4F(512, 256, 512)
366
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
367
+
368
+ self.stage6 = RSU4F(512, 256, 512)
369
+
370
+ # decoder
371
+ self.stage5d = RSU4F(1024, 256, 512)
372
+ self.stage4d = RSU4(1024, 128, 256)
373
+ self.stage3d = RSU5(512, 64, 128)
374
+ self.stage2d = RSU6(256, 32, 64)
375
+ self.stage1d = RSU7(128, 16, 64)
376
+
377
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
378
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
379
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
380
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
381
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
382
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
383
+
384
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
385
+
386
+ def forward(self, x):
387
+
388
+ hx = x
389
+
390
+ # stage 1
391
+ hx1 = self.stage1(hx)
392
+ hx = self.pool12(hx1)
393
+
394
+ # stage 2
395
+ hx2 = self.stage2(hx)
396
+ hx = self.pool23(hx2)
397
+
398
+ # stage 3
399
+ hx3 = self.stage3(hx)
400
+ hx = self.pool34(hx3)
401
+
402
+ # stage 4
403
+ hx4 = self.stage4(hx)
404
+ hx = self.pool45(hx4)
405
+
406
+ # stage 5
407
+ hx5 = self.stage5(hx)
408
+ hx = self.pool56(hx5)
409
+
410
+ # stage 6
411
+ hx6 = self.stage6(hx)
412
+ hx6up = _upsample_like(hx6, hx5)
413
+
414
+ # -------------------- decoder --------------------
415
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
416
+ hx5dup = _upsample_like(hx5d, hx4)
417
+
418
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
419
+ hx4dup = _upsample_like(hx4d, hx3)
420
+
421
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
422
+ hx3dup = _upsample_like(hx3d, hx2)
423
+
424
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
425
+ hx2dup = _upsample_like(hx2d, hx1)
426
+
427
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
428
+
429
+ # side output
430
+ d1 = self.side1(hx1d)
431
+
432
+ d2 = self.side2(hx2d)
433
+ d2 = _upsample_like(d2, d1)
434
+
435
+ d3 = self.side3(hx3d)
436
+ d3 = _upsample_like(d3, d1)
437
+
438
+ d4 = self.side4(hx4d)
439
+ d4 = _upsample_like(d4, d1)
440
+
441
+ d5 = self.side5(hx5d)
442
+ d5 = _upsample_like(d5, d1)
443
+
444
+ d6 = self.side6(hx6)
445
+ d6 = _upsample_like(d6, d1)
446
+
447
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
448
+
449
+ """
450
+ del hx1, hx2, hx3, hx4, hx5, hx6
451
+ del hx5d, hx4d, hx3d, hx2d, hx1d
452
+ del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
453
+ """
454
+
455
+ return d0, d1, d2, d3, d4, d5, d6
456
+
457
+
458
+ ### U^2-Net small ###
459
+ class U2NETP(nn.Module):
460
+ def __init__(self, in_ch=3, out_ch=1):
461
+ super(U2NETP, self).__init__()
462
+
463
+ self.stage1 = RSU7(in_ch, 16, 64)
464
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
465
+
466
+ self.stage2 = RSU6(64, 16, 64)
467
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
468
+
469
+ self.stage3 = RSU5(64, 16, 64)
470
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
471
+
472
+ self.stage4 = RSU4(64, 16, 64)
473
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
474
+
475
+ self.stage5 = RSU4F(64, 16, 64)
476
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
477
+
478
+ self.stage6 = RSU4F(64, 16, 64)
479
+
480
+ # decoder
481
+ self.stage5d = RSU4F(128, 16, 64)
482
+ self.stage4d = RSU4(128, 16, 64)
483
+ self.stage3d = RSU5(128, 16, 64)
484
+ self.stage2d = RSU6(128, 16, 64)
485
+ self.stage1d = RSU7(128, 16, 64)
486
+
487
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
488
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
489
+ self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
490
+ self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
491
+ self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
492
+ self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
493
+
494
+ self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
495
+
496
+ def forward(self, x):
497
+
498
+ hx = x
499
+
500
+ # stage 1
501
+ hx1 = self.stage1(hx)
502
+ hx = self.pool12(hx1)
503
+
504
+ # stage 2
505
+ hx2 = self.stage2(hx)
506
+ hx = self.pool23(hx2)
507
+
508
+ # stage 3
509
+ hx3 = self.stage3(hx)
510
+ hx = self.pool34(hx3)
511
+
512
+ # stage 4
513
+ hx4 = self.stage4(hx)
514
+ hx = self.pool45(hx4)
515
+
516
+ # stage 5
517
+ hx5 = self.stage5(hx)
518
+ hx = self.pool56(hx5)
519
+
520
+ # stage 6
521
+ hx6 = self.stage6(hx)
522
+ hx6up = _upsample_like(hx6, hx5)
523
+
524
+ # decoder
525
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
526
+ hx5dup = _upsample_like(hx5d, hx4)
527
+
528
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
529
+ hx4dup = _upsample_like(hx4d, hx3)
530
+
531
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
532
+ hx3dup = _upsample_like(hx3d, hx2)
533
+
534
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
535
+ hx2dup = _upsample_like(hx2d, hx1)
536
+
537
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
538
+
539
+ # side output
540
+ d1 = self.side1(hx1d)
541
+
542
+ d2 = self.side2(hx2d)
543
+ d2 = _upsample_like(d2, d1)
544
+
545
+ d3 = self.side3(hx3d)
546
+ d3 = _upsample_like(d3, d1)
547
+
548
+ d4 = self.side4(hx4d)
549
+ d4 = _upsample_like(d4, d1)
550
+
551
+ d5 = self.side5(hx5d)
552
+ d5 = _upsample_like(d5, d1)
553
+
554
+ d6 = self.side6(hx6)
555
+ d6 = _upsample_like(d6, d1)
556
+
557
+ d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
558
+
559
+ """
560
+ del hx1, hx2, hx3, hx4, hx5, hx6
561
+ del hx5d, hx4d, hx3d, hx2d, hx1d
562
+ del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
563
+ """
564
+
565
+ return d0, d1, d2, d3, d4, d5, d6
remove_bg.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+ from PIL import Image
4
+ import numpy as np
5
+ from rembg import remove
6
+
7
+
8
+ class preprcessInput:
9
+
10
+ def __init__(self):
11
+ self.o_width = None
12
+ self.o_height = None
13
+ self.o_image = None
14
+
15
+ self.t_width = None
16
+ self.t_height = None
17
+ self.t_image = None
18
+ self.save_path = None
19
+
20
+ def remove_bg(self, file_path: str):
21
+ self.save_path = file_path[:-3]+'.png'
22
+ pic = Image.open(file_path)
23
+ self.o_width = np.asarray(pic).shape[1]
24
+ self.o_height = np.asarray(pic).shape[0]
25
+ try:
26
+ self.o_channels = np.asarray(pic).shape[2]
27
+ except Exception as e:
28
+ print("Single channel image and error", e)
29
+ os.remove(file_path)
30
+ self.o_image = remove(pic)
31
+ self.o_image.save(self.save_path)
32
+ os.remove(self.save_path)
33
+ return np.asarray(self.o_image)
34
+
35
+ def transform(self, width=768, height=1024):
36
+ newsize = (width, height)
37
+ self.t_height = height
38
+ self.t_width = width
39
+
40
+ pic = self.o_image
41
+ img = pic.resize(newsize)
42
+
43
+ self.t_image = img
44
+
45
+ background = Image.new("RGBA", newsize, (255, 255, 255, 255))
46
+ background.paste(img, mask=img.split()[3]) # 3 is the alpha channel
47
+ self.save_path = self.save_path[:-3] + '.jpg'
48
+ background.convert('RGB').save(self.save_path, 'JPEG')
49
+
50
+ return np.asarray(background.convert('RGB'))
51
+
52
+
53
+ # USAGE OF THE CLASS
54
+ preprocess = preprcessInput()
55
+ for images in os.listdir('/content/inputs/test/image'):
56
+ print(images)
57
+ if images[-3:] == 'jpg':
58
+ op = preprocess.remove_bg(r'/content/inputs/test/image/'+images)
59
+ arr = preprocess.transform(768, 1024)
run.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import os
3
+
4
+
5
+ # running the preprocessing
6
+
7
+ def resize_img(path):
8
+ im = Image.open(path)
9
+ im = im.resize((768, 1024))
10
+ im.save(path)
11
+
12
+
13
+ for path in os.listdir('/content/inputs/test/cloth/'):
14
+ resize_img(f'/content/inputs/test/cloth/{path}')
15
+
16
+ os.chdir('/content/clothes-virtual-try-on')
17
+ os.system("rm -rf /content/inputs/test/cloth/.ipynb_checkpoints")
18
+ os.system("python cloth-mask.py")
19
+ os.chdir('/content')
20
+ os.system("python /content/clothes-virtual-try-on/remove_bg.py")
21
+ os.system(
22
+ "python3 /content/Self-Correction-Human-Parsing/simple_extractor.py --dataset 'lip' --model-restore '/content/Self-Correction-Human-Parsing/checkpoints/final.pth' --input-dir '/content/inputs/test/image' --output-dir '/content/inputs/test/image-parse'")
23
+ os.chdir('/content')
24
+ os.system(
25
+ "cd openpose && ./build/examples/openpose/openpose.bin --image_dir /content/inputs/test/image/ --write_json /content/inputs/test/openpose-json/ --display 0 --render_pose 0 --hand")
26
+ os.system(
27
+ "cd openpose && ./build/examples/openpose/openpose.bin --image_dir /content/inputs/test/image/ --display 0 --write_images /content/inputs/test/openpose-img/ --hand --render_pose 1 --disable_blending true")
28
+
29
+ model_image = os.listdir('/content/inputs/test/image')
30
+ cloth_image = os.listdir('/content/inputs/test/cloth')
31
+ pairs = zip(model_image, cloth_image)
32
+
33
+ with open('/content/inputs/test_pairs.txt', 'w') as file:
34
+ for model, cloth in pairs:
35
+ file.write(f"{model} {cloth}")
36
+
37
+ # making predictions
38
+ os.system(
39
+ "python /content/clothes-virtual-try-on/test.py --name output --dataset_dir /content/inputs --checkpoint_dir /content/clothes-virtual-try-on/checkpoints --save_dir /content/")
40
+ os.system("rm -rf /content/inputs")
41
+ os.system("rm -rf /content/output/.ipynb_checkpoints")
setup_gradio.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
setup_ngrok.ipynb ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "8gqt11Y_RYxU"
7
+ },
8
+ "source": [
9
+ "# Setting up the environment. PLEASE WAIT 🙃"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "id": "RHmGnnTZL6os"
17
+ },
18
+ "outputs": [],
19
+ "source": [
20
+ "!pip install --upgrade --no-cache-dir gdown\n",
21
+ "!pip install rembg[gpu]"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {
28
+ "id": "uX3LsFFPKSwo"
29
+ },
30
+ "outputs": [],
31
+ "source": [
32
+ "! wget -c \"https://github.com/Kitware/CMake/releases/download/v3.19.6/cmake-3.19.6.tar.gz\"\n",
33
+ "! tar xf cmake-3.19.6.tar.gz\n",
34
+ "! cd cmake-3.19.6 && ./configure && make && sudo make install"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {
41
+ "id": "51QJAhPOK9cK"
42
+ },
43
+ "outputs": [],
44
+ "source": [
45
+ "# Install library\n",
46
+ "! sudo apt-get --assume-yes update\n",
47
+ "! sudo apt-get --assume-yes install build-essential\n",
48
+ "# OpenCV\n",
49
+ "! sudo apt-get --assume-yes install libopencv-dev\n",
50
+ "# General dependencies\n",
51
+ "! sudo apt-get --assume-yes install libatlas-base-dev libprotobuf-dev libleveldb-dev libsnappy-dev libhdf5-serial-dev protobuf-compiler\n",
52
+ "! sudo apt-get --assume-yes install --no-install-recommends libboost-all-dev\n",
53
+ "# Remaining dependencies, 14.04\n",
54
+ "! sudo apt-get --assume-yes install libgflags-dev libgoogle-glog-dev liblmdb-dev\n",
55
+ "# Python3 libs\n",
56
+ "! sudo apt-get --assume-yes install python3-setuptools python3-dev build-essential\n",
57
+ "! sudo apt-get --assume-yes install python3-pip\n",
58
+ "! sudo -H pip3 install --upgrade numpy protobuf opencv-python\n",
59
+ "# OpenCL Generic\n",
60
+ "! sudo apt-get --assume-yes install opencl-headers ocl-icd-opencl-dev\n",
61
+ "! sudo apt-get --assume-yes install libviennacl-dev"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": 3,
67
+ "metadata": {
68
+ "colab": {
69
+ "base_uri": "https://localhost:8080/"
70
+ },
71
+ "id": "x7tqqDLHLNDr",
72
+ "outputId": "dc52fffb-2375-4283-da82-f6327a4d73ad"
73
+ },
74
+ "outputs": [
75
+ {
76
+ "name": "stdout",
77
+ "output_type": "stream",
78
+ "text": [
79
+ "v1.7.0\n"
80
+ ]
81
+ }
82
+ ],
83
+ "source": [
84
+ "ver_openpose = \"v1.7.0\"\n",
85
+ "! echo $ver_openpose"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {
92
+ "id": "b11_hkSgLDl5"
93
+ },
94
+ "outputs": [],
95
+ "source": [
96
+ "! git clone --depth 1 -b \"$ver_openpose\" https://github.com/CMU-Perceptual-Computing-Lab/openpose.git"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {
103
+ "id": "s-XyxfV8Q-DE"
104
+ },
105
+ "outputs": [],
106
+ "source": [
107
+ "# manually downloading openpose models\n",
108
+ "%%bash\n",
109
+ "gdown 1QCSxJZpnWvM00hx49CJ2zky7PWGzpcEh\n",
110
+ "unzip models.zip\n",
111
+ "mv /content/models/face/pose_iter_116000.caffemodel /content/openpose/models/face/pose_iter_116000.caffemodel\n",
112
+ "mv /content/models/hand/pose_iter_102000.caffemodel /content/openpose/models/hand/pose_iter_102000.caffemodel\n",
113
+ "mv /content/models/pose/body_25/pose_iter_584000.caffemodel /content/openpose/models/pose/body_25/pose_iter_584000.caffemodel\n",
114
+ "mv /content/models/pose/coco/pose_iter_440000.caffemodel /content/openpose/models/pose/coco/pose_iter_440000.caffemodel\n",
115
+ "mv /content/models/pose/mpi/pose_iter_160000.caffemodel /content/openpose/models/pose/mpi/pose_iter_160000.caffemodel\n",
116
+ "rm -rf models\n",
117
+ "rm models.zip"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": 6,
123
+ "metadata": {
124
+ "id": "Bs-zIObzQLYj"
125
+ },
126
+ "outputs": [],
127
+ "source": [
128
+ "! cd openpose && mkdir build && cd build"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "metadata": {
135
+ "id": "7i7oHh2vQqHv"
136
+ },
137
+ "outputs": [],
138
+ "source": [
139
+ "! cd openpose/build && cmake -DUSE_CUDNN=OFF -DBUILD_PYTHON=ON .."
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": null,
145
+ "metadata": {
146
+ "id": "iBvxsDM-EYJk"
147
+ },
148
+ "outputs": [],
149
+ "source": [
150
+ "# ! cd openpose/build && cmake .."
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "metadata": {
157
+ "id": "XEAY8VW0RzD0"
158
+ },
159
+ "outputs": [],
160
+ "source": [
161
+ "! cd openpose/build && make -j`nproc`\n",
162
+ "! cd openpose && mkdir output"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": 9,
168
+ "metadata": {
169
+ "colab": {
170
+ "base_uri": "https://localhost:8080/"
171
+ },
172
+ "id": "60nEQBKefg3f",
173
+ "outputId": "91903854-5dc4-4661-c6d6-cae5ab56bdf2"
174
+ },
175
+ "outputs": [
176
+ {
177
+ "name": "stdout",
178
+ "output_type": "stream",
179
+ "text": [
180
+ "Collecting flask-ngrok\n",
181
+ " Downloading flask_ngrok-0.0.25-py3-none-any.whl (3.1 kB)\n",
182
+ "Requirement already satisfied: Flask>=0.8 in /usr/local/lib/python3.10/dist-packages (from flask-ngrok) (2.2.5)\n",
183
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from flask-ngrok) (2.31.0)\n",
184
+ "Requirement already satisfied: Werkzeug>=2.2.2 in /usr/local/lib/python3.10/dist-packages (from Flask>=0.8->flask-ngrok) (3.0.1)\n",
185
+ "Requirement already satisfied: Jinja2>=3.0 in /usr/local/lib/python3.10/dist-packages (from Flask>=0.8->flask-ngrok) (3.1.2)\n",
186
+ "Requirement already satisfied: itsdangerous>=2.0 in /usr/local/lib/python3.10/dist-packages (from Flask>=0.8->flask-ngrok) (2.1.2)\n",
187
+ "Requirement already satisfied: click>=8.0 in /usr/local/lib/python3.10/dist-packages (from Flask>=0.8->flask-ngrok) (8.1.7)\n",
188
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->flask-ngrok) (3.3.2)\n",
189
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->flask-ngrok) (3.6)\n",
190
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->flask-ngrok) (2.0.7)\n",
191
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->flask-ngrok) (2023.11.17)\n",
192
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2>=3.0->Flask>=0.8->flask-ngrok) (2.1.3)\n",
193
+ "Installing collected packages: flask-ngrok\n",
194
+ "Successfully installed flask-ngrok-0.0.25\n",
195
+ "Collecting pyngrok==4.1.1\n",
196
+ " Downloading pyngrok-4.1.1.tar.gz (18 kB)\n",
197
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
198
+ "Requirement already satisfied: future in /usr/local/lib/python3.10/dist-packages (from pyngrok==4.1.1) (0.18.3)\n",
199
+ "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from pyngrok==4.1.1) (6.0.1)\n",
200
+ "Building wheels for collected packages: pyngrok\n",
201
+ " Building wheel for pyngrok (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
202
+ " Created wheel for pyngrok: filename=pyngrok-4.1.1-py3-none-any.whl size=15963 sha256=b8ae0d70bccacfc72462262492c96e88942fd5e8113a0a9c6745caea83aad689\n",
203
+ " Stored in directory: /root/.cache/pip/wheels/4c/7c/4c/632fba2ea8e88d8890102eb07bc922e1ca8fa14db5902c91a8\n",
204
+ "Successfully built pyngrok\n",
205
+ "Installing collected packages: pyngrok\n",
206
+ "Successfully installed pyngrok-4.1.1\n",
207
+ "Authtoken saved to configuration file: /root/.ngrok2/ngrok.yml\n"
208
+ ]
209
+ }
210
+ ],
211
+ "source": [
212
+ "!pip install flask-ngrok\n",
213
+ "!pip install pyngrok==4.1.1\n",
214
+ "!ngrok authtoken <your_token>"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": 10,
220
+ "metadata": {
221
+ "colab": {
222
+ "base_uri": "https://localhost:8080/"
223
+ },
224
+ "id": "Fo-q2Q-XMFen",
225
+ "outputId": "a762d258-7da5-44fc-fdad-7fef43ddd361"
226
+ },
227
+ "outputs": [
228
+ {
229
+ "name": "stdout",
230
+ "output_type": "stream",
231
+ "text": [
232
+ "/content\n",
233
+ "Cloning into 'clothes-virtual-try-on'...\n",
234
+ "remote: Enumerating objects: 154, done.\u001b[K\n",
235
+ "remote: Counting objects: 100% (22/22), done.\u001b[K\n",
236
+ "remote: Compressing objects: 100% (10/10), done.\u001b[K\n",
237
+ "remote: Total 154 (delta 16), reused 12 (delta 12), pack-reused 132\u001b[K\n",
238
+ "Receiving objects: 100% (154/154), 20.47 MiB | 33.87 MiB/s, done.\n",
239
+ "Resolving deltas: 100% (54/54), done.\n"
240
+ ]
241
+ }
242
+ ],
243
+ "source": [
244
+ "import os\n",
245
+ "%cd /content/\n",
246
+ "!rm -rf clothes-virtual-try-on\n",
247
+ "!git clone https://github.com/practice404/clothes-virtual-try-on.git\n",
248
+ "os.makedirs(\"/content/clothes-virtual-try-on/checkpoints\")"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": 11,
254
+ "metadata": {
255
+ "colab": {
256
+ "base_uri": "https://localhost:8080/"
257
+ },
258
+ "id": "tnud6ptA9ZwL",
259
+ "outputId": "bc5ee612-eb57-4118-f5e0-6221484c9571"
260
+ },
261
+ "outputs": [
262
+ {
263
+ "name": "stdout",
264
+ "output_type": "stream",
265
+ "text": [
266
+ "/usr/local/lib/python3.10/dist-packages/gdown/cli.py:126: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.\n",
267
+ " warnings.warn(\n",
268
+ "Downloading...\n",
269
+ "From (uriginal): https://drive.google.com/uc?id=18q4lS7cNt1_X8ewCgya1fq0dSk93jTL6\n",
270
+ "From (redirected): https://drive.google.com/uc?id=18q4lS7cNt1_X8ewCgya1fq0dSk93jTL6&confirm=t&uuid=6db7053d-df6d-41e3-ac5b-2f924303335f\n",
271
+ "To: /content/clothes-virtual-try-on/checkpoints/alias_final.pth\n",
272
+ "100% 402M/402M [00:01<00:00, 254MB/s]\n",
273
+ "/usr/local/lib/python3.10/dist-packages/gdown/cli.py:126: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.\n",
274
+ " warnings.warn(\n",
275
+ "Downloading...\n",
276
+ "From: https://drive.google.com/uc?id=1uDRPY8gh9sHb3UDonq6ZrINqDOd7pmTz\n",
277
+ "To: /content/clothes-virtual-try-on/checkpoints/gmm_final.pth\n",
278
+ "100% 76.2M/76.2M [00:00<00:00, 223MB/s]\n",
279
+ "/usr/local/lib/python3.10/dist-packages/gdown/cli.py:126: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.\n",
280
+ " warnings.warn(\n",
281
+ "Downloading...\n",
282
+ "From (uriginal): https://drive.google.com/uc?id=1d7lZNLh51Qt5Mi1lXqyi6Asb2ncLrEdC\n",
283
+ "From (redirected): https://drive.google.com/uc?id=1d7lZNLh51Qt5Mi1lXqyi6Asb2ncLrEdC&confirm=t&uuid=78aeda19-f21d-4598-8bdf-d08a78a99149\n",
284
+ "To: /content/clothes-virtual-try-on/checkpoints/seg_final.pth\n",
285
+ "100% 138M/138M [00:01<00:00, 135MB/s]\n"
286
+ ]
287
+ }
288
+ ],
289
+ "source": [
290
+ "!gdown --id 18q4lS7cNt1_X8ewCgya1fq0dSk93jTL6 --output /content/clothes-virtual-try-on/checkpoints/alias_final.pth\n",
291
+ "!gdown --id 1uDRPY8gh9sHb3UDonq6ZrINqDOd7pmTz --output /content/clothes-virtual-try-on/checkpoints/gmm_final.pth\n",
292
+ "!gdown --id 1d7lZNLh51Qt5Mi1lXqyi6Asb2ncLrEdC --output /content/clothes-virtual-try-on/checkpoints/seg_final.pth"
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "code",
297
+ "execution_count": 12,
298
+ "metadata": {
299
+ "colab": {
300
+ "base_uri": "https://localhost:8080/"
301
+ },
302
+ "id": "qWPkjShFMK82",
303
+ "outputId": "cf51a4d3-4833-4788-9878-92a791a944b8"
304
+ },
305
+ "outputs": [
306
+ {
307
+ "name": "stdout",
308
+ "output_type": "stream",
309
+ "text": [
310
+ "/usr/local/lib/python3.10/dist-packages/gdown/cli.py:126: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.\n",
311
+ " warnings.warn(\n",
312
+ "Downloading...\n",
313
+ "From (uriginal): https://drive.google.com/uc?id=1ysEoAJNxou7RNuT9iKOxRhjVRNY5RLjx\n",
314
+ "From (redirected): https://drive.google.com/uc?id=1ysEoAJNxou7RNuT9iKOxRhjVRNY5RLjx&confirm=t&uuid=50dc2d49-15b3-47ed-905f-fc2455dfea07\n",
315
+ "To: /content/clothes-virtual-try-on/cloth_segm_u2net_latest.pth\n",
316
+ "100% 177M/177M [00:00<00:00, 178MB/s]\n"
317
+ ]
318
+ }
319
+ ],
320
+ "source": [
321
+ "!gdown --id 1ysEoAJNxou7RNuT9iKOxRhjVRNY5RLjx --output /content/clothes-virtual-try-on/cloth_segm_u2net_latest.pth --no-cookies"
322
+ ]
323
+ },
324
+ {
325
+ "cell_type": "code",
326
+ "execution_count": 13,
327
+ "metadata": {
328
+ "colab": {
329
+ "base_uri": "https://localhost:8080/"
330
+ },
331
+ "id": "I9MhYntvMP84",
332
+ "outputId": "774ecb82-0f56-4b65-a6fd-0baa9416d75c"
333
+ },
334
+ "outputs": [
335
+ {
336
+ "name": "stdout",
337
+ "output_type": "stream",
338
+ "text": [
339
+ "/content\n",
340
+ "Collecting ninja\n",
341
+ " Downloading ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (307 kB)\n",
342
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m307.2/307.2 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
343
+ "\u001b[?25hInstalling collected packages: ninja\n",
344
+ "Successfully installed ninja-1.11.1.1\n"
345
+ ]
346
+ }
347
+ ],
348
+ "source": [
349
+ "%cd /content/\n",
350
+ "!pip install ninja"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": 14,
356
+ "metadata": {
357
+ "colab": {
358
+ "base_uri": "https://localhost:8080/"
359
+ },
360
+ "id": "Rz9LOnvyMWEJ",
361
+ "outputId": "55af9ee0-cdf5-495f-a14f-ac93695a5fbe"
362
+ },
363
+ "outputs": [
364
+ {
365
+ "name": "stdout",
366
+ "output_type": "stream",
367
+ "text": [
368
+ "Cloning into 'Self-Correction-Human-Parsing'...\n",
369
+ "remote: Enumerating objects: 719, done.\u001b[K\n",
370
+ "remote: Counting objects: 100% (719/719), done.\u001b[K\n",
371
+ "remote: Compressing objects: 100% (568/568), done.\u001b[K\n",
372
+ "remote: Total 719 (delta 149), reused 611 (delta 140), pack-reused 0\u001b[K\n",
373
+ "Receiving objects: 100% (719/719), 3.88 MiB | 12.81 MiB/s, done.\n",
374
+ "Resolving deltas: 100% (149/149), done.\n",
375
+ "/content/Self-Correction-Human-Parsing\n"
376
+ ]
377
+ }
378
+ ],
379
+ "source": [
380
+ "!git clone https://github.com/PeikeLi/Self-Correction-Human-Parsing\n",
381
+ "%cd Self-Correction-Human-Parsing\n",
382
+ "!mkdir checkpoints"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": 15,
388
+ "metadata": {
389
+ "colab": {
390
+ "base_uri": "https://localhost:8080/"
391
+ },
392
+ "id": "b2k0DLCsMaG0",
393
+ "outputId": "a28b0d51-14a3-426b-a2cb-d9b209e2b202"
394
+ },
395
+ "outputs": [
396
+ {
397
+ "name": "stdout",
398
+ "output_type": "stream",
399
+ "text": [
400
+ "/usr/local/lib/python3.10/dist-packages/gdown/cli.py:126: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.\n",
401
+ " warnings.warn(\n",
402
+ "Downloading...\n",
403
+ "From (uriginal): https://drive.google.com/uc?id=1k4dllHpu0bdx38J7H28rVVLpU-kOHmnH\n",
404
+ "From (redirected): https://drive.google.com/uc?id=1k4dllHpu0bdx38J7H28rVVLpU-kOHmnH&confirm=t&uuid=83091795-9ef5-449e-8d11-008bbe2238eb\n",
405
+ "To: /content/Self-Correction-Human-Parsing/exp-schp-201908261155-lip.pth\n",
406
+ "100% 267M/267M [00:01<00:00, 210MB/s]\n"
407
+ ]
408
+ }
409
+ ],
410
+ "source": [
411
+ "# downloading LIP dataset model\n",
412
+ "!gdown --id 1k4dllHpu0bdx38J7H28rVVLpU-kOHmnH\n",
413
+ "!mv /content/Self-Correction-Human-Parsing/exp-schp-201908261155-lip.pth /content/Self-Correction-Human-Parsing/checkpoints/final.pth"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": 16,
419
+ "metadata": {
420
+ "colab": {
421
+ "base_uri": "https://localhost:8080/"
422
+ },
423
+ "id": "2Y4f3VRyMd9Z",
424
+ "outputId": "c81b2103-5c6b-4af5-aad1-c51dc65b1015"
425
+ },
426
+ "outputs": [
427
+ {
428
+ "name": "stdout",
429
+ "output_type": "stream",
430
+ "text": [
431
+ "/content\n"
432
+ ]
433
+ }
434
+ ],
435
+ "source": [
436
+ "%cd /content/"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": null,
442
+ "metadata": {
443
+ "id": "1k2dVj4vMhwA"
444
+ },
445
+ "outputs": [],
446
+ "source": [
447
+ "%%bash\n",
448
+ "MINICONDA_INSTALLER_SCRIPT=Miniconda3-4.5.4-Linux-x86_64.sh\n",
449
+ "MINICONDA_PREFIX=/usr/local\n",
450
+ "wget https://repo.continuum.io/miniconda/$MINICONDA_INSTALLER_SCRIPT\n",
451
+ "chmod +x $MINICONDA_INSTALLER_SCRIPT\n",
452
+ "./$MINICONDA_INSTALLER_SCRIPT -b -f -p $MINICONDA_PREFIX\n",
453
+ "conda install --channel defaults conda python=3.8 --yes\n",
454
+ "conda update --channel defaults --all --yes"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "code",
459
+ "execution_count": 18,
460
+ "metadata": {
461
+ "id": "I6entwp3MliV"
462
+ },
463
+ "outputs": [],
464
+ "source": [
465
+ "import sys\n",
466
+ "_ = (sys.path\n",
467
+ " .append(\"/usr/local/lib/python3.6/site-packages\"))"
468
+ ]
469
+ },
470
+ {
471
+ "cell_type": "code",
472
+ "execution_count": null,
473
+ "metadata": {
474
+ "id": "cseosViyMtYx"
475
+ },
476
+ "outputs": [],
477
+ "source": [
478
+ "!conda install --channel conda-forge featuretools --yes"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": null,
484
+ "metadata": {
485
+ "id": "BEnK6NI6M0cz"
486
+ },
487
+ "outputs": [],
488
+ "source": [
489
+ "!pip install opencv-python torchgeometry"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "code",
494
+ "execution_count": null,
495
+ "metadata": {
496
+ "id": "HkySNWttHdW2"
497
+ },
498
+ "outputs": [],
499
+ "source": [
500
+ "!pip install torchvision"
501
+ ]
502
+ },
503
+ {
504
+ "cell_type": "markdown",
505
+ "metadata": {
506
+ "id": "-wvtaXujRhNp"
507
+ },
508
+ "source": [
509
+ "# Welcome to Virtual-Cloth-Assistant\n",
510
+ "\n",
511
+ "> It'll take some extra time in first execution for setting up and downloading of model weights"
512
+ ]
513
+ },
514
+ {
515
+ "cell_type": "code",
516
+ "execution_count": 22,
517
+ "metadata": {
518
+ "id": "RwcUm39LM8H0"
519
+ },
520
+ "outputs": [],
521
+ "source": [
522
+ "def make_dir():\n",
523
+ " os.system(\"cd /content/ && mkdir inputs && cd inputs && mkdir test && cd test && mkdir cloth cloth-mask image image-parse openpose-img openpose-json\")"
524
+ ]
525
+ },
526
+ {
527
+ "cell_type": "code",
528
+ "execution_count": 28,
529
+ "metadata": {
530
+ "id": "9jGRSFuEM9-q"
531
+ },
532
+ "outputs": [],
533
+ "source": [
534
+ "from flask import Flask, request, send_file, jsonify\n",
535
+ "from flask_ngrok import run_with_ngrok\n",
536
+ "from PIL import Image\n",
537
+ "import base64\n",
538
+ "import io\n",
539
+ "\n",
540
+ "app = Flask(__name__)\n",
541
+ "run_with_ngrok(app)\n",
542
+ "\n",
543
+ "@app.route(\"/\")\n",
544
+ "def home():\n",
545
+ " return jsonify(\"hello world\");\n",
546
+ "\n",
547
+ "@app.route(\"/api/transform\", methods=['POST'])\n",
548
+ "def begin():\n",
549
+ " make_dir()\n",
550
+ " print(\"data recieved\")\n",
551
+ " cloth = request.files['cloth']\n",
552
+ " model = request.files['model']\n",
553
+ "\n",
554
+ " cloth = Image.open(cloth.stream)\n",
555
+ " model = Image.open(model.stream)\n",
556
+ "\n",
557
+ " cloth.save(\"/content/inputs/test/cloth/cloth.jpg\")\n",
558
+ " model.save(\"/content/inputs/test/image/model.jpg\")\n",
559
+ "\n",
560
+ " # running script to compute the predictions\n",
561
+ " os.system(\"python /content/clothes-virtual-try-on/run.py\")\n",
562
+ "\n",
563
+ " # loading output\n",
564
+ " op = os.listdir(\"/content/output\")[0]\n",
565
+ " op = Image.open(f\"/content/output/{op}\")\n",
566
+ " buffer = io.BytesIO()\n",
567
+ " op.save(buffer, 'png')\n",
568
+ " buffer.seek(0)\n",
569
+ " os.system(\"rm -rf /content/output/\")\n",
570
+ " return send_file(buffer, mimetype='image/gif')"
571
+ ]
572
+ },
573
+ {
574
+ "cell_type": "code",
575
+ "execution_count": 29,
576
+ "metadata": {
577
+ "colab": {
578
+ "base_uri": "https://localhost:8080/"
579
+ },
580
+ "id": "Pl52gbcqZ3GL",
581
+ "outputId": "72cbd1c5-afa8-4411-9b3e-75b192b65a07"
582
+ },
583
+ "outputs": [
584
+ {
585
+ "name": "stdout",
586
+ "output_type": "stream",
587
+ "text": [
588
+ " * Serving Flask app '__main__'\n",
589
+ " * Debug mode: off\n"
590
+ ]
591
+ },
592
+ {
593
+ "name": "stderr",
594
+ "output_type": "stream",
595
+ "text": [
596
+ "INFO:werkzeug:\u001b[31m\u001b[1mWARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead.\u001b[0m\n",
597
+ " * Running on http://127.0.0.1:5000\n",
598
+ "INFO:werkzeug:\u001b[33mPress CTRL+C to quit\u001b[0m\n"
599
+ ]
600
+ },
601
+ {
602
+ "name": "stdout",
603
+ "output_type": "stream",
604
+ "text": [
605
+ " * Running on http://e793-34-123-73-186.ngrok-free.app\n",
606
+ " * Traffic stats available on http://127.0.0.1:4040\n",
607
+ "data recieved\n"
608
+ ]
609
+ },
610
+ {
611
+ "name": "stderr",
612
+ "output_type": "stream",
613
+ "text": [
614
+ "INFO:werkzeug:127.0.0.1 - - [19/Dec/2023 14:36:09] \"POST /api/transform HTTP/1.1\" 200 -\n"
615
+ ]
616
+ }
617
+ ],
618
+ "source": [
619
+ "if __name__ == '__main__':\n",
620
+ " app.run()"
621
+ ]
622
+ }
623
+ ],
624
+ "metadata": {
625
+ "accelerator": "GPU",
626
+ "colab": {
627
+ "collapsed_sections": [
628
+ "8gqt11Y_RYxU"
629
+ ],
630
+ "provenance": []
631
+ },
632
+ "gpuClass": "standard",
633
+ "kernelspec": {
634
+ "display_name": "Python 3",
635
+ "name": "python3"
636
+ },
637
+ "language_info": {
638
+ "name": "python"
639
+ }
640
+ },
641
+ "nbformat": 4,
642
+ "nbformat_minor": 0
643
+ }
test.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ import torchgeometry as tgm
8
+
9
+ from datasets import VITONDataset, VITONDataLoader
10
+ from network import SegGenerator, GMM, ALIASGenerator
11
+ from utils import gen_noise, load_checkpoint, save_images
12
+
13
+
14
+ def get_opt():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--name', type=str, required=True)
17
+
18
+ parser.add_argument('-b', '--batch_size', type=int, default=1)
19
+ parser.add_argument('-j', '--workers', type=int, default=1)
20
+ parser.add_argument('--load_height', type=int, default=1024)
21
+ parser.add_argument('--load_width', type=int, default=768)
22
+ parser.add_argument('--shuffle', action='store_true')
23
+
24
+ parser.add_argument('--dataset_dir', type=str, default='./datasets/')
25
+ parser.add_argument('--dataset_mode', type=str, default='test')
26
+ parser.add_argument('--dataset_list', type=str, default='test_pairs.txt')
27
+ parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints/')
28
+ parser.add_argument('--save_dir', type=str, default='./results/')
29
+
30
+ parser.add_argument('--display_freq', type=int, default=1)
31
+
32
+ parser.add_argument('--seg_checkpoint', type=str, default='seg_final.pth')
33
+ parser.add_argument('--gmm_checkpoint', type=str, default='gmm_final.pth')
34
+ parser.add_argument('--alias_checkpoint', type=str, default='alias_final.pth')
35
+
36
+ # common
37
+ parser.add_argument('--semantic_nc', type=int, default=13, help='# of human-parsing map classes')
38
+ parser.add_argument('--init_type', choices=['normal', 'xavier', 'xavier_uniform', 'kaiming', 'orthogonal', 'none'], default='xavier')
39
+ parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
40
+
41
+ # for GMM
42
+ parser.add_argument('--grid_size', type=int, default=5)
43
+
44
+ # for ALIASGenerator
45
+ parser.add_argument('--norm_G', type=str, default='spectralaliasinstance')
46
+ parser.add_argument('--ngf', type=int, default=64, help='# of generator filters in the first conv layer')
47
+ parser.add_argument('--num_upsampling_layers', choices=['normal', 'more', 'most'], default='most',
48
+ help='If \'more\', add upsampling layer between the two middle resnet blocks. '
49
+ 'If \'most\', also add one more (upsampling + resnet) layer at the end of the generator.')
50
+
51
+ opt = parser.parse_args()
52
+ return opt
53
+
54
+
55
+ def test(opt, seg, gmm, alias):
56
+ up = nn.Upsample(size=(opt.load_height, opt.load_width), mode='bilinear')
57
+ gauss = tgm.image.GaussianBlur((15, 15), (3, 3))
58
+ gauss.cuda()
59
+
60
+ test_dataset = VITONDataset(opt)
61
+ test_loader = VITONDataLoader(opt, test_dataset)
62
+
63
+ with torch.no_grad():
64
+ for i, inputs in enumerate(test_loader.data_loader):
65
+ img_names = inputs['img_name']
66
+ c_names = inputs['c_name']['unpaired']
67
+
68
+ img_agnostic = inputs['img_agnostic'].cuda()
69
+ parse_agnostic = inputs['parse_agnostic'].cuda()
70
+ pose = inputs['pose'].cuda()
71
+ c = inputs['cloth']['unpaired'].cuda()
72
+ cm = inputs['cloth_mask']['unpaired'].cuda()
73
+
74
+ # Part 1. Segmentation generation
75
+ parse_agnostic_down = F.interpolate(parse_agnostic, size=(256, 192), mode='bilinear')
76
+ pose_down = F.interpolate(pose, size=(256, 192), mode='bilinear')
77
+ c_masked_down = F.interpolate(c * cm, size=(256, 192), mode='bilinear')
78
+ cm_down = F.interpolate(cm, size=(256, 192), mode='bilinear')
79
+ seg_input = torch.cat((cm_down, c_masked_down, parse_agnostic_down, pose_down, gen_noise(cm_down.size()).cuda()), dim=1)
80
+
81
+ parse_pred_down = seg(seg_input)
82
+ parse_pred = gauss(up(parse_pred_down))
83
+ parse_pred = parse_pred.argmax(dim=1)[:, None]
84
+
85
+ parse_old = torch.zeros(parse_pred.size(0), 13, opt.load_height, opt.load_width, dtype=torch.float).cuda()
86
+ parse_old.scatter_(1, parse_pred, 1.0)
87
+
88
+ labels = {
89
+ 0: ['background', [0]],
90
+ 1: ['paste', [2, 4, 7, 8, 9, 10, 11]],
91
+ 2: ['upper', [3]],
92
+ 3: ['hair', [1]],
93
+ 4: ['left_arm', [5]],
94
+ 5: ['right_arm', [6]],
95
+ 6: ['noise', [12]]
96
+ }
97
+ parse = torch.zeros(parse_pred.size(0), 7, opt.load_height, opt.load_width, dtype=torch.float).cuda()
98
+ for j in range(len(labels)):
99
+ for label in labels[j][1]:
100
+ parse[:, j] += parse_old[:, label]
101
+
102
+ # Part 2. Clothes Deformation
103
+ agnostic_gmm = F.interpolate(img_agnostic, size=(256, 192), mode='nearest')
104
+ parse_cloth_gmm = F.interpolate(parse[:, 2:3], size=(256, 192), mode='nearest')
105
+ pose_gmm = F.interpolate(pose, size=(256, 192), mode='nearest')
106
+ c_gmm = F.interpolate(c, size=(256, 192), mode='nearest')
107
+ gmm_input = torch.cat((parse_cloth_gmm, pose_gmm, agnostic_gmm), dim=1)
108
+
109
+ _, warped_grid = gmm(gmm_input, c_gmm)
110
+ warped_c = F.grid_sample(c, warped_grid, padding_mode='border')
111
+ warped_cm = F.grid_sample(cm, warped_grid, padding_mode='border')
112
+
113
+ # Part 3. Try-on synthesis
114
+ misalign_mask = parse[:, 2:3] - warped_cm
115
+ misalign_mask[misalign_mask < 0.0] = 0.0
116
+ parse_div = torch.cat((parse, misalign_mask), dim=1)
117
+ parse_div[:, 2:3] -= misalign_mask
118
+
119
+ output = alias(torch.cat((img_agnostic, pose, warped_c), dim=1), parse, parse_div, misalign_mask)
120
+
121
+ unpaired_names = []
122
+ for img_name, c_name in zip(img_names, c_names):
123
+ unpaired_names.append('{}_{}'.format(img_name.split('_')[0], c_name))
124
+
125
+ save_images(output, unpaired_names, os.path.join(opt.save_dir, opt.name))
126
+
127
+ if (i + 1) % opt.display_freq == 0:
128
+ print("step: {}".format(i + 1))
129
+
130
+
131
+ def main():
132
+ opt = get_opt()
133
+ print(opt)
134
+
135
+ if not os.path.exists(os.path.join(opt.save_dir, opt.name)):
136
+ os.makedirs(os.path.join(opt.save_dir, opt.name))
137
+
138
+ seg = SegGenerator(opt, input_nc=opt.semantic_nc + 8, output_nc=opt.semantic_nc)
139
+ gmm = GMM(opt, inputA_nc=7, inputB_nc=3)
140
+ opt.semantic_nc = 7
141
+ alias = ALIASGenerator(opt, input_nc=9)
142
+ opt.semantic_nc = 13
143
+
144
+ load_checkpoint(seg, os.path.join(opt.checkpoint_dir, opt.seg_checkpoint))
145
+ load_checkpoint(gmm, os.path.join(opt.checkpoint_dir, opt.gmm_checkpoint))
146
+ load_checkpoint(alias, os.path.join(opt.checkpoint_dir, opt.alias_checkpoint))
147
+
148
+ seg.cuda().eval()
149
+ gmm.cuda().eval()
150
+ alias.cuda().eval()
151
+ test(opt, seg, gmm, alias)
152
+
153
+
154
+ if __name__ == '__main__':
155
+ main()
utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+
8
+
9
+ def gen_noise(shape):
10
+ noise = np.zeros(shape, dtype=np.uint8)
11
+ ### noise
12
+ noise = cv2.randn(noise, 0, 255)
13
+ noise = np.asarray(noise / 255, dtype=np.uint8)
14
+ noise = torch.tensor(noise, dtype=torch.float32)
15
+ return noise
16
+
17
+
18
+ def save_images(img_tensors, img_names, save_dir):
19
+ for img_tensor, img_name in zip(img_tensors, img_names):
20
+ tensor = (img_tensor.clone()+1)*0.5 * 255
21
+ tensor = tensor.cpu().clamp(0,255)
22
+
23
+ try:
24
+ array = tensor.numpy().astype('uint8')
25
+ except:
26
+ array = tensor.detach().numpy().astype('uint8')
27
+
28
+ if array.shape[0] == 1:
29
+ array = array.squeeze(0)
30
+ elif array.shape[0] == 3:
31
+ array = array.swapaxes(0, 1).swapaxes(1, 2)
32
+
33
+ im = Image.fromarray(array)
34
+ im.save(os.path.join(save_dir, img_name), format='JPEG')
35
+
36
+
37
+ def load_checkpoint(model, checkpoint_path):
38
+ if not os.path.exists(checkpoint_path):
39
+ raise ValueError("'{}' is not a valid checkpoint path".format(checkpoint_path))
40
+ model.load_state_dict(torch.load(checkpoint_path))