Spaces:
Runtime error
Runtime error
victorisgeek
commited on
Commit
•
10ca7b7
1
Parent(s):
059b53a
Upload folder using huggingface_hub
Browse files- .gitignore +162 -0
- README.md +50 -12
- assets/cloth/01260_00.jpg +0 -0
- assets/cloth/01430_00.jpg +0 -0
- assets/cloth/02783_00.jpg +0 -0
- assets/cloth/03751_00.jpg +0 -0
- assets/cloth/06429_00.jpg +0 -0
- assets/cloth/06802_00.jpg +0 -0
- assets/cloth/07429_00.jpg +0 -0
- assets/cloth/08348_00.jpg +0 -0
- assets/cloth/09933_00.jpg +0 -0
- assets/cloth/11028_00.jpg +0 -0
- assets/cloth/11351_00.jpg +0 -0
- assets/cloth/11791_00.jpg +0 -0
- assets/image/00891_00.jpg +0 -0
- assets/image/03615_00.jpg +0 -0
- assets/image/07445_00.jpg +0 -0
- assets/image/07573_00.jpg +0 -0
- assets/image/08909_00.jpg +0 -0
- assets/image/10549_00.jpg +0 -0
- client-side/app.py +38 -0
- client-side/static/css/style.css +0 -0
- client-side/static/images/logo.png +0 -0
- client-side/static/output/dog.png +0 -0
- client-side/templates/index.html +249 -0
- cloth-mask.py +124 -0
- datasets.py +224 -0
- network.py +526 -0
- networks/__init__.py +1 -0
- networks/u2net.py +565 -0
- remove_bg.py +59 -0
- run.py +41 -0
- setup_gradio.ipynb +0 -0
- setup_ngrok.ipynb +643 -0
- test.py +155 -0
- utils.py +40 -0
.gitignore
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
.idea/
|
161 |
+
.vscode
|
162 |
+
.DS_Store
|
README.md
CHANGED
@@ -1,12 +1,50 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cloths Virtual Try On
|
2 |
+
[![Open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SwayamInSync/clothes-virtual-try-on/blob/main/setup_gradio.ipynb)
|
3 |
+
|
4 |
+
## Updates
|
5 |
+
- **[19/02/2024] From now on this repo won't receive any future updates from my side (Spoiler: It's not gone for good 😉. Expect its return, stronger than ever.) (Community Contributions & Issues discussions are still welcome 🤗)**
|
6 |
+
- [26/12/2023] Added the Gradio interface and removed all the external depenency
|
7 |
+
- [19/12/2023] Fixed the `openpose` installation and missing model weights issue
|
8 |
+
- [19/12/2023] Replaced the `remove.bg` dependecy with `rembg`
|
9 |
+
- [26/04/2023] Fixed the GAN generation issue
|
10 |
+
|
11 |
+
## Star History
|
12 |
+
[![Star History Chart](https://api.star-history.com/svg?repos=SwayamInSync/clothes-virtual-try-on&type=Date)](https://star-history.com/#SwayamInSync/clothes-virtual-try-on&Date)
|
13 |
+
|
14 |
+
## Table of contents
|
15 |
+
- [Cloths Virtual Try On](#cloths-virtual-try-on)
|
16 |
+
- [Table of contents](#table-of-contents)
|
17 |
+
- [General info](#general-info)
|
18 |
+
- [Demo](#demo)
|
19 |
+
- [Block Diagram](#block-diagram)
|
20 |
+
- [Methodology](#methodology)
|
21 |
+
- [Usage](#usage)
|
22 |
+
- [Citation](#citation)
|
23 |
+
|
24 |
+
## General info
|
25 |
+
|
26 |
+
This project is a part of a crework community project. While buying clothes online, it is difficult for a customer to select a desirable outfit in the first attempt because they can’t try on clothes. This project aims to solve this problem.
|
27 |
+
|
28 |
+
<img width="383" alt="general_info" src="https://user-images.githubusercontent.com/63489382/163923011-c2898812-2491-4ec2-beb7-dcaaaf680e4f.png">
|
29 |
+
|
30 |
+
|
31 |
+
## Demo
|
32 |
+
|
33 |
+
https://user-images.githubusercontent.com/63489382/163922795-5dbb0f52-95e4-42c6-95d7-2d965abeba6d.mp4
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
## Block Diagram
|
38 |
+
![block_diagram_whole](https://user-images.githubusercontent.com/63489382/163922947-c1677f79-ad6f-4550-affc-7d4e80f0d247.png)
|
39 |
+
|
40 |
+
|
41 |
+
## Methodology
|
42 |
+
![block_diagram_detailed](https://user-images.githubusercontent.com/63489382/163922991-86d148c2-1a97-48a5-b4ec-d8c16819374a.png)
|
43 |
+
|
44 |
+
|
45 |
+
## Usage
|
46 |
+
- Just Click on `open in colab` button on top of this README file
|
47 |
+
|
48 |
+
|
49 |
+
## Citation
|
50 |
+
**Work in progress**
|
assets/cloth/01260_00.jpg
ADDED
assets/cloth/01430_00.jpg
ADDED
assets/cloth/02783_00.jpg
ADDED
assets/cloth/03751_00.jpg
ADDED
assets/cloth/06429_00.jpg
ADDED
assets/cloth/06802_00.jpg
ADDED
assets/cloth/07429_00.jpg
ADDED
assets/cloth/08348_00.jpg
ADDED
assets/cloth/09933_00.jpg
ADDED
assets/cloth/11028_00.jpg
ADDED
assets/cloth/11351_00.jpg
ADDED
assets/cloth/11791_00.jpg
ADDED
assets/image/00891_00.jpg
ADDED
assets/image/03615_00.jpg
ADDED
assets/image/07445_00.jpg
ADDED
assets/image/07573_00.jpg
ADDED
assets/image/08909_00.jpg
ADDED
assets/image/10549_00.jpg
ADDED
client-side/app.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request, jsonify, render_template
|
2 |
+
from PIL import Image
|
3 |
+
import requests
|
4 |
+
from io import BytesIO
|
5 |
+
import base64
|
6 |
+
|
7 |
+
app = Flask(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
@app.route('/')
|
11 |
+
def home():
|
12 |
+
return render_template("index.html")
|
13 |
+
|
14 |
+
|
15 |
+
@app.route("/preds", methods=['POST'])
|
16 |
+
def submit():
|
17 |
+
cloth = request.files['cloth']
|
18 |
+
model = request.files['model']
|
19 |
+
|
20 |
+
## replace the url from the ngrok url provided on the notebook on server.
|
21 |
+
url = "http://e793-34-123-73-186.ngrok-free.app/api/transform"
|
22 |
+
print("sending")
|
23 |
+
response = requests.post(url=url, files={"cloth":cloth.stream, "model":model.stream})
|
24 |
+
op = Image.open(BytesIO(response.content))
|
25 |
+
|
26 |
+
buffer = BytesIO()
|
27 |
+
op.save(buffer, 'png')
|
28 |
+
buffer.seek(0)
|
29 |
+
|
30 |
+
data = buffer.read()
|
31 |
+
data = base64.b64encode(data).decode()
|
32 |
+
|
33 |
+
|
34 |
+
return render_template('index.html', op=data)
|
35 |
+
# return render_template('index.html', test=True)
|
36 |
+
|
37 |
+
if __name__ == '__main__':
|
38 |
+
app.run(debug=True)
|
client-side/static/css/style.css
ADDED
File without changes
|
client-side/static/images/logo.png
ADDED
client-side/static/output/dog.png
ADDED
client-side/templates/index.html
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!doctype html>
|
2 |
+
<html>
|
3 |
+
|
4 |
+
<head>
|
5 |
+
<meta charset="UTF-8">
|
6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
7 |
+
|
8 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
9 |
+
</head>
|
10 |
+
|
11 |
+
<body>
|
12 |
+
|
13 |
+
<!-- <h1>Welcome to your own virtual clothing assistant</h1>
|
14 |
+
<form action="{{ url_for('submit') }}" method="post" enctype="multipart/form-data">
|
15 |
+
<input type="file" name="cloth">
|
16 |
+
<input type="file" name="model">
|
17 |
+
<button type="submit">Submit</button>
|
18 |
+
</form> -->
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
<header class="text-gray-400 bg-gray-900 body-font">
|
23 |
+
<div class="container mx-auto flex flex-wrap p-5 flex-col md:flex-row items-center">
|
24 |
+
<a class="flex title-font font-medium items-center text-white mb-4 md:mb-0">
|
25 |
+
|
26 |
+
<img src="{{url_for('static', filename='images/logo.png')}}" height="50" width="50" style="border-radius: 50%;"/>
|
27 |
+
|
28 |
+
<span class="ml-3 text-xl">Virtual Cloth Assistant</span>
|
29 |
+
</a>
|
30 |
+
|
31 |
+
|
32 |
+
</div>
|
33 |
+
</header>
|
34 |
+
|
35 |
+
|
36 |
+
<section class="text-gray-400 bg-gray-900 body-font">
|
37 |
+
<div class="container px-5 py-24 mx-auto">
|
38 |
+
<div class="text-center mb-20">
|
39 |
+
<h1 class="sm:text-3xl text-2xl font-medium title-font text-white mb-4">Virtual Cloth Assistant</h1>
|
40 |
+
<p class="text-base leading-relaxed xl:w-2/4 lg:w-3/4 mx-auto text-gray-400 text-opacity-80">Wanna try out, How that cloth suits you ?
|
41 |
+
<br>
|
42 |
+
Upgrade your shopping experience with an intelligent trial room.
|
43 |
+
<br> Check out our API and get your wish fulfilled in seconds!!</p>
|
44 |
+
<div class="flex mt-6 justify-center">
|
45 |
+
<div class="w-16 h-1 rounded-full bg-blue-500 inline-flex"></div>
|
46 |
+
</div>
|
47 |
+
</div>
|
48 |
+
<div class="flex flex-wrap sm:-m-4 -mx-4 -mb-10 -mt-4 md:space-y-0 space-y-6">
|
49 |
+
<div class="p-4 md:w-1/3 flex flex-col text-center items-center">
|
50 |
+
<div
|
51 |
+
class="w-20 h-20 inline-flex items-center justify-center rounded-full bg-gray-800 text-blue-400 mb-5 flex-shrink-0">
|
52 |
+
<svg fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round"
|
53 |
+
stroke-width="2" class="w-10 h-10" viewBox="0 0 24 24">
|
54 |
+
<path d="M22 12h-4l-3 9L9 3l-3 9H2"></path>
|
55 |
+
</svg>
|
56 |
+
</div>
|
57 |
+
<div class="flex-grow">
|
58 |
+
<h2 class="text-white text-lg title-font font-medium mb-3">The Problem</h2>
|
59 |
+
<p class="leading-relaxed text-base">While buying clothes online, it is difficult for a customer to select a desirable outfit in the first attempt because they can’t try on clothes before they are delivered physically.
|
60 |
+
|
61 |
+
</p>
|
62 |
+
|
63 |
+
</div>
|
64 |
+
</div>
|
65 |
+
<div class="p-4 md:w-1/3 flex flex-col text-center items-center">
|
66 |
+
<div
|
67 |
+
class="w-20 h-20 inline-flex items-center justify-center rounded-full bg-gray-800 text-blue-400 mb-5 flex-shrink-0">
|
68 |
+
<svg fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round"
|
69 |
+
stroke-width="2" class="w-10 h-10" viewBox="0 0 24 24">
|
70 |
+
<circle cx="6" cy="6" r="3"></circle>
|
71 |
+
<circle cx="6" cy="18" r="3"></circle>
|
72 |
+
<path d="M20 4L8.12 15.88M14.47 14.48L20 20M8.12 8.12L12 12"></path>
|
73 |
+
</svg>
|
74 |
+
</div>
|
75 |
+
<div class="flex-grow">
|
76 |
+
<h2 class="text-white text-lg title-font font-medium mb-3">The Solution</h2>
|
77 |
+
<p class="leading-relaxed text-base">E-commerce websites can be equipped with virtual trial rooms that allow users to try on multiple clothes virtually and select the best looking outfit in a single attempt.
|
78 |
+
|
79 |
+
</p>
|
80 |
+
|
81 |
+
</div>
|
82 |
+
</div>
|
83 |
+
<div class="p-4 md:w-1/3 flex flex-col text-center items-center">
|
84 |
+
<div
|
85 |
+
class="w-20 h-20 inline-flex items-center justify-center rounded-full bg-gray-800 text-blue-400 mb-5 flex-shrink-0">
|
86 |
+
<svg fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round"
|
87 |
+
stroke-width="2" class="w-10 h-10" viewBox="0 0 24 24">
|
88 |
+
<path d="M20 21v-2a4 4 0 00-4-4H8a4 4 0 00-4 4v2"></path>
|
89 |
+
<circle cx="12" cy="7" r="4"></circle>
|
90 |
+
</svg>
|
91 |
+
</div>
|
92 |
+
<div class="flex-grow">
|
93 |
+
<h2 class="text-white text-lg title-font font-medium mb-3">The Approach</h2>
|
94 |
+
<p class="leading-relaxed text-base">
|
95 |
+
We used Deep Learning to solve this problem. VCA (virtual clothing assistant) for consumers, where
|
96 |
+
user can select the cloth he/she wants to wear and then upload his/her image of any pose they want and VCA is capable to
|
97 |
+
dress that human with his/her selected cloth.
|
98 |
+
</p>
|
99 |
+
|
100 |
+
</div>
|
101 |
+
</div>
|
102 |
+
</div>
|
103 |
+
<!-- <button
|
104 |
+
class="flex mx-auto mt-16 text-white bg-blue-500 border-0 py-2 px-8 focus:outline-none hover:bg-blue-600 rounded text-lg">Button</button> -->
|
105 |
+
</div>
|
106 |
+
</section>
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
<section class="text-gray-400 bg-gray-900 body-font">
|
111 |
+
<form action="{{ url_for('submit') }}" method="post" enctype="multipart/form-data">
|
112 |
+
<div class="container mx-auto flex flex-col px-5 py-24 justify-center items-center">
|
113 |
+
|
114 |
+
<div class="flex flex-wrap -m-2">
|
115 |
+
<div class="p-1 xl:w-1/2 md:w-1/2 w-full">
|
116 |
+
<center><label class="block text-lr font-medium text-white-700"> Cloth Image </label></center>
|
117 |
+
<div
|
118 |
+
class="mt-1 flex justify-center px-6 pt-5 pb-6 border-2 border-gray-300 border-dashed rounded-md">
|
119 |
+
<div class="space-y-1 text-center">
|
120 |
+
<svg class="mx-auto h-12 w-12 text-gray-400" stroke="currentColor" fill="none"
|
121 |
+
viewBox="0 0 48 48" aria-hidden="true">
|
122 |
+
<path
|
123 |
+
d="M28 8H12a4 4 0 00-4 4v20m32-12v8m0 0v8a4 4 0 01-4 4H12a4 4 0 01-4-4v-4m32-4l-3.172-3.172a4 4 0 00-5.656 0L28 28M8 32l9.172-9.172a4 4 0 015.656 0L28 28m0 0l4 4m4-24h8m-4-4v8m-12 4h.02"
|
124 |
+
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" />
|
125 |
+
</svg>
|
126 |
+
<div class="flex text-sm text-gray-600">
|
127 |
+
<!-- <label for="file-upload"
|
128 |
+
class="relative cursor-pointer rounded-md font-medium text-indigo-600 hover:text-indigo-500 focus-within:outline-none focus-within:ring-2 focus-within:ring-offset-2 focus-within:ring-indigo-500">
|
129 |
+
<span>Upload a file</span> -->
|
130 |
+
<input class="block w-full text-sm text-slate-500
|
131 |
+
file:mr-4 file:py-2 file:px-4
|
132 |
+
file:rounded-full file:border-0
|
133 |
+
file:text-sm file:font-semibold
|
134 |
+
file:bg-violet-50 file:text-violet-700
|
135 |
+
hover:file:bg-violet-100" id="file-upload" type="file" name="cloth" class="sr-only">
|
136 |
+
</label>
|
137 |
+
<p class="pl-1">or drag and drop</p>
|
138 |
+
</div>
|
139 |
+
<p class="text-xs text-gray-500">PNG, JPG up to 10MB</p>
|
140 |
+
</div>
|
141 |
+
</div>
|
142 |
+
|
143 |
+
</div>
|
144 |
+
<div class="p-1 xl:w-1/2 md:w-1/2 w-full">
|
145 |
+
<center> <label class="block text-lr font-medium text-white-700"> Model Image </label></center>
|
146 |
+
<div
|
147 |
+
class="mt-1 flex justify-center px-6 pt-5 pb-6 border-2 border-gray-300 border-dashed rounded-md">
|
148 |
+
<div class="space-y-1 text-center">
|
149 |
+
<svg class="mx-auto h-12 w-12 text-gray-400" stroke="currentColor" fill="none"
|
150 |
+
viewBox="0 0 48 48" aria-hidden="true">
|
151 |
+
<path
|
152 |
+
d="M28 8H12a4 4 0 00-4 4v20m32-12v8m0 0v8a4 4 0 01-4 4H12a4 4 0 01-4-4v-4m32-4l-3.172-3.172a4 4 0 00-5.656 0L28 28M8 32l9.172-9.172a4 4 0 015.656 0L28 28m0 0l4 4m4-24h8m-4-4v8m-12 4h.02"
|
153 |
+
stroke-width="2" stroke-linecap="round" stroke-linejoin="round" />
|
154 |
+
</svg>
|
155 |
+
<div class="flex text-sm text-gray-600">
|
156 |
+
<!-- <label for="file-upload"
|
157 |
+
class="relative cursor-pointer rounded-md font-medium text-indigo-600 hover:text-indigo-500 focus-within:outline-none focus-within:ring-2 focus-within:ring-offset-2 focus-within:ring-indigo-500">
|
158 |
+
<span>Upload a file</span> -->
|
159 |
+
<input class="block w-full text-sm text-slate-500
|
160 |
+
file:mr-4 file:py-2 file:px-4
|
161 |
+
file:rounded-full file:border-0
|
162 |
+
file:text-sm file:font-semibold
|
163 |
+
file:bg-violet-50 file:text-violet-700
|
164 |
+
hover:file:bg-violet-100" id="file-upload" type="file" name="model" class="sr-only">
|
165 |
+
</label>
|
166 |
+
<p class="pl-1">or drag and drop</p>
|
167 |
+
</div>
|
168 |
+
<p class="text-xs text-gray-500">PNG, JPG up to 10MB</p>
|
169 |
+
</div>
|
170 |
+
</div>
|
171 |
+
</div>
|
172 |
+
</div>
|
173 |
+
|
174 |
+
<br>
|
175 |
+
<br>
|
176 |
+
|
177 |
+
<div class="w-full md:w-2/3 flex flex-col mb-16 items-center text-center">
|
178 |
+
<h1 class="title-font sm:text-4xl text-3xl mb-4 font-medium text-white">Upload corresponding images
|
179 |
+
and get the Result</h1>
|
180 |
+
<p class="mb-8 leading-relaxed"> Need to wait for 20-30 seconds so chill and drink some water....
|
181 |
+
</p>
|
182 |
+
<div class="flex w-full justify-center items-end">
|
183 |
+
|
184 |
+
<button type="submit"
|
185 |
+
class="inline-flex text-white bg-blue-500 border-0 py-2 px-6 focus:outline-none hover:bg-blue-600 rounded text-lg">Try
|
186 |
+
it</button>
|
187 |
+
</div>
|
188 |
+
|
189 |
+
</button>
|
190 |
+
</div>
|
191 |
+
|
192 |
+
|
193 |
+
</div>
|
194 |
+
|
195 |
+
<div>
|
196 |
+
{% if op %}
|
197 |
+
<center style="color: white; font-size: x-large;">HERE'S IS YOUR RESULT 🤗</center>
|
198 |
+
|
199 |
+
<center>
|
200 |
+
<div class="sm: w-3/4 mb-10 lg:mb-0 rounded-lg overflow-hidden">
|
201 |
+
<img alt="output" class="object-cover object-center h-2/4 w-2/4"
|
202 |
+
src="data:image/png;base64,{{ op }}">
|
203 |
+
</div>
|
204 |
+
</center>
|
205 |
+
{% endif %}
|
206 |
+
</div>
|
207 |
+
</form>
|
208 |
+
|
209 |
+
|
210 |
+
</section>
|
211 |
+
|
212 |
+
|
213 |
+
<footer class="text-gray-400 bg-gray-900 body-font">
|
214 |
+
<div
|
215 |
+
class="container px-5 py-24 mx-auto flex md:items-center lg:items-start md:flex-row md:flex-nowrap flex-wrap flex-col">
|
216 |
+
<div class="w-64 flex-shrink-0 md:mx-0 mx-auto text-center md:text-left md:mt-0 mt-10">
|
217 |
+
<a class="flex title-font font-medium items-center md:justify-start justify-center text-white">
|
218 |
+
<img src="{{url_for('static', filename='images/logo.png')}}" height="50" width="50" style="border-radius: 50%;"/>
|
219 |
+
<span class="ml-3 text-xl">V-Cloth Assistant</span>
|
220 |
+
</a>
|
221 |
+
</div>
|
222 |
+
<div class="flex-grow flex flex-wrap md:pr-20 -mb-10 md:text-left text-center order-first">
|
223 |
+
<div class="lg:w-1/4 md:w-1/2 w-full px-4">
|
224 |
+
<h2 class="title-font font-medium text-white tracking-widest text-sm mb-3">SWAYAM</h2>
|
225 |
+
</div>
|
226 |
+
<div class="lg:w-1/4 md:w-1/2 w-full px-4">
|
227 |
+
<h2 class="title-font font-medium text-white tracking-widest text-sm mb-3">PARTH</h2>
|
228 |
+
</div>
|
229 |
+
<div class="lg:w-1/4 md:w-1/2 w-full px-4">
|
230 |
+
<h2 class="title-font font-medium text-white tracking-widest text-sm mb-3">KEERTHI</h2>
|
231 |
+
</div>
|
232 |
+
<div class="lg:w-1/4 md:w-1/2 w-full px-4">
|
233 |
+
<h2 class="title-font font-medium text-white tracking-widest text-sm mb-3">NAVANEETH</h2>
|
234 |
+
</div>
|
235 |
+
</div>
|
236 |
+
</div>
|
237 |
+
<div class="bg-gray-800 bg-opacity-75">
|
238 |
+
<div class="container mx-auto py-4 px-5 flex flex-wrap flex-col sm:flex-row">
|
239 |
+
<p class="text-gray-400 text-sm text-center sm:text-left">© 2022 Crework Batch 3 —
|
240 |
+
<a href="https://twitter.com/knyttneve" class="text-gray-500 ml-1" rel="noopener noreferrer"
|
241 |
+
target="_blank">@Crework</a>
|
242 |
+
</p>
|
243 |
+
</div>
|
244 |
+
</div>
|
245 |
+
</footer>
|
246 |
+
|
247 |
+
</body>
|
248 |
+
|
249 |
+
</html>
|
cloth-mask.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
|
10 |
+
from networks.u2net import U2NET
|
11 |
+
device = 'cuda'
|
12 |
+
|
13 |
+
image_dir = '/content/inputs/test/cloth'
|
14 |
+
result_dir = '/content/inputs/test/cloth-mask'
|
15 |
+
checkpoint_path = 'cloth_segm_u2net_latest.pth'
|
16 |
+
|
17 |
+
def load_checkpoint_mgpu(model, checkpoint_path):
|
18 |
+
if not os.path.exists(checkpoint_path):
|
19 |
+
print("----No checkpoints at given path----")
|
20 |
+
return
|
21 |
+
model_state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
22 |
+
new_state_dict = OrderedDict()
|
23 |
+
for k, v in model_state_dict.items():
|
24 |
+
name = k[7:] # remove `module.`
|
25 |
+
new_state_dict[name] = v
|
26 |
+
|
27 |
+
model.load_state_dict(new_state_dict)
|
28 |
+
print("----checkpoints loaded from path: {}----".format(checkpoint_path))
|
29 |
+
return model
|
30 |
+
|
31 |
+
class Normalize_image(object):
|
32 |
+
"""Normalize given tensor into given mean and standard dev
|
33 |
+
|
34 |
+
Args:
|
35 |
+
mean (float): Desired mean to substract from tensors
|
36 |
+
std (float): Desired std to divide from tensors
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, mean, std):
|
40 |
+
assert isinstance(mean, (float))
|
41 |
+
if isinstance(mean, float):
|
42 |
+
self.mean = mean
|
43 |
+
|
44 |
+
if isinstance(std, float):
|
45 |
+
self.std = std
|
46 |
+
|
47 |
+
self.normalize_1 = transforms.Normalize(self.mean, self.std)
|
48 |
+
self.normalize_3 = transforms.Normalize([self.mean] * 3, [self.std] * 3)
|
49 |
+
self.normalize_18 = transforms.Normalize([self.mean] * 18, [self.std] * 18)
|
50 |
+
|
51 |
+
def __call__(self, image_tensor):
|
52 |
+
if image_tensor.shape[0] == 1:
|
53 |
+
return self.normalize_1(image_tensor)
|
54 |
+
|
55 |
+
elif image_tensor.shape[0] == 3:
|
56 |
+
return self.normalize_3(image_tensor)
|
57 |
+
|
58 |
+
elif image_tensor.shape[0] == 18:
|
59 |
+
return self.normalize_18(image_tensor)
|
60 |
+
|
61 |
+
else:
|
62 |
+
assert "Please set proper channels! Normlization implemented only for 1, 3 and 18"
|
63 |
+
|
64 |
+
|
65 |
+
def get_palette(num_cls):
|
66 |
+
""" Returns the color map for visualizing the segmentation mask.
|
67 |
+
Args:
|
68 |
+
num_cls: Number of classes
|
69 |
+
Returns:
|
70 |
+
The color map
|
71 |
+
"""
|
72 |
+
n = num_cls
|
73 |
+
palette = [0] * (n * 3)
|
74 |
+
for j in range(0, n):
|
75 |
+
lab = j
|
76 |
+
palette[j * 3 + 0] = 0
|
77 |
+
palette[j * 3 + 1] = 0
|
78 |
+
palette[j * 3 + 2] = 0
|
79 |
+
i = 0
|
80 |
+
while lab:
|
81 |
+
palette[j * 3 + 0] = 255
|
82 |
+
palette[j * 3 + 1] = 255
|
83 |
+
palette[j * 3 + 2] = 255
|
84 |
+
# palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
|
85 |
+
# palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
|
86 |
+
# palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
|
87 |
+
i += 1
|
88 |
+
lab >>= 3
|
89 |
+
return palette
|
90 |
+
|
91 |
+
|
92 |
+
transforms_list = []
|
93 |
+
transforms_list += [transforms.ToTensor()]
|
94 |
+
transforms_list += [Normalize_image(0.5, 0.5)]
|
95 |
+
transform_rgb = transforms.Compose(transforms_list)
|
96 |
+
|
97 |
+
net = U2NET(in_ch=3, out_ch=4)
|
98 |
+
net = load_checkpoint_mgpu(net, checkpoint_path)
|
99 |
+
net = net.to(device)
|
100 |
+
net = net.eval()
|
101 |
+
|
102 |
+
palette = get_palette(4)
|
103 |
+
|
104 |
+
images_list = sorted(os.listdir(image_dir))
|
105 |
+
for image_name in images_list:
|
106 |
+
img = Image.open(os.path.join(image_dir, image_name)).convert('RGB')
|
107 |
+
img_size = img.size
|
108 |
+
img = img.resize((768, 768), Image.BICUBIC)
|
109 |
+
image_tensor = transform_rgb(img)
|
110 |
+
image_tensor = torch.unsqueeze(image_tensor, 0)
|
111 |
+
|
112 |
+
output_tensor = net(image_tensor.to(device))
|
113 |
+
output_tensor = F.log_softmax(output_tensor[0], dim=1)
|
114 |
+
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
|
115 |
+
output_tensor = torch.squeeze(output_tensor, dim=0)
|
116 |
+
output_tensor = torch.squeeze(output_tensor, dim=0)
|
117 |
+
output_arr = output_tensor.cpu().numpy()
|
118 |
+
|
119 |
+
output_img = Image.fromarray(output_arr.astype('uint8'), mode='L')
|
120 |
+
output_img = output_img.resize(img_size, Image.BICUBIC)
|
121 |
+
|
122 |
+
output_img.putpalette(palette)
|
123 |
+
output_img = output_img.convert('L')
|
124 |
+
output_img.save(os.path.join(result_dir, image_name[:-4]+'.jpg'))
|
datasets.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from os import path as osp
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image, ImageDraw
|
6 |
+
import torch
|
7 |
+
from torch.utils import data
|
8 |
+
from torchvision import transforms
|
9 |
+
|
10 |
+
|
11 |
+
class VITONDataset(data.Dataset):
|
12 |
+
def __init__(self, opt):
|
13 |
+
super(VITONDataset, self).__init__()
|
14 |
+
self.load_height = opt.load_height
|
15 |
+
self.load_width = opt.load_width
|
16 |
+
self.semantic_nc = opt.semantic_nc
|
17 |
+
self.data_path = osp.join(opt.dataset_dir, opt.dataset_mode)
|
18 |
+
self.transform = transforms.Compose([
|
19 |
+
transforms.ToTensor(),
|
20 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
21 |
+
])
|
22 |
+
|
23 |
+
# load data list
|
24 |
+
img_names = []
|
25 |
+
c_names = []
|
26 |
+
with open(osp.join(opt.dataset_dir, opt.dataset_list), 'r') as f:
|
27 |
+
for line in f.readlines():
|
28 |
+
img_name, c_name = line.strip().split()
|
29 |
+
img_names.append(img_name)
|
30 |
+
c_names.append(c_name)
|
31 |
+
|
32 |
+
self.img_names = img_names
|
33 |
+
self.c_names = dict()
|
34 |
+
self.c_names['unpaired'] = c_names
|
35 |
+
|
36 |
+
def get_parse_agnostic(self, parse, pose_data):
|
37 |
+
parse_array = np.array(parse)
|
38 |
+
parse_upper = ((parse_array == 5).astype(np.float32) +
|
39 |
+
(parse_array == 6).astype(np.float32) +
|
40 |
+
(parse_array == 7).astype(np.float32))
|
41 |
+
parse_neck = (parse_array == 10).astype(np.float32)
|
42 |
+
|
43 |
+
r = 10
|
44 |
+
agnostic = parse.copy()
|
45 |
+
|
46 |
+
# mask arms
|
47 |
+
for parse_id, pose_ids in [(14, [2, 5, 6, 7]), (15, [5, 2, 3, 4])]:
|
48 |
+
mask_arm = Image.new('L', (self.load_width, self.load_height), 'black')
|
49 |
+
mask_arm_draw = ImageDraw.Draw(mask_arm)
|
50 |
+
i_prev = pose_ids[0]
|
51 |
+
for i in pose_ids[1:]:
|
52 |
+
if (pose_data[i_prev, 0] == 0.0 and pose_data[i_prev, 1] == 0.0) or (pose_data[i, 0] == 0.0 and pose_data[i, 1] == 0.0):
|
53 |
+
continue
|
54 |
+
mask_arm_draw.line([tuple(pose_data[j]) for j in [i_prev, i]], 'white', width=r*10)
|
55 |
+
pointx, pointy = pose_data[i]
|
56 |
+
radius = r*4 if i == pose_ids[-1] else r*15
|
57 |
+
mask_arm_draw.ellipse((pointx-radius, pointy-radius, pointx+radius, pointy+radius), 'white', 'white')
|
58 |
+
i_prev = i
|
59 |
+
parse_arm = (np.array(mask_arm) / 255) * (parse_array == parse_id).astype(np.float32)
|
60 |
+
agnostic.paste(0, None, Image.fromarray(np.uint8(parse_arm * 255), 'L'))
|
61 |
+
|
62 |
+
# mask torso & neck
|
63 |
+
agnostic.paste(0, None, Image.fromarray(np.uint8(parse_upper * 255), 'L'))
|
64 |
+
agnostic.paste(0, None, Image.fromarray(np.uint8(parse_neck * 255), 'L'))
|
65 |
+
|
66 |
+
return agnostic
|
67 |
+
|
68 |
+
def get_img_agnostic(self, img, parse, pose_data):
|
69 |
+
parse_array = np.array(parse)
|
70 |
+
parse_head = ((parse_array == 4).astype(np.float32) +
|
71 |
+
(parse_array == 13).astype(np.float32))
|
72 |
+
parse_lower = ((parse_array == 9).astype(np.float32) +
|
73 |
+
(parse_array == 12).astype(np.float32) +
|
74 |
+
(parse_array == 16).astype(np.float32) +
|
75 |
+
(parse_array == 17).astype(np.float32) +
|
76 |
+
(parse_array == 18).astype(np.float32) +
|
77 |
+
(parse_array == 19).astype(np.float32))
|
78 |
+
|
79 |
+
r = 20
|
80 |
+
agnostic = img.copy()
|
81 |
+
agnostic_draw = ImageDraw.Draw(agnostic)
|
82 |
+
|
83 |
+
length_a = np.linalg.norm(pose_data[5] - pose_data[2])
|
84 |
+
length_b = np.linalg.norm(pose_data[12] - pose_data[9])
|
85 |
+
point = (pose_data[9] + pose_data[12]) / 2
|
86 |
+
pose_data[9] = point + (pose_data[9] - point) / length_b * length_a
|
87 |
+
pose_data[12] = point + (pose_data[12] - point) / length_b * length_a
|
88 |
+
|
89 |
+
# mask arms
|
90 |
+
agnostic_draw.line([tuple(pose_data[i]) for i in [2, 5]], 'gray', width=r*10)
|
91 |
+
for i in [2, 5]:
|
92 |
+
pointx, pointy = pose_data[i]
|
93 |
+
agnostic_draw.ellipse((pointx-r*5, pointy-r*5, pointx+r*5, pointy+r*5), 'gray', 'gray')
|
94 |
+
for i in [3, 4, 6, 7]:
|
95 |
+
if (pose_data[i - 1, 0] == 0.0 and pose_data[i - 1, 1] == 0.0) or (pose_data[i, 0] == 0.0 and pose_data[i, 1] == 0.0):
|
96 |
+
continue
|
97 |
+
agnostic_draw.line([tuple(pose_data[j]) for j in [i - 1, i]], 'gray', width=r*10)
|
98 |
+
pointx, pointy = pose_data[i]
|
99 |
+
agnostic_draw.ellipse((pointx-r*5, pointy-r*5, pointx+r*5, pointy+r*5), 'gray', 'gray')
|
100 |
+
|
101 |
+
# mask torso
|
102 |
+
for i in [9, 12]:
|
103 |
+
pointx, pointy = pose_data[i]
|
104 |
+
agnostic_draw.ellipse((pointx-r*3, pointy-r*6, pointx+r*3, pointy+r*6), 'gray', 'gray')
|
105 |
+
agnostic_draw.line([tuple(pose_data[i]) for i in [2, 9]], 'gray', width=r*6)
|
106 |
+
agnostic_draw.line([tuple(pose_data[i]) for i in [5, 12]], 'gray', width=r*6)
|
107 |
+
agnostic_draw.line([tuple(pose_data[i]) for i in [9, 12]], 'gray', width=r*12)
|
108 |
+
agnostic_draw.polygon([tuple(pose_data[i]) for i in [2, 5, 12, 9]], 'gray', 'gray')
|
109 |
+
|
110 |
+
# mask neck
|
111 |
+
pointx, pointy = pose_data[1]
|
112 |
+
agnostic_draw.rectangle((pointx-r*7, pointy-r*7, pointx+r*7, pointy+r*7), 'gray', 'gray')
|
113 |
+
agnostic.paste(img, None, Image.fromarray(np.uint8(parse_head * 255), 'L'))
|
114 |
+
agnostic.paste(img, None, Image.fromarray(np.uint8(parse_lower * 255), 'L'))
|
115 |
+
|
116 |
+
return agnostic
|
117 |
+
|
118 |
+
def __getitem__(self, index):
|
119 |
+
img_name = self.img_names[index]
|
120 |
+
c_name = {}
|
121 |
+
c = {}
|
122 |
+
cm = {}
|
123 |
+
for key in self.c_names:
|
124 |
+
c_name[key] = self.c_names[key][index]
|
125 |
+
c[key] = Image.open(osp.join(self.data_path, 'cloth', c_name[key])).convert('RGB')
|
126 |
+
c[key] = transforms.Resize(self.load_width, interpolation=2)(c[key])
|
127 |
+
cm[key] = Image.open(osp.join(self.data_path, 'cloth-mask', c_name[key]))
|
128 |
+
cm[key] = transforms.Resize(self.load_width, interpolation=0)(cm[key])
|
129 |
+
|
130 |
+
c[key] = self.transform(c[key]) # [-1,1]
|
131 |
+
cm_array = np.array(cm[key])
|
132 |
+
cm_array = (cm_array >= 128).astype(np.float32)
|
133 |
+
cm[key] = torch.from_numpy(cm_array) # [0,1]
|
134 |
+
cm[key].unsqueeze_(0)
|
135 |
+
|
136 |
+
# load pose image
|
137 |
+
pose_name = img_name.replace('.jpg', '_rendered.png')
|
138 |
+
pose_rgb = Image.open(osp.join(self.data_path, 'openpose-img', pose_name))
|
139 |
+
pose_rgb = transforms.Resize(self.load_width, interpolation=2)(pose_rgb)
|
140 |
+
pose_rgb = self.transform(pose_rgb) # [-1,1]
|
141 |
+
|
142 |
+
pose_name = img_name.replace('.jpg', '_keypoints.json')
|
143 |
+
with open(osp.join(self.data_path, 'openpose-json', pose_name), 'r') as f:
|
144 |
+
pose_label = json.load(f)
|
145 |
+
pose_data = pose_label['people'][0]['pose_keypoints_2d']
|
146 |
+
pose_data = np.array(pose_data)
|
147 |
+
pose_data = pose_data.reshape((-1, 3))[:, :2]
|
148 |
+
|
149 |
+
# load parsing image
|
150 |
+
parse_name = img_name.replace('.jpg', '.png')
|
151 |
+
parse = Image.open(osp.join(self.data_path, 'image-parse', parse_name))
|
152 |
+
parse = transforms.Resize(self.load_width, interpolation=0)(parse)
|
153 |
+
parse_agnostic = self.get_parse_agnostic(parse, pose_data)
|
154 |
+
parse_agnostic = torch.from_numpy(np.array(parse_agnostic)[None]).long()
|
155 |
+
|
156 |
+
labels = {
|
157 |
+
0: ['background', [0, 10]],
|
158 |
+
1: ['hair', [1, 2]],
|
159 |
+
2: ['face', [4, 13]],
|
160 |
+
3: ['upper', [5, 6, 7]],
|
161 |
+
4: ['bottom', [9, 12]],
|
162 |
+
5: ['left_arm', [14]],
|
163 |
+
6: ['right_arm', [15]],
|
164 |
+
7: ['left_leg', [16]],
|
165 |
+
8: ['right_leg', [17]],
|
166 |
+
9: ['left_shoe', [18]],
|
167 |
+
10: ['right_shoe', [19]],
|
168 |
+
11: ['socks', [8]],
|
169 |
+
12: ['noise', [3, 11]]
|
170 |
+
}
|
171 |
+
parse_agnostic_map = torch.zeros(20, self.load_height, self.load_width, dtype=torch.float)
|
172 |
+
parse_agnostic_map.scatter_(0, parse_agnostic, 1.0)
|
173 |
+
new_parse_agnostic_map = torch.zeros(self.semantic_nc, self.load_height, self.load_width, dtype=torch.float)
|
174 |
+
for i in range(len(labels)):
|
175 |
+
for label in labels[i][1]:
|
176 |
+
new_parse_agnostic_map[i] += parse_agnostic_map[label]
|
177 |
+
|
178 |
+
# load person image
|
179 |
+
img = Image.open(osp.join(self.data_path, 'image', img_name))
|
180 |
+
img = transforms.Resize(self.load_width, interpolation=2)(img)
|
181 |
+
img_agnostic = self.get_img_agnostic(img, parse, pose_data)
|
182 |
+
img = self.transform(img)
|
183 |
+
img_agnostic = self.transform(img_agnostic) # [-1,1]
|
184 |
+
|
185 |
+
result = {
|
186 |
+
'img_name': img_name,
|
187 |
+
'c_name': c_name,
|
188 |
+
'img': img,
|
189 |
+
'img_agnostic': img_agnostic,
|
190 |
+
'parse_agnostic': new_parse_agnostic_map,
|
191 |
+
'pose': pose_rgb,
|
192 |
+
'cloth': c,
|
193 |
+
'cloth_mask': cm,
|
194 |
+
}
|
195 |
+
return result
|
196 |
+
|
197 |
+
def __len__(self):
|
198 |
+
return len(self.img_names)
|
199 |
+
|
200 |
+
|
201 |
+
class VITONDataLoader:
|
202 |
+
def __init__(self, opt, dataset):
|
203 |
+
super(VITONDataLoader, self).__init__()
|
204 |
+
|
205 |
+
if opt.shuffle:
|
206 |
+
train_sampler = data.sampler.RandomSampler(dataset)
|
207 |
+
else:
|
208 |
+
train_sampler = None
|
209 |
+
|
210 |
+
self.data_loader = data.DataLoader(
|
211 |
+
dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
|
212 |
+
num_workers=opt.workers, pin_memory=True, drop_last=True, sampler=train_sampler
|
213 |
+
)
|
214 |
+
self.dataset = dataset
|
215 |
+
self.data_iter = self.data_loader.__iter__()
|
216 |
+
|
217 |
+
def next_batch(self):
|
218 |
+
try:
|
219 |
+
batch = self.data_iter.__next__()
|
220 |
+
except StopIteration:
|
221 |
+
self.data_iter = self.data_loader.__iter__()
|
222 |
+
batch = self.data_iter.__next__()
|
223 |
+
|
224 |
+
return batch
|
network.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from torch.nn import init
|
6 |
+
from torch.nn.utils.spectral_norm import spectral_norm
|
7 |
+
|
8 |
+
|
9 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
10 |
+
# Common classes
|
11 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
12 |
+
class BaseNetwork(nn.Module):
|
13 |
+
def __init__(self):
|
14 |
+
super(BaseNetwork, self).__init__()
|
15 |
+
|
16 |
+
def print_network(self):
|
17 |
+
num_params = 0
|
18 |
+
for param in self.parameters():
|
19 |
+
num_params += param.numel()
|
20 |
+
print("Network [{}] was created. Total number of parameters: {:.1f} million. "
|
21 |
+
"To see the architecture, do print(network).".format(self.__class__.__name__, num_params / 1000000))
|
22 |
+
|
23 |
+
def init_weights(self, init_type='normal', gain=0.02):
|
24 |
+
def init_func(m):
|
25 |
+
classname = m.__class__.__name__
|
26 |
+
if 'BatchNorm2d' in classname:
|
27 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
28 |
+
init.normal_(m.weight.data, 1.0, gain)
|
29 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
30 |
+
init.constant_(m.bias.data, 0.0)
|
31 |
+
elif ('Conv' in classname or 'Linear' in classname) and hasattr(m, 'weight'):
|
32 |
+
if init_type == 'normal':
|
33 |
+
init.normal_(m.weight.data, 0.0, gain)
|
34 |
+
elif init_type == 'xavier':
|
35 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
36 |
+
elif init_type == 'xavier_uniform':
|
37 |
+
init.xavier_uniform_(m.weight.data, gain=1.0)
|
38 |
+
elif init_type == 'kaiming':
|
39 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
40 |
+
elif init_type == 'orthogonal':
|
41 |
+
init.orthogonal_(m.weight.data, gain=gain)
|
42 |
+
elif init_type == 'none': # uses pytorch's default init method
|
43 |
+
m.reset_parameters()
|
44 |
+
else:
|
45 |
+
raise NotImplementedError("initialization method '{}' is not implemented".format(init_type))
|
46 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
47 |
+
init.constant_(m.bias.data, 0.0)
|
48 |
+
|
49 |
+
self.apply(init_func)
|
50 |
+
|
51 |
+
def forward(self, *inputs):
|
52 |
+
pass
|
53 |
+
|
54 |
+
|
55 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
56 |
+
# SegGenerator-related classes
|
57 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
58 |
+
class SegGenerator(BaseNetwork):
|
59 |
+
def __init__(self, opt, input_nc, output_nc=13, norm_layer=nn.InstanceNorm2d):
|
60 |
+
super(SegGenerator, self).__init__()
|
61 |
+
|
62 |
+
self.conv1 = nn.Sequential(nn.Conv2d(input_nc, 64, kernel_size=3, padding=1), norm_layer(64), nn.ReLU(),
|
63 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1), norm_layer(64), nn.ReLU())
|
64 |
+
|
65 |
+
self.conv2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, padding=1), norm_layer(128), nn.ReLU(),
|
66 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1), norm_layer(128), nn.ReLU())
|
67 |
+
|
68 |
+
self.conv3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, padding=1), norm_layer(256), nn.ReLU(),
|
69 |
+
nn.Conv2d(256, 256, kernel_size=3, padding=1), norm_layer(256), nn.ReLU())
|
70 |
+
|
71 |
+
self.conv4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, padding=1), norm_layer(512), nn.ReLU(),
|
72 |
+
nn.Conv2d(512, 512, kernel_size=3, padding=1), norm_layer(512), nn.ReLU())
|
73 |
+
|
74 |
+
self.conv5 = nn.Sequential(nn.Conv2d(512, 1024, kernel_size=3, padding=1), norm_layer(1024), nn.ReLU(),
|
75 |
+
nn.Conv2d(1024, 1024, kernel_size=3, padding=1), norm_layer(1024), nn.ReLU())
|
76 |
+
|
77 |
+
self.up6 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
|
78 |
+
nn.Conv2d(1024, 512, kernel_size=3, padding=1), norm_layer(512), nn.ReLU())
|
79 |
+
self.conv6 = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=3, padding=1), norm_layer(512), nn.ReLU(),
|
80 |
+
nn.Conv2d(512, 512, kernel_size=3, padding=1), norm_layer(512), nn.ReLU())
|
81 |
+
|
82 |
+
self.up7 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
|
83 |
+
nn.Conv2d(512, 256, kernel_size=3, padding=1), norm_layer(256), nn.ReLU())
|
84 |
+
self.conv7 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, padding=1), norm_layer(256), nn.ReLU(),
|
85 |
+
nn.Conv2d(256, 256, kernel_size=3, padding=1), norm_layer(256), nn.ReLU())
|
86 |
+
|
87 |
+
self.up8 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
|
88 |
+
nn.Conv2d(256, 128, kernel_size=3, padding=1), norm_layer(128), nn.ReLU())
|
89 |
+
self.conv8 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, padding=1), norm_layer(128), nn.ReLU(),
|
90 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1), norm_layer(128), nn.ReLU())
|
91 |
+
|
92 |
+
self.up9 = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
|
93 |
+
nn.Conv2d(128, 64, kernel_size=3, padding=1), norm_layer(64), nn.ReLU())
|
94 |
+
self.conv9 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1), norm_layer(64), nn.ReLU(),
|
95 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1), norm_layer(64), nn.ReLU(),
|
96 |
+
nn.Conv2d(64, output_nc, kernel_size=3, padding=1))
|
97 |
+
|
98 |
+
self.pool = nn.MaxPool2d(2)
|
99 |
+
self.drop = nn.Dropout(0.5)
|
100 |
+
self.sigmoid = nn.Sigmoid()
|
101 |
+
|
102 |
+
self.print_network()
|
103 |
+
self.init_weights(opt.init_type, opt.init_variance)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
conv1 = self.conv1(x)
|
107 |
+
conv2 = self.conv2(self.pool(conv1))
|
108 |
+
conv3 = self.conv3(self.pool(conv2))
|
109 |
+
conv4 = self.drop(self.conv4(self.pool(conv3)))
|
110 |
+
conv5 = self.drop(self.conv5(self.pool(conv4)))
|
111 |
+
|
112 |
+
conv6 = self.conv6(torch.cat((conv4, self.up6(conv5)), 1))
|
113 |
+
conv7 = self.conv7(torch.cat((conv3, self.up7(conv6)), 1))
|
114 |
+
conv8 = self.conv8(torch.cat((conv2, self.up8(conv7)), 1))
|
115 |
+
conv9 = self.conv9(torch.cat((conv1, self.up9(conv8)), 1))
|
116 |
+
return self.sigmoid(conv9)
|
117 |
+
|
118 |
+
|
119 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
120 |
+
# GMM-related classes
|
121 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
122 |
+
class FeatureExtraction(BaseNetwork):
|
123 |
+
def __init__(self, input_nc, ngf=64, num_layers=4, norm_layer=nn.BatchNorm2d):
|
124 |
+
super(FeatureExtraction, self).__init__()
|
125 |
+
|
126 |
+
nf = ngf
|
127 |
+
layers = [nn.Conv2d(input_nc, nf, kernel_size=4, stride=2, padding=1), nn.ReLU(), norm_layer(nf)]
|
128 |
+
|
129 |
+
for i in range(1, num_layers):
|
130 |
+
nf_prev = nf
|
131 |
+
nf = min(nf * 2, 512)
|
132 |
+
layers += [nn.Conv2d(nf_prev, nf, kernel_size=4, stride=2, padding=1), nn.ReLU(), norm_layer(nf)]
|
133 |
+
|
134 |
+
layers += [nn.Conv2d(nf, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(), norm_layer(512)]
|
135 |
+
layers += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU()]
|
136 |
+
|
137 |
+
self.model = nn.Sequential(*layers)
|
138 |
+
self.init_weights()
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
return self.model(x)
|
142 |
+
|
143 |
+
|
144 |
+
class FeatureCorrelation(nn.Module):
|
145 |
+
def __init__(self):
|
146 |
+
super(FeatureCorrelation, self).__init__()
|
147 |
+
|
148 |
+
def forward(self, featureA, featureB):
|
149 |
+
# Reshape features for matrix multiplication.
|
150 |
+
b, c, h, w = featureA.size()
|
151 |
+
featureA = featureA.permute(0, 3, 2, 1).reshape(b, w * h, c)
|
152 |
+
featureB = featureB.reshape(b, c, h * w)
|
153 |
+
|
154 |
+
# Perform matrix multiplication.
|
155 |
+
corr = torch.bmm(featureA, featureB).reshape(b, w * h, h, w)
|
156 |
+
return corr
|
157 |
+
|
158 |
+
|
159 |
+
class FeatureRegression(nn.Module):
|
160 |
+
def __init__(self, input_nc=512, output_size=6, norm_layer=nn.BatchNorm2d):
|
161 |
+
super(FeatureRegression, self).__init__()
|
162 |
+
|
163 |
+
self.conv = nn.Sequential(
|
164 |
+
nn.Conv2d(input_nc, 512, kernel_size=4, stride=2, padding=1), norm_layer(512), nn.ReLU(),
|
165 |
+
nn.Conv2d(512, 256, kernel_size=4, stride=2, padding=1), norm_layer(256), nn.ReLU(),
|
166 |
+
nn.Conv2d(256, 128, kernel_size=3, padding=1), norm_layer(128), nn.ReLU(),
|
167 |
+
nn.Conv2d(128, 64, kernel_size=3, padding=1), norm_layer(64), nn.ReLU()
|
168 |
+
)
|
169 |
+
self.linear = nn.Linear(64 * (input_nc // 16), output_size)
|
170 |
+
self.tanh = nn.Tanh()
|
171 |
+
|
172 |
+
def forward(self, x):
|
173 |
+
x = self.conv(x)
|
174 |
+
x = self.linear(x.reshape(x.size(0), -1))
|
175 |
+
return self.tanh(x)
|
176 |
+
|
177 |
+
|
178 |
+
class TpsGridGen(nn.Module):
|
179 |
+
def __init__(self, opt, dtype=torch.float):
|
180 |
+
super(TpsGridGen, self).__init__()
|
181 |
+
|
182 |
+
# Create a grid in numpy.
|
183 |
+
# TODO: set an appropriate interval ([-1, 1] in CP-VTON, [-0.9, 0.9] in the current version of VITON-HD)
|
184 |
+
grid_X, grid_Y = np.meshgrid(np.linspace(-0.9, 0.9, opt.load_width), np.linspace(-0.9, 0.9, opt.load_height))
|
185 |
+
grid_X = torch.tensor(grid_X, dtype=dtype).unsqueeze(0).unsqueeze(3) # size: (1, h, w, 1)
|
186 |
+
grid_Y = torch.tensor(grid_Y, dtype=dtype).unsqueeze(0).unsqueeze(3) # size: (1, h, w, 1)
|
187 |
+
|
188 |
+
# Initialize the regular grid for control points P.
|
189 |
+
self.N = opt.grid_size * opt.grid_size
|
190 |
+
coords = np.linspace(-0.9, 0.9, opt.grid_size)
|
191 |
+
# FIXME: why P_Y and P_X are swapped?
|
192 |
+
P_Y, P_X = np.meshgrid(coords, coords)
|
193 |
+
P_X = torch.tensor(P_X, dtype=dtype).reshape(self.N, 1)
|
194 |
+
P_Y = torch.tensor(P_Y, dtype=dtype).reshape(self.N, 1)
|
195 |
+
P_X_base = P_X.clone()
|
196 |
+
P_Y_base = P_Y.clone()
|
197 |
+
|
198 |
+
Li = self.compute_L_inverse(P_X, P_Y).unsqueeze(0)
|
199 |
+
P_X = P_X.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0, 4) # size: (1, 1, 1, 1, self.N)
|
200 |
+
P_Y = P_Y.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0, 4) # size: (1, 1, 1, 1, self.N)
|
201 |
+
|
202 |
+
self.register_buffer('grid_X', grid_X, False)
|
203 |
+
self.register_buffer('grid_Y', grid_Y, False)
|
204 |
+
self.register_buffer('P_X_base', P_X_base, False)
|
205 |
+
self.register_buffer('P_Y_base', P_Y_base, False)
|
206 |
+
self.register_buffer('Li', Li, False)
|
207 |
+
self.register_buffer('P_X', P_X, False)
|
208 |
+
self.register_buffer('P_Y', P_Y, False)
|
209 |
+
|
210 |
+
# TODO: refactor
|
211 |
+
def compute_L_inverse(self,X,Y):
|
212 |
+
N = X.size()[0] # num of points (along dim 0)
|
213 |
+
# construct matrix K
|
214 |
+
Xmat = X.expand(N,N)
|
215 |
+
Ymat = Y.expand(N,N)
|
216 |
+
P_dist_squared = torch.pow(Xmat-Xmat.transpose(0,1),2)+torch.pow(Ymat-Ymat.transpose(0,1),2)
|
217 |
+
P_dist_squared[P_dist_squared==0]=1 # make diagonal 1 to avoid NaN in log computation
|
218 |
+
K = torch.mul(P_dist_squared,torch.log(P_dist_squared))
|
219 |
+
# construct matrix L
|
220 |
+
O = torch.FloatTensor(N,1).fill_(1)
|
221 |
+
Z = torch.FloatTensor(3,3).fill_(0)
|
222 |
+
P = torch.cat((O,X,Y),1)
|
223 |
+
L = torch.cat((torch.cat((K,P),1),torch.cat((P.transpose(0,1),Z),1)),0)
|
224 |
+
Li = torch.inverse(L)
|
225 |
+
return Li
|
226 |
+
|
227 |
+
# TODO: refactor
|
228 |
+
def apply_transformation(self,theta,points):
|
229 |
+
if theta.dim()==2:
|
230 |
+
theta = theta.unsqueeze(2).unsqueeze(3)
|
231 |
+
# points should be in the [B,H,W,2] format,
|
232 |
+
# where points[:,:,:,0] are the X coords
|
233 |
+
# and points[:,:,:,1] are the Y coords
|
234 |
+
|
235 |
+
# input are the corresponding control points P_i
|
236 |
+
batch_size = theta.size()[0]
|
237 |
+
# split theta into point coordinates
|
238 |
+
Q_X=theta[:,:self.N,:,:].squeeze(3)
|
239 |
+
Q_Y=theta[:,self.N:,:,:].squeeze(3)
|
240 |
+
Q_X = Q_X + self.P_X_base.expand_as(Q_X)
|
241 |
+
Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
|
242 |
+
|
243 |
+
# get spatial dimensions of points
|
244 |
+
points_b = points.size()[0]
|
245 |
+
points_h = points.size()[1]
|
246 |
+
points_w = points.size()[2]
|
247 |
+
|
248 |
+
# repeat pre-defined control points along spatial dimensions of points to be transformed
|
249 |
+
P_X = self.P_X.expand((1,points_h,points_w,1,self.N))
|
250 |
+
P_Y = self.P_Y.expand((1,points_h,points_w,1,self.N))
|
251 |
+
|
252 |
+
# compute weigths for non-linear part
|
253 |
+
W_X = torch.bmm(self.Li[:,:self.N,:self.N].expand((batch_size,self.N,self.N)),Q_X)
|
254 |
+
W_Y = torch.bmm(self.Li[:,:self.N,:self.N].expand((batch_size,self.N,self.N)),Q_Y)
|
255 |
+
# reshape
|
256 |
+
# W_X,W,Y: size [B,H,W,1,N]
|
257 |
+
W_X = W_X.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)
|
258 |
+
W_Y = W_Y.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)
|
259 |
+
# compute weights for affine part
|
260 |
+
A_X = torch.bmm(self.Li[:,self.N:,:self.N].expand((batch_size,3,self.N)),Q_X)
|
261 |
+
A_Y = torch.bmm(self.Li[:,self.N:,:self.N].expand((batch_size,3,self.N)),Q_Y)
|
262 |
+
# reshape
|
263 |
+
# A_X,A,Y: size [B,H,W,1,3]
|
264 |
+
A_X = A_X.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)
|
265 |
+
A_Y = A_Y.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)
|
266 |
+
|
267 |
+
# compute distance P_i - (grid_X,grid_Y)
|
268 |
+
# grid is expanded in point dim 4, but not in batch dim 0, as points P_X,P_Y are fixed for all batch
|
269 |
+
points_X_for_summation = points[:,:,:,0].unsqueeze(3).unsqueeze(4).expand(points[:,:,:,0].size()+(1,self.N))
|
270 |
+
points_Y_for_summation = points[:,:,:,1].unsqueeze(3).unsqueeze(4).expand(points[:,:,:,1].size()+(1,self.N))
|
271 |
+
|
272 |
+
if points_b==1:
|
273 |
+
delta_X = points_X_for_summation-P_X
|
274 |
+
delta_Y = points_Y_for_summation-P_Y
|
275 |
+
else:
|
276 |
+
# use expanded P_X,P_Y in batch dimension
|
277 |
+
delta_X = points_X_for_summation-P_X.expand_as(points_X_for_summation)
|
278 |
+
delta_Y = points_Y_for_summation-P_Y.expand_as(points_Y_for_summation)
|
279 |
+
|
280 |
+
dist_squared = torch.pow(delta_X,2)+torch.pow(delta_Y,2)
|
281 |
+
# U: size [1,H,W,1,N]
|
282 |
+
dist_squared[dist_squared==0]=1 # avoid NaN in log computation
|
283 |
+
U = torch.mul(dist_squared,torch.log(dist_squared))
|
284 |
+
|
285 |
+
# expand grid in batch dimension if necessary
|
286 |
+
points_X_batch = points[:,:,:,0].unsqueeze(3)
|
287 |
+
points_Y_batch = points[:,:,:,1].unsqueeze(3)
|
288 |
+
if points_b==1:
|
289 |
+
points_X_batch = points_X_batch.expand((batch_size,)+points_X_batch.size()[1:])
|
290 |
+
points_Y_batch = points_Y_batch.expand((batch_size,)+points_Y_batch.size()[1:])
|
291 |
+
|
292 |
+
points_X_prime = A_X[:,:,:,:,0]+ \
|
293 |
+
torch.mul(A_X[:,:,:,:,1],points_X_batch) + \
|
294 |
+
torch.mul(A_X[:,:,:,:,2],points_Y_batch) + \
|
295 |
+
torch.sum(torch.mul(W_X,U.expand_as(W_X)),4)
|
296 |
+
|
297 |
+
points_Y_prime = A_Y[:,:,:,:,0]+ \
|
298 |
+
torch.mul(A_Y[:,:,:,:,1],points_X_batch) + \
|
299 |
+
torch.mul(A_Y[:,:,:,:,2],points_Y_batch) + \
|
300 |
+
torch.sum(torch.mul(W_Y,U.expand_as(W_Y)),4)
|
301 |
+
|
302 |
+
return torch.cat((points_X_prime,points_Y_prime),3)
|
303 |
+
|
304 |
+
def forward(self, theta):
|
305 |
+
warped_grid = self.apply_transformation(theta, torch.cat((self.grid_X, self.grid_Y), 3))
|
306 |
+
return warped_grid
|
307 |
+
|
308 |
+
|
309 |
+
class GMM(nn.Module):
|
310 |
+
def __init__(self, opt, inputA_nc, inputB_nc):
|
311 |
+
super(GMM, self).__init__()
|
312 |
+
|
313 |
+
self.extractionA = FeatureExtraction(inputA_nc, ngf=64, num_layers=4)
|
314 |
+
self.extractionB = FeatureExtraction(inputB_nc, ngf=64, num_layers=4)
|
315 |
+
self.correlation = FeatureCorrelation()
|
316 |
+
self.regression = FeatureRegression(input_nc=(opt.load_width // 64) * (opt.load_height // 64),
|
317 |
+
output_size=2 * opt.grid_size**2)
|
318 |
+
self.gridGen = TpsGridGen(opt)
|
319 |
+
|
320 |
+
def forward(self, inputA, inputB):
|
321 |
+
featureA = F.normalize(self.extractionA(inputA), dim=1)
|
322 |
+
featureB = F.normalize(self.extractionB(inputB), dim=1)
|
323 |
+
corr = self.correlation(featureA, featureB)
|
324 |
+
theta = self.regression(corr)
|
325 |
+
|
326 |
+
warped_grid = self.gridGen(theta)
|
327 |
+
return theta, warped_grid
|
328 |
+
|
329 |
+
|
330 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
331 |
+
# ALIASGenerator-related classes
|
332 |
+
# ----------------------------------------------------------------------------------------------------------------------
|
333 |
+
class MaskNorm(nn.Module):
|
334 |
+
def __init__(self, norm_nc):
|
335 |
+
super(MaskNorm, self).__init__()
|
336 |
+
|
337 |
+
self.norm_layer = nn.InstanceNorm2d(norm_nc, affine=False)
|
338 |
+
|
339 |
+
def normalize_region(self, region, mask):
|
340 |
+
b, c, h, w = region.size()
|
341 |
+
|
342 |
+
num_pixels = mask.sum((2, 3), keepdim=True) # size: (b, 1, 1, 1)
|
343 |
+
num_pixels[num_pixels == 0] = 1
|
344 |
+
mu = region.sum((2, 3), keepdim=True) / num_pixels # size: (b, c, 1, 1)
|
345 |
+
|
346 |
+
normalized_region = self.norm_layer(region + (1 - mask) * mu)
|
347 |
+
return normalized_region * torch.sqrt(num_pixels / (h * w))
|
348 |
+
|
349 |
+
def forward(self, x, mask):
|
350 |
+
mask = mask.detach()
|
351 |
+
normalized_foreground = self.normalize_region(x * mask, mask)
|
352 |
+
normalized_background = self.normalize_region(x * (1 - mask), 1 - mask)
|
353 |
+
return normalized_foreground + normalized_background
|
354 |
+
|
355 |
+
|
356 |
+
class ALIASNorm(nn.Module):
|
357 |
+
def __init__(self, norm_type, norm_nc, label_nc):
|
358 |
+
super(ALIASNorm, self).__init__()
|
359 |
+
|
360 |
+
self.noise_scale = nn.Parameter(torch.zeros(norm_nc))
|
361 |
+
|
362 |
+
assert norm_type.startswith('alias')
|
363 |
+
param_free_norm_type = norm_type[len('alias'):]
|
364 |
+
if param_free_norm_type == 'batch':
|
365 |
+
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
|
366 |
+
elif param_free_norm_type == 'instance':
|
367 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
|
368 |
+
elif param_free_norm_type == 'mask':
|
369 |
+
self.param_free_norm = MaskNorm(norm_nc)
|
370 |
+
else:
|
371 |
+
raise ValueError(
|
372 |
+
"'{}' is not a recognized parameter-free normalization type in ALIASNorm".format(param_free_norm_type)
|
373 |
+
)
|
374 |
+
|
375 |
+
nhidden = 128
|
376 |
+
ks = 3
|
377 |
+
pw = ks // 2
|
378 |
+
self.conv_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
|
379 |
+
self.conv_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
|
380 |
+
self.conv_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
|
381 |
+
|
382 |
+
def forward(self, x, seg, misalign_mask=None):
|
383 |
+
# Part 1. Generate parameter-free normalized activations.
|
384 |
+
b, c, h, w = x.size()
|
385 |
+
noise = (torch.randn(b, w, h, 1).cuda() * self.noise_scale).transpose(1, 3)
|
386 |
+
|
387 |
+
if misalign_mask is None:
|
388 |
+
normalized = self.param_free_norm(x + noise)
|
389 |
+
else:
|
390 |
+
normalized = self.param_free_norm(x + noise, misalign_mask)
|
391 |
+
|
392 |
+
# Part 2. Produce affine parameters conditioned on the segmentation map.
|
393 |
+
actv = self.conv_shared(seg)
|
394 |
+
gamma = self.conv_gamma(actv)
|
395 |
+
beta = self.conv_beta(actv)
|
396 |
+
|
397 |
+
# Apply the affine parameters.
|
398 |
+
output = normalized * (1 + gamma) + beta
|
399 |
+
return output
|
400 |
+
|
401 |
+
|
402 |
+
class ALIASResBlock(nn.Module):
|
403 |
+
def __init__(self, opt, input_nc, output_nc, use_mask_norm=True):
|
404 |
+
super(ALIASResBlock, self).__init__()
|
405 |
+
|
406 |
+
self.learned_shortcut = (input_nc != output_nc)
|
407 |
+
middle_nc = min(input_nc, output_nc)
|
408 |
+
|
409 |
+
self.conv_0 = nn.Conv2d(input_nc, middle_nc, kernel_size=3, padding=1)
|
410 |
+
self.conv_1 = nn.Conv2d(middle_nc, output_nc, kernel_size=3, padding=1)
|
411 |
+
if self.learned_shortcut:
|
412 |
+
self.conv_s = nn.Conv2d(input_nc, output_nc, kernel_size=1, bias=False)
|
413 |
+
|
414 |
+
subnorm_type = opt.norm_G
|
415 |
+
if subnorm_type.startswith('spectral'):
|
416 |
+
subnorm_type = subnorm_type[len('spectral'):]
|
417 |
+
self.conv_0 = spectral_norm(self.conv_0)
|
418 |
+
self.conv_1 = spectral_norm(self.conv_1)
|
419 |
+
if self.learned_shortcut:
|
420 |
+
self.conv_s = spectral_norm(self.conv_s)
|
421 |
+
|
422 |
+
semantic_nc = opt.semantic_nc
|
423 |
+
if use_mask_norm:
|
424 |
+
subnorm_type = 'aliasmask'
|
425 |
+
semantic_nc = semantic_nc + 1
|
426 |
+
|
427 |
+
self.norm_0 = ALIASNorm(subnorm_type, input_nc, semantic_nc)
|
428 |
+
self.norm_1 = ALIASNorm(subnorm_type, middle_nc, semantic_nc)
|
429 |
+
if self.learned_shortcut:
|
430 |
+
self.norm_s = ALIASNorm(subnorm_type, input_nc, semantic_nc)
|
431 |
+
|
432 |
+
self.relu = nn.LeakyReLU(0.2)
|
433 |
+
|
434 |
+
def shortcut(self, x, seg, misalign_mask):
|
435 |
+
if self.learned_shortcut:
|
436 |
+
return self.conv_s(self.norm_s(x, seg, misalign_mask))
|
437 |
+
else:
|
438 |
+
return x
|
439 |
+
|
440 |
+
def forward(self, x, seg, misalign_mask=None):
|
441 |
+
seg = F.interpolate(seg, size=x.size()[2:], mode='nearest')
|
442 |
+
if misalign_mask is not None:
|
443 |
+
misalign_mask = F.interpolate(misalign_mask, size=x.size()[2:], mode='nearest')
|
444 |
+
|
445 |
+
x_s = self.shortcut(x, seg, misalign_mask)
|
446 |
+
|
447 |
+
dx = self.conv_0(self.relu(self.norm_0(x, seg, misalign_mask)))
|
448 |
+
dx = self.conv_1(self.relu(self.norm_1(dx, seg, misalign_mask)))
|
449 |
+
output = x_s + dx
|
450 |
+
return output
|
451 |
+
|
452 |
+
|
453 |
+
class ALIASGenerator(BaseNetwork):
|
454 |
+
def __init__(self, opt, input_nc):
|
455 |
+
super(ALIASGenerator, self).__init__()
|
456 |
+
self.num_upsampling_layers = opt.num_upsampling_layers
|
457 |
+
|
458 |
+
self.sh, self.sw = self.compute_latent_vector_size(opt)
|
459 |
+
|
460 |
+
nf = opt.ngf
|
461 |
+
self.conv_0 = nn.Conv2d(input_nc, nf * 16, kernel_size=3, padding=1)
|
462 |
+
for i in range(1, 8):
|
463 |
+
self.add_module('conv_{}'.format(i), nn.Conv2d(input_nc, 16, kernel_size=3, padding=1))
|
464 |
+
|
465 |
+
self.head_0 = ALIASResBlock(opt, nf * 16, nf * 16)
|
466 |
+
|
467 |
+
self.G_middle_0 = ALIASResBlock(opt, nf * 16 + 16, nf * 16)
|
468 |
+
self.G_middle_1 = ALIASResBlock(opt, nf * 16 + 16, nf * 16)
|
469 |
+
|
470 |
+
self.up_0 = ALIASResBlock(opt, nf * 16 + 16, nf * 8)
|
471 |
+
self.up_1 = ALIASResBlock(opt, nf * 8 + 16, nf * 4)
|
472 |
+
self.up_2 = ALIASResBlock(opt, nf * 4 + 16, nf * 2, use_mask_norm=False)
|
473 |
+
self.up_3 = ALIASResBlock(opt, nf * 2 + 16, nf * 1, use_mask_norm=False)
|
474 |
+
if self.num_upsampling_layers == 'most':
|
475 |
+
self.up_4 = ALIASResBlock(opt, nf * 1 + 16, nf // 2, use_mask_norm=False)
|
476 |
+
nf = nf // 2
|
477 |
+
|
478 |
+
self.conv_img = nn.Conv2d(nf, 3, kernel_size=3, padding=1)
|
479 |
+
|
480 |
+
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
481 |
+
self.relu = nn.LeakyReLU(0.2)
|
482 |
+
self.tanh = nn.Tanh()
|
483 |
+
|
484 |
+
self.print_network()
|
485 |
+
self.init_weights(opt.init_type, opt.init_variance)
|
486 |
+
|
487 |
+
def compute_latent_vector_size(self, opt):
|
488 |
+
if self.num_upsampling_layers == 'normal':
|
489 |
+
num_up_layers = 5
|
490 |
+
elif self.num_upsampling_layers == 'more':
|
491 |
+
num_up_layers = 6
|
492 |
+
elif self.num_upsampling_layers == 'most':
|
493 |
+
num_up_layers = 7
|
494 |
+
else:
|
495 |
+
raise ValueError("opt.num_upsampling_layers '{}' is not recognized".format(self.num_upsampling_layers))
|
496 |
+
|
497 |
+
sh = opt.load_height // 2**num_up_layers
|
498 |
+
sw = opt.load_width // 2**num_up_layers
|
499 |
+
return sh, sw
|
500 |
+
|
501 |
+
def forward(self, x, seg, seg_div, misalign_mask):
|
502 |
+
samples = [F.interpolate(x, size=(self.sh * 2**i, self.sw * 2**i), mode='nearest') for i in range(8)]
|
503 |
+
features = [self._modules['conv_{}'.format(i)](samples[i]) for i in range(8)]
|
504 |
+
|
505 |
+
x = self.head_0(features[0], seg_div, misalign_mask)
|
506 |
+
|
507 |
+
x = self.up(x)
|
508 |
+
x = self.G_middle_0(torch.cat((x, features[1]), 1), seg_div, misalign_mask)
|
509 |
+
if self.num_upsampling_layers in ['more', 'most']:
|
510 |
+
x = self.up(x)
|
511 |
+
x = self.G_middle_1(torch.cat((x, features[2]), 1), seg_div, misalign_mask)
|
512 |
+
|
513 |
+
x = self.up(x)
|
514 |
+
x = self.up_0(torch.cat((x, features[3]), 1), seg_div, misalign_mask)
|
515 |
+
x = self.up(x)
|
516 |
+
x = self.up_1(torch.cat((x, features[4]), 1), seg_div, misalign_mask)
|
517 |
+
x = self.up(x)
|
518 |
+
x = self.up_2(torch.cat((x, features[5]), 1), seg)
|
519 |
+
x = self.up(x)
|
520 |
+
x = self.up_3(torch.cat((x, features[6]), 1), seg)
|
521 |
+
if self.num_upsampling_layers == 'most':
|
522 |
+
x = self.up(x)
|
523 |
+
x = self.up_4(torch.cat((x, features[7]), 1), seg)
|
524 |
+
|
525 |
+
x = self.conv_img(self.relu(x))
|
526 |
+
return self.tanh(x)
|
networks/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .u2net import U2NET
|
networks/u2net.py
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class REBNCONV(nn.Module):
|
7 |
+
def __init__(self, in_ch=3, out_ch=3, dirate=1):
|
8 |
+
super(REBNCONV, self).__init__()
|
9 |
+
|
10 |
+
self.conv_s1 = nn.Conv2d(
|
11 |
+
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate
|
12 |
+
)
|
13 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
14 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
|
18 |
+
hx = x
|
19 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
20 |
+
|
21 |
+
return xout
|
22 |
+
|
23 |
+
|
24 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
25 |
+
def _upsample_like(src, tar):
|
26 |
+
|
27 |
+
src = F.upsample(src, size=tar.shape[2:], mode="bilinear")
|
28 |
+
|
29 |
+
return src
|
30 |
+
|
31 |
+
|
32 |
+
### RSU-7 ###
|
33 |
+
class RSU7(nn.Module): # UNet07DRES(nn.Module):
|
34 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
35 |
+
super(RSU7, self).__init__()
|
36 |
+
|
37 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
38 |
+
|
39 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
40 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
41 |
+
|
42 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
43 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
44 |
+
|
45 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
46 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
47 |
+
|
48 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
49 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
50 |
+
|
51 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
52 |
+
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
53 |
+
|
54 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
55 |
+
|
56 |
+
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
57 |
+
|
58 |
+
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
59 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
60 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
61 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
62 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
63 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
|
67 |
+
hx = x
|
68 |
+
hxin = self.rebnconvin(hx)
|
69 |
+
|
70 |
+
hx1 = self.rebnconv1(hxin)
|
71 |
+
hx = self.pool1(hx1)
|
72 |
+
|
73 |
+
hx2 = self.rebnconv2(hx)
|
74 |
+
hx = self.pool2(hx2)
|
75 |
+
|
76 |
+
hx3 = self.rebnconv3(hx)
|
77 |
+
hx = self.pool3(hx3)
|
78 |
+
|
79 |
+
hx4 = self.rebnconv4(hx)
|
80 |
+
hx = self.pool4(hx4)
|
81 |
+
|
82 |
+
hx5 = self.rebnconv5(hx)
|
83 |
+
hx = self.pool5(hx5)
|
84 |
+
|
85 |
+
hx6 = self.rebnconv6(hx)
|
86 |
+
|
87 |
+
hx7 = self.rebnconv7(hx6)
|
88 |
+
|
89 |
+
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
90 |
+
hx6dup = _upsample_like(hx6d, hx5)
|
91 |
+
|
92 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
93 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
94 |
+
|
95 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
96 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
97 |
+
|
98 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
99 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
100 |
+
|
101 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
102 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
103 |
+
|
104 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
105 |
+
|
106 |
+
"""
|
107 |
+
del hx1, hx2, hx3, hx4, hx5, hx6, hx7
|
108 |
+
del hx6d, hx5d, hx3d, hx2d
|
109 |
+
del hx2dup, hx3dup, hx4dup, hx5dup, hx6dup
|
110 |
+
"""
|
111 |
+
|
112 |
+
return hx1d + hxin
|
113 |
+
|
114 |
+
|
115 |
+
### RSU-6 ###
|
116 |
+
class RSU6(nn.Module): # UNet06DRES(nn.Module):
|
117 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
118 |
+
super(RSU6, self).__init__()
|
119 |
+
|
120 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
121 |
+
|
122 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
123 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
124 |
+
|
125 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
126 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
127 |
+
|
128 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
129 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
130 |
+
|
131 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
132 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
133 |
+
|
134 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
135 |
+
|
136 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
137 |
+
|
138 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
139 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
140 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
141 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
142 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
|
146 |
+
hx = x
|
147 |
+
|
148 |
+
hxin = self.rebnconvin(hx)
|
149 |
+
|
150 |
+
hx1 = self.rebnconv1(hxin)
|
151 |
+
hx = self.pool1(hx1)
|
152 |
+
|
153 |
+
hx2 = self.rebnconv2(hx)
|
154 |
+
hx = self.pool2(hx2)
|
155 |
+
|
156 |
+
hx3 = self.rebnconv3(hx)
|
157 |
+
hx = self.pool3(hx3)
|
158 |
+
|
159 |
+
hx4 = self.rebnconv4(hx)
|
160 |
+
hx = self.pool4(hx4)
|
161 |
+
|
162 |
+
hx5 = self.rebnconv5(hx)
|
163 |
+
|
164 |
+
hx6 = self.rebnconv6(hx5)
|
165 |
+
|
166 |
+
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
167 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
168 |
+
|
169 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
170 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
171 |
+
|
172 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
173 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
174 |
+
|
175 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
176 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
177 |
+
|
178 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
179 |
+
|
180 |
+
"""
|
181 |
+
del hx1, hx2, hx3, hx4, hx5, hx6
|
182 |
+
del hx5d, hx4d, hx3d, hx2d
|
183 |
+
del hx2dup, hx3dup, hx4dup, hx5dup
|
184 |
+
"""
|
185 |
+
|
186 |
+
return hx1d + hxin
|
187 |
+
|
188 |
+
|
189 |
+
### RSU-5 ###
|
190 |
+
class RSU5(nn.Module): # UNet05DRES(nn.Module):
|
191 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
192 |
+
super(RSU5, self).__init__()
|
193 |
+
|
194 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
195 |
+
|
196 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
197 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
198 |
+
|
199 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
200 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
201 |
+
|
202 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
203 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
204 |
+
|
205 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
206 |
+
|
207 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
208 |
+
|
209 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
210 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
211 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
212 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
213 |
+
|
214 |
+
def forward(self, x):
|
215 |
+
|
216 |
+
hx = x
|
217 |
+
|
218 |
+
hxin = self.rebnconvin(hx)
|
219 |
+
|
220 |
+
hx1 = self.rebnconv1(hxin)
|
221 |
+
hx = self.pool1(hx1)
|
222 |
+
|
223 |
+
hx2 = self.rebnconv2(hx)
|
224 |
+
hx = self.pool2(hx2)
|
225 |
+
|
226 |
+
hx3 = self.rebnconv3(hx)
|
227 |
+
hx = self.pool3(hx3)
|
228 |
+
|
229 |
+
hx4 = self.rebnconv4(hx)
|
230 |
+
|
231 |
+
hx5 = self.rebnconv5(hx4)
|
232 |
+
|
233 |
+
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
234 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
235 |
+
|
236 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
237 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
238 |
+
|
239 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
240 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
241 |
+
|
242 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
243 |
+
|
244 |
+
"""
|
245 |
+
del hx1, hx2, hx3, hx4, hx5
|
246 |
+
del hx4d, hx3d, hx2d
|
247 |
+
del hx2dup, hx3dup, hx4dup
|
248 |
+
"""
|
249 |
+
|
250 |
+
return hx1d + hxin
|
251 |
+
|
252 |
+
|
253 |
+
### RSU-4 ###
|
254 |
+
class RSU4(nn.Module): # UNet04DRES(nn.Module):
|
255 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
256 |
+
super(RSU4, self).__init__()
|
257 |
+
|
258 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
259 |
+
|
260 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
261 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
262 |
+
|
263 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
264 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
265 |
+
|
266 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
267 |
+
|
268 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
269 |
+
|
270 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
271 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
272 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
273 |
+
|
274 |
+
def forward(self, x):
|
275 |
+
|
276 |
+
hx = x
|
277 |
+
|
278 |
+
hxin = self.rebnconvin(hx)
|
279 |
+
|
280 |
+
hx1 = self.rebnconv1(hxin)
|
281 |
+
hx = self.pool1(hx1)
|
282 |
+
|
283 |
+
hx2 = self.rebnconv2(hx)
|
284 |
+
hx = self.pool2(hx2)
|
285 |
+
|
286 |
+
hx3 = self.rebnconv3(hx)
|
287 |
+
|
288 |
+
hx4 = self.rebnconv4(hx3)
|
289 |
+
|
290 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
291 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
292 |
+
|
293 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
294 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
295 |
+
|
296 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
297 |
+
|
298 |
+
"""
|
299 |
+
del hx1, hx2, hx3, hx4
|
300 |
+
del hx3d, hx2d
|
301 |
+
del hx2dup, hx3dup
|
302 |
+
"""
|
303 |
+
|
304 |
+
return hx1d + hxin
|
305 |
+
|
306 |
+
|
307 |
+
### RSU-4F ###
|
308 |
+
class RSU4F(nn.Module): # UNet04FRES(nn.Module):
|
309 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
310 |
+
super(RSU4F, self).__init__()
|
311 |
+
|
312 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
313 |
+
|
314 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
315 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
316 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
317 |
+
|
318 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
319 |
+
|
320 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
321 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
322 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
323 |
+
|
324 |
+
def forward(self, x):
|
325 |
+
|
326 |
+
hx = x
|
327 |
+
|
328 |
+
hxin = self.rebnconvin(hx)
|
329 |
+
|
330 |
+
hx1 = self.rebnconv1(hxin)
|
331 |
+
hx2 = self.rebnconv2(hx1)
|
332 |
+
hx3 = self.rebnconv3(hx2)
|
333 |
+
|
334 |
+
hx4 = self.rebnconv4(hx3)
|
335 |
+
|
336 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
337 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
338 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
339 |
+
|
340 |
+
"""
|
341 |
+
del hx1, hx2, hx3, hx4
|
342 |
+
del hx3d, hx2d
|
343 |
+
"""
|
344 |
+
|
345 |
+
return hx1d + hxin
|
346 |
+
|
347 |
+
|
348 |
+
##### U^2-Net ####
|
349 |
+
class U2NET(nn.Module):
|
350 |
+
def __init__(self, in_ch=3, out_ch=1):
|
351 |
+
super(U2NET, self).__init__()
|
352 |
+
|
353 |
+
self.stage1 = RSU7(in_ch, 32, 64)
|
354 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
355 |
+
|
356 |
+
self.stage2 = RSU6(64, 32, 128)
|
357 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
358 |
+
|
359 |
+
self.stage3 = RSU5(128, 64, 256)
|
360 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
361 |
+
|
362 |
+
self.stage4 = RSU4(256, 128, 512)
|
363 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
364 |
+
|
365 |
+
self.stage5 = RSU4F(512, 256, 512)
|
366 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
367 |
+
|
368 |
+
self.stage6 = RSU4F(512, 256, 512)
|
369 |
+
|
370 |
+
# decoder
|
371 |
+
self.stage5d = RSU4F(1024, 256, 512)
|
372 |
+
self.stage4d = RSU4(1024, 128, 256)
|
373 |
+
self.stage3d = RSU5(512, 64, 128)
|
374 |
+
self.stage2d = RSU6(256, 32, 64)
|
375 |
+
self.stage1d = RSU7(128, 16, 64)
|
376 |
+
|
377 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
378 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
379 |
+
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
380 |
+
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
381 |
+
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
382 |
+
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
383 |
+
|
384 |
+
self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
|
385 |
+
|
386 |
+
def forward(self, x):
|
387 |
+
|
388 |
+
hx = x
|
389 |
+
|
390 |
+
# stage 1
|
391 |
+
hx1 = self.stage1(hx)
|
392 |
+
hx = self.pool12(hx1)
|
393 |
+
|
394 |
+
# stage 2
|
395 |
+
hx2 = self.stage2(hx)
|
396 |
+
hx = self.pool23(hx2)
|
397 |
+
|
398 |
+
# stage 3
|
399 |
+
hx3 = self.stage3(hx)
|
400 |
+
hx = self.pool34(hx3)
|
401 |
+
|
402 |
+
# stage 4
|
403 |
+
hx4 = self.stage4(hx)
|
404 |
+
hx = self.pool45(hx4)
|
405 |
+
|
406 |
+
# stage 5
|
407 |
+
hx5 = self.stage5(hx)
|
408 |
+
hx = self.pool56(hx5)
|
409 |
+
|
410 |
+
# stage 6
|
411 |
+
hx6 = self.stage6(hx)
|
412 |
+
hx6up = _upsample_like(hx6, hx5)
|
413 |
+
|
414 |
+
# -------------------- decoder --------------------
|
415 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
416 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
417 |
+
|
418 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
419 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
420 |
+
|
421 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
422 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
423 |
+
|
424 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
425 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
426 |
+
|
427 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
428 |
+
|
429 |
+
# side output
|
430 |
+
d1 = self.side1(hx1d)
|
431 |
+
|
432 |
+
d2 = self.side2(hx2d)
|
433 |
+
d2 = _upsample_like(d2, d1)
|
434 |
+
|
435 |
+
d3 = self.side3(hx3d)
|
436 |
+
d3 = _upsample_like(d3, d1)
|
437 |
+
|
438 |
+
d4 = self.side4(hx4d)
|
439 |
+
d4 = _upsample_like(d4, d1)
|
440 |
+
|
441 |
+
d5 = self.side5(hx5d)
|
442 |
+
d5 = _upsample_like(d5, d1)
|
443 |
+
|
444 |
+
d6 = self.side6(hx6)
|
445 |
+
d6 = _upsample_like(d6, d1)
|
446 |
+
|
447 |
+
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
448 |
+
|
449 |
+
"""
|
450 |
+
del hx1, hx2, hx3, hx4, hx5, hx6
|
451 |
+
del hx5d, hx4d, hx3d, hx2d, hx1d
|
452 |
+
del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
|
453 |
+
"""
|
454 |
+
|
455 |
+
return d0, d1, d2, d3, d4, d5, d6
|
456 |
+
|
457 |
+
|
458 |
+
### U^2-Net small ###
|
459 |
+
class U2NETP(nn.Module):
|
460 |
+
def __init__(self, in_ch=3, out_ch=1):
|
461 |
+
super(U2NETP, self).__init__()
|
462 |
+
|
463 |
+
self.stage1 = RSU7(in_ch, 16, 64)
|
464 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
465 |
+
|
466 |
+
self.stage2 = RSU6(64, 16, 64)
|
467 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
468 |
+
|
469 |
+
self.stage3 = RSU5(64, 16, 64)
|
470 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
471 |
+
|
472 |
+
self.stage4 = RSU4(64, 16, 64)
|
473 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
474 |
+
|
475 |
+
self.stage5 = RSU4F(64, 16, 64)
|
476 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
477 |
+
|
478 |
+
self.stage6 = RSU4F(64, 16, 64)
|
479 |
+
|
480 |
+
# decoder
|
481 |
+
self.stage5d = RSU4F(128, 16, 64)
|
482 |
+
self.stage4d = RSU4(128, 16, 64)
|
483 |
+
self.stage3d = RSU5(128, 16, 64)
|
484 |
+
self.stage2d = RSU6(128, 16, 64)
|
485 |
+
self.stage1d = RSU7(128, 16, 64)
|
486 |
+
|
487 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
488 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
489 |
+
self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
|
490 |
+
self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
|
491 |
+
self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
|
492 |
+
self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
|
493 |
+
|
494 |
+
self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
|
495 |
+
|
496 |
+
def forward(self, x):
|
497 |
+
|
498 |
+
hx = x
|
499 |
+
|
500 |
+
# stage 1
|
501 |
+
hx1 = self.stage1(hx)
|
502 |
+
hx = self.pool12(hx1)
|
503 |
+
|
504 |
+
# stage 2
|
505 |
+
hx2 = self.stage2(hx)
|
506 |
+
hx = self.pool23(hx2)
|
507 |
+
|
508 |
+
# stage 3
|
509 |
+
hx3 = self.stage3(hx)
|
510 |
+
hx = self.pool34(hx3)
|
511 |
+
|
512 |
+
# stage 4
|
513 |
+
hx4 = self.stage4(hx)
|
514 |
+
hx = self.pool45(hx4)
|
515 |
+
|
516 |
+
# stage 5
|
517 |
+
hx5 = self.stage5(hx)
|
518 |
+
hx = self.pool56(hx5)
|
519 |
+
|
520 |
+
# stage 6
|
521 |
+
hx6 = self.stage6(hx)
|
522 |
+
hx6up = _upsample_like(hx6, hx5)
|
523 |
+
|
524 |
+
# decoder
|
525 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
526 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
527 |
+
|
528 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
529 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
530 |
+
|
531 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
532 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
533 |
+
|
534 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
535 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
536 |
+
|
537 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
538 |
+
|
539 |
+
# side output
|
540 |
+
d1 = self.side1(hx1d)
|
541 |
+
|
542 |
+
d2 = self.side2(hx2d)
|
543 |
+
d2 = _upsample_like(d2, d1)
|
544 |
+
|
545 |
+
d3 = self.side3(hx3d)
|
546 |
+
d3 = _upsample_like(d3, d1)
|
547 |
+
|
548 |
+
d4 = self.side4(hx4d)
|
549 |
+
d4 = _upsample_like(d4, d1)
|
550 |
+
|
551 |
+
d5 = self.side5(hx5d)
|
552 |
+
d5 = _upsample_like(d5, d1)
|
553 |
+
|
554 |
+
d6 = self.side6(hx6)
|
555 |
+
d6 = _upsample_like(d6, d1)
|
556 |
+
|
557 |
+
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
558 |
+
|
559 |
+
"""
|
560 |
+
del hx1, hx2, hx3, hx4, hx5, hx6
|
561 |
+
del hx5d, hx4d, hx3d, hx2d, hx1d
|
562 |
+
del hx6up, hx5dup, hx4dup, hx3dup, hx2dup
|
563 |
+
"""
|
564 |
+
|
565 |
+
return d0, d1, d2, d3, d4, d5, d6
|
remove_bg.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
from rembg import remove
|
6 |
+
|
7 |
+
|
8 |
+
class preprcessInput:
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
self.o_width = None
|
12 |
+
self.o_height = None
|
13 |
+
self.o_image = None
|
14 |
+
|
15 |
+
self.t_width = None
|
16 |
+
self.t_height = None
|
17 |
+
self.t_image = None
|
18 |
+
self.save_path = None
|
19 |
+
|
20 |
+
def remove_bg(self, file_path: str):
|
21 |
+
self.save_path = file_path[:-3]+'.png'
|
22 |
+
pic = Image.open(file_path)
|
23 |
+
self.o_width = np.asarray(pic).shape[1]
|
24 |
+
self.o_height = np.asarray(pic).shape[0]
|
25 |
+
try:
|
26 |
+
self.o_channels = np.asarray(pic).shape[2]
|
27 |
+
except Exception as e:
|
28 |
+
print("Single channel image and error", e)
|
29 |
+
os.remove(file_path)
|
30 |
+
self.o_image = remove(pic)
|
31 |
+
self.o_image.save(self.save_path)
|
32 |
+
os.remove(self.save_path)
|
33 |
+
return np.asarray(self.o_image)
|
34 |
+
|
35 |
+
def transform(self, width=768, height=1024):
|
36 |
+
newsize = (width, height)
|
37 |
+
self.t_height = height
|
38 |
+
self.t_width = width
|
39 |
+
|
40 |
+
pic = self.o_image
|
41 |
+
img = pic.resize(newsize)
|
42 |
+
|
43 |
+
self.t_image = img
|
44 |
+
|
45 |
+
background = Image.new("RGBA", newsize, (255, 255, 255, 255))
|
46 |
+
background.paste(img, mask=img.split()[3]) # 3 is the alpha channel
|
47 |
+
self.save_path = self.save_path[:-3] + '.jpg'
|
48 |
+
background.convert('RGB').save(self.save_path, 'JPEG')
|
49 |
+
|
50 |
+
return np.asarray(background.convert('RGB'))
|
51 |
+
|
52 |
+
|
53 |
+
# USAGE OF THE CLASS
|
54 |
+
preprocess = preprcessInput()
|
55 |
+
for images in os.listdir('/content/inputs/test/image'):
|
56 |
+
print(images)
|
57 |
+
if images[-3:] == 'jpg':
|
58 |
+
op = preprocess.remove_bg(r'/content/inputs/test/image/'+images)
|
59 |
+
arr = preprocess.transform(768, 1024)
|
run.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
# running the preprocessing
|
6 |
+
|
7 |
+
def resize_img(path):
|
8 |
+
im = Image.open(path)
|
9 |
+
im = im.resize((768, 1024))
|
10 |
+
im.save(path)
|
11 |
+
|
12 |
+
|
13 |
+
for path in os.listdir('/content/inputs/test/cloth/'):
|
14 |
+
resize_img(f'/content/inputs/test/cloth/{path}')
|
15 |
+
|
16 |
+
os.chdir('/content/clothes-virtual-try-on')
|
17 |
+
os.system("rm -rf /content/inputs/test/cloth/.ipynb_checkpoints")
|
18 |
+
os.system("python cloth-mask.py")
|
19 |
+
os.chdir('/content')
|
20 |
+
os.system("python /content/clothes-virtual-try-on/remove_bg.py")
|
21 |
+
os.system(
|
22 |
+
"python3 /content/Self-Correction-Human-Parsing/simple_extractor.py --dataset 'lip' --model-restore '/content/Self-Correction-Human-Parsing/checkpoints/final.pth' --input-dir '/content/inputs/test/image' --output-dir '/content/inputs/test/image-parse'")
|
23 |
+
os.chdir('/content')
|
24 |
+
os.system(
|
25 |
+
"cd openpose && ./build/examples/openpose/openpose.bin --image_dir /content/inputs/test/image/ --write_json /content/inputs/test/openpose-json/ --display 0 --render_pose 0 --hand")
|
26 |
+
os.system(
|
27 |
+
"cd openpose && ./build/examples/openpose/openpose.bin --image_dir /content/inputs/test/image/ --display 0 --write_images /content/inputs/test/openpose-img/ --hand --render_pose 1 --disable_blending true")
|
28 |
+
|
29 |
+
model_image = os.listdir('/content/inputs/test/image')
|
30 |
+
cloth_image = os.listdir('/content/inputs/test/cloth')
|
31 |
+
pairs = zip(model_image, cloth_image)
|
32 |
+
|
33 |
+
with open('/content/inputs/test_pairs.txt', 'w') as file:
|
34 |
+
for model, cloth in pairs:
|
35 |
+
file.write(f"{model} {cloth}")
|
36 |
+
|
37 |
+
# making predictions
|
38 |
+
os.system(
|
39 |
+
"python /content/clothes-virtual-try-on/test.py --name output --dataset_dir /content/inputs --checkpoint_dir /content/clothes-virtual-try-on/checkpoints --save_dir /content/")
|
40 |
+
os.system("rm -rf /content/inputs")
|
41 |
+
os.system("rm -rf /content/output/.ipynb_checkpoints")
|
setup_gradio.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
setup_ngrok.ipynb
ADDED
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "8gqt11Y_RYxU"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"# Setting up the environment. PLEASE WAIT 🙃"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": null,
|
15 |
+
"metadata": {
|
16 |
+
"id": "RHmGnnTZL6os"
|
17 |
+
},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"!pip install --upgrade --no-cache-dir gdown\n",
|
21 |
+
"!pip install rembg[gpu]"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": null,
|
27 |
+
"metadata": {
|
28 |
+
"id": "uX3LsFFPKSwo"
|
29 |
+
},
|
30 |
+
"outputs": [],
|
31 |
+
"source": [
|
32 |
+
"! wget -c \"https://github.com/Kitware/CMake/releases/download/v3.19.6/cmake-3.19.6.tar.gz\"\n",
|
33 |
+
"! tar xf cmake-3.19.6.tar.gz\n",
|
34 |
+
"! cd cmake-3.19.6 && ./configure && make && sudo make install"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": null,
|
40 |
+
"metadata": {
|
41 |
+
"id": "51QJAhPOK9cK"
|
42 |
+
},
|
43 |
+
"outputs": [],
|
44 |
+
"source": [
|
45 |
+
"# Install library\n",
|
46 |
+
"! sudo apt-get --assume-yes update\n",
|
47 |
+
"! sudo apt-get --assume-yes install build-essential\n",
|
48 |
+
"# OpenCV\n",
|
49 |
+
"! sudo apt-get --assume-yes install libopencv-dev\n",
|
50 |
+
"# General dependencies\n",
|
51 |
+
"! sudo apt-get --assume-yes install libatlas-base-dev libprotobuf-dev libleveldb-dev libsnappy-dev libhdf5-serial-dev protobuf-compiler\n",
|
52 |
+
"! sudo apt-get --assume-yes install --no-install-recommends libboost-all-dev\n",
|
53 |
+
"# Remaining dependencies, 14.04\n",
|
54 |
+
"! sudo apt-get --assume-yes install libgflags-dev libgoogle-glog-dev liblmdb-dev\n",
|
55 |
+
"# Python3 libs\n",
|
56 |
+
"! sudo apt-get --assume-yes install python3-setuptools python3-dev build-essential\n",
|
57 |
+
"! sudo apt-get --assume-yes install python3-pip\n",
|
58 |
+
"! sudo -H pip3 install --upgrade numpy protobuf opencv-python\n",
|
59 |
+
"# OpenCL Generic\n",
|
60 |
+
"! sudo apt-get --assume-yes install opencl-headers ocl-icd-opencl-dev\n",
|
61 |
+
"! sudo apt-get --assume-yes install libviennacl-dev"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "code",
|
66 |
+
"execution_count": 3,
|
67 |
+
"metadata": {
|
68 |
+
"colab": {
|
69 |
+
"base_uri": "https://localhost:8080/"
|
70 |
+
},
|
71 |
+
"id": "x7tqqDLHLNDr",
|
72 |
+
"outputId": "dc52fffb-2375-4283-da82-f6327a4d73ad"
|
73 |
+
},
|
74 |
+
"outputs": [
|
75 |
+
{
|
76 |
+
"name": "stdout",
|
77 |
+
"output_type": "stream",
|
78 |
+
"text": [
|
79 |
+
"v1.7.0\n"
|
80 |
+
]
|
81 |
+
}
|
82 |
+
],
|
83 |
+
"source": [
|
84 |
+
"ver_openpose = \"v1.7.0\"\n",
|
85 |
+
"! echo $ver_openpose"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": null,
|
91 |
+
"metadata": {
|
92 |
+
"id": "b11_hkSgLDl5"
|
93 |
+
},
|
94 |
+
"outputs": [],
|
95 |
+
"source": [
|
96 |
+
"! git clone --depth 1 -b \"$ver_openpose\" https://github.com/CMU-Perceptual-Computing-Lab/openpose.git"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": null,
|
102 |
+
"metadata": {
|
103 |
+
"id": "s-XyxfV8Q-DE"
|
104 |
+
},
|
105 |
+
"outputs": [],
|
106 |
+
"source": [
|
107 |
+
"# manually downloading openpose models\n",
|
108 |
+
"%%bash\n",
|
109 |
+
"gdown 1QCSxJZpnWvM00hx49CJ2zky7PWGzpcEh\n",
|
110 |
+
"unzip models.zip\n",
|
111 |
+
"mv /content/models/face/pose_iter_116000.caffemodel /content/openpose/models/face/pose_iter_116000.caffemodel\n",
|
112 |
+
"mv /content/models/hand/pose_iter_102000.caffemodel /content/openpose/models/hand/pose_iter_102000.caffemodel\n",
|
113 |
+
"mv /content/models/pose/body_25/pose_iter_584000.caffemodel /content/openpose/models/pose/body_25/pose_iter_584000.caffemodel\n",
|
114 |
+
"mv /content/models/pose/coco/pose_iter_440000.caffemodel /content/openpose/models/pose/coco/pose_iter_440000.caffemodel\n",
|
115 |
+
"mv /content/models/pose/mpi/pose_iter_160000.caffemodel /content/openpose/models/pose/mpi/pose_iter_160000.caffemodel\n",
|
116 |
+
"rm -rf models\n",
|
117 |
+
"rm models.zip"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": 6,
|
123 |
+
"metadata": {
|
124 |
+
"id": "Bs-zIObzQLYj"
|
125 |
+
},
|
126 |
+
"outputs": [],
|
127 |
+
"source": [
|
128 |
+
"! cd openpose && mkdir build && cd build"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": null,
|
134 |
+
"metadata": {
|
135 |
+
"id": "7i7oHh2vQqHv"
|
136 |
+
},
|
137 |
+
"outputs": [],
|
138 |
+
"source": [
|
139 |
+
"! cd openpose/build && cmake -DUSE_CUDNN=OFF -DBUILD_PYTHON=ON .."
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "code",
|
144 |
+
"execution_count": null,
|
145 |
+
"metadata": {
|
146 |
+
"id": "iBvxsDM-EYJk"
|
147 |
+
},
|
148 |
+
"outputs": [],
|
149 |
+
"source": [
|
150 |
+
"# ! cd openpose/build && cmake .."
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": null,
|
156 |
+
"metadata": {
|
157 |
+
"id": "XEAY8VW0RzD0"
|
158 |
+
},
|
159 |
+
"outputs": [],
|
160 |
+
"source": [
|
161 |
+
"! cd openpose/build && make -j`nproc`\n",
|
162 |
+
"! cd openpose && mkdir output"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "code",
|
167 |
+
"execution_count": 9,
|
168 |
+
"metadata": {
|
169 |
+
"colab": {
|
170 |
+
"base_uri": "https://localhost:8080/"
|
171 |
+
},
|
172 |
+
"id": "60nEQBKefg3f",
|
173 |
+
"outputId": "91903854-5dc4-4661-c6d6-cae5ab56bdf2"
|
174 |
+
},
|
175 |
+
"outputs": [
|
176 |
+
{
|
177 |
+
"name": "stdout",
|
178 |
+
"output_type": "stream",
|
179 |
+
"text": [
|
180 |
+
"Collecting flask-ngrok\n",
|
181 |
+
" Downloading flask_ngrok-0.0.25-py3-none-any.whl (3.1 kB)\n",
|
182 |
+
"Requirement already satisfied: Flask>=0.8 in /usr/local/lib/python3.10/dist-packages (from flask-ngrok) (2.2.5)\n",
|
183 |
+
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from flask-ngrok) (2.31.0)\n",
|
184 |
+
"Requirement already satisfied: Werkzeug>=2.2.2 in /usr/local/lib/python3.10/dist-packages (from Flask>=0.8->flask-ngrok) (3.0.1)\n",
|
185 |
+
"Requirement already satisfied: Jinja2>=3.0 in /usr/local/lib/python3.10/dist-packages (from Flask>=0.8->flask-ngrok) (3.1.2)\n",
|
186 |
+
"Requirement already satisfied: itsdangerous>=2.0 in /usr/local/lib/python3.10/dist-packages (from Flask>=0.8->flask-ngrok) (2.1.2)\n",
|
187 |
+
"Requirement already satisfied: click>=8.0 in /usr/local/lib/python3.10/dist-packages (from Flask>=0.8->flask-ngrok) (8.1.7)\n",
|
188 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->flask-ngrok) (3.3.2)\n",
|
189 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->flask-ngrok) (3.6)\n",
|
190 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->flask-ngrok) (2.0.7)\n",
|
191 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->flask-ngrok) (2023.11.17)\n",
|
192 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from Jinja2>=3.0->Flask>=0.8->flask-ngrok) (2.1.3)\n",
|
193 |
+
"Installing collected packages: flask-ngrok\n",
|
194 |
+
"Successfully installed flask-ngrok-0.0.25\n",
|
195 |
+
"Collecting pyngrok==4.1.1\n",
|
196 |
+
" Downloading pyngrok-4.1.1.tar.gz (18 kB)\n",
|
197 |
+
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
|
198 |
+
"Requirement already satisfied: future in /usr/local/lib/python3.10/dist-packages (from pyngrok==4.1.1) (0.18.3)\n",
|
199 |
+
"Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from pyngrok==4.1.1) (6.0.1)\n",
|
200 |
+
"Building wheels for collected packages: pyngrok\n",
|
201 |
+
" Building wheel for pyngrok (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
|
202 |
+
" Created wheel for pyngrok: filename=pyngrok-4.1.1-py3-none-any.whl size=15963 sha256=b8ae0d70bccacfc72462262492c96e88942fd5e8113a0a9c6745caea83aad689\n",
|
203 |
+
" Stored in directory: /root/.cache/pip/wheels/4c/7c/4c/632fba2ea8e88d8890102eb07bc922e1ca8fa14db5902c91a8\n",
|
204 |
+
"Successfully built pyngrok\n",
|
205 |
+
"Installing collected packages: pyngrok\n",
|
206 |
+
"Successfully installed pyngrok-4.1.1\n",
|
207 |
+
"Authtoken saved to configuration file: /root/.ngrok2/ngrok.yml\n"
|
208 |
+
]
|
209 |
+
}
|
210 |
+
],
|
211 |
+
"source": [
|
212 |
+
"!pip install flask-ngrok\n",
|
213 |
+
"!pip install pyngrok==4.1.1\n",
|
214 |
+
"!ngrok authtoken <your_token>"
|
215 |
+
]
|
216 |
+
},
|
217 |
+
{
|
218 |
+
"cell_type": "code",
|
219 |
+
"execution_count": 10,
|
220 |
+
"metadata": {
|
221 |
+
"colab": {
|
222 |
+
"base_uri": "https://localhost:8080/"
|
223 |
+
},
|
224 |
+
"id": "Fo-q2Q-XMFen",
|
225 |
+
"outputId": "a762d258-7da5-44fc-fdad-7fef43ddd361"
|
226 |
+
},
|
227 |
+
"outputs": [
|
228 |
+
{
|
229 |
+
"name": "stdout",
|
230 |
+
"output_type": "stream",
|
231 |
+
"text": [
|
232 |
+
"/content\n",
|
233 |
+
"Cloning into 'clothes-virtual-try-on'...\n",
|
234 |
+
"remote: Enumerating objects: 154, done.\u001b[K\n",
|
235 |
+
"remote: Counting objects: 100% (22/22), done.\u001b[K\n",
|
236 |
+
"remote: Compressing objects: 100% (10/10), done.\u001b[K\n",
|
237 |
+
"remote: Total 154 (delta 16), reused 12 (delta 12), pack-reused 132\u001b[K\n",
|
238 |
+
"Receiving objects: 100% (154/154), 20.47 MiB | 33.87 MiB/s, done.\n",
|
239 |
+
"Resolving deltas: 100% (54/54), done.\n"
|
240 |
+
]
|
241 |
+
}
|
242 |
+
],
|
243 |
+
"source": [
|
244 |
+
"import os\n",
|
245 |
+
"%cd /content/\n",
|
246 |
+
"!rm -rf clothes-virtual-try-on\n",
|
247 |
+
"!git clone https://github.com/practice404/clothes-virtual-try-on.git\n",
|
248 |
+
"os.makedirs(\"/content/clothes-virtual-try-on/checkpoints\")"
|
249 |
+
]
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"cell_type": "code",
|
253 |
+
"execution_count": 11,
|
254 |
+
"metadata": {
|
255 |
+
"colab": {
|
256 |
+
"base_uri": "https://localhost:8080/"
|
257 |
+
},
|
258 |
+
"id": "tnud6ptA9ZwL",
|
259 |
+
"outputId": "bc5ee612-eb57-4118-f5e0-6221484c9571"
|
260 |
+
},
|
261 |
+
"outputs": [
|
262 |
+
{
|
263 |
+
"name": "stdout",
|
264 |
+
"output_type": "stream",
|
265 |
+
"text": [
|
266 |
+
"/usr/local/lib/python3.10/dist-packages/gdown/cli.py:126: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.\n",
|
267 |
+
" warnings.warn(\n",
|
268 |
+
"Downloading...\n",
|
269 |
+
"From (uriginal): https://drive.google.com/uc?id=18q4lS7cNt1_X8ewCgya1fq0dSk93jTL6\n",
|
270 |
+
"From (redirected): https://drive.google.com/uc?id=18q4lS7cNt1_X8ewCgya1fq0dSk93jTL6&confirm=t&uuid=6db7053d-df6d-41e3-ac5b-2f924303335f\n",
|
271 |
+
"To: /content/clothes-virtual-try-on/checkpoints/alias_final.pth\n",
|
272 |
+
"100% 402M/402M [00:01<00:00, 254MB/s]\n",
|
273 |
+
"/usr/local/lib/python3.10/dist-packages/gdown/cli.py:126: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.\n",
|
274 |
+
" warnings.warn(\n",
|
275 |
+
"Downloading...\n",
|
276 |
+
"From: https://drive.google.com/uc?id=1uDRPY8gh9sHb3UDonq6ZrINqDOd7pmTz\n",
|
277 |
+
"To: /content/clothes-virtual-try-on/checkpoints/gmm_final.pth\n",
|
278 |
+
"100% 76.2M/76.2M [00:00<00:00, 223MB/s]\n",
|
279 |
+
"/usr/local/lib/python3.10/dist-packages/gdown/cli.py:126: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.\n",
|
280 |
+
" warnings.warn(\n",
|
281 |
+
"Downloading...\n",
|
282 |
+
"From (uriginal): https://drive.google.com/uc?id=1d7lZNLh51Qt5Mi1lXqyi6Asb2ncLrEdC\n",
|
283 |
+
"From (redirected): https://drive.google.com/uc?id=1d7lZNLh51Qt5Mi1lXqyi6Asb2ncLrEdC&confirm=t&uuid=78aeda19-f21d-4598-8bdf-d08a78a99149\n",
|
284 |
+
"To: /content/clothes-virtual-try-on/checkpoints/seg_final.pth\n",
|
285 |
+
"100% 138M/138M [00:01<00:00, 135MB/s]\n"
|
286 |
+
]
|
287 |
+
}
|
288 |
+
],
|
289 |
+
"source": [
|
290 |
+
"!gdown --id 18q4lS7cNt1_X8ewCgya1fq0dSk93jTL6 --output /content/clothes-virtual-try-on/checkpoints/alias_final.pth\n",
|
291 |
+
"!gdown --id 1uDRPY8gh9sHb3UDonq6ZrINqDOd7pmTz --output /content/clothes-virtual-try-on/checkpoints/gmm_final.pth\n",
|
292 |
+
"!gdown --id 1d7lZNLh51Qt5Mi1lXqyi6Asb2ncLrEdC --output /content/clothes-virtual-try-on/checkpoints/seg_final.pth"
|
293 |
+
]
|
294 |
+
},
|
295 |
+
{
|
296 |
+
"cell_type": "code",
|
297 |
+
"execution_count": 12,
|
298 |
+
"metadata": {
|
299 |
+
"colab": {
|
300 |
+
"base_uri": "https://localhost:8080/"
|
301 |
+
},
|
302 |
+
"id": "qWPkjShFMK82",
|
303 |
+
"outputId": "cf51a4d3-4833-4788-9878-92a791a944b8"
|
304 |
+
},
|
305 |
+
"outputs": [
|
306 |
+
{
|
307 |
+
"name": "stdout",
|
308 |
+
"output_type": "stream",
|
309 |
+
"text": [
|
310 |
+
"/usr/local/lib/python3.10/dist-packages/gdown/cli.py:126: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.\n",
|
311 |
+
" warnings.warn(\n",
|
312 |
+
"Downloading...\n",
|
313 |
+
"From (uriginal): https://drive.google.com/uc?id=1ysEoAJNxou7RNuT9iKOxRhjVRNY5RLjx\n",
|
314 |
+
"From (redirected): https://drive.google.com/uc?id=1ysEoAJNxou7RNuT9iKOxRhjVRNY5RLjx&confirm=t&uuid=50dc2d49-15b3-47ed-905f-fc2455dfea07\n",
|
315 |
+
"To: /content/clothes-virtual-try-on/cloth_segm_u2net_latest.pth\n",
|
316 |
+
"100% 177M/177M [00:00<00:00, 178MB/s]\n"
|
317 |
+
]
|
318 |
+
}
|
319 |
+
],
|
320 |
+
"source": [
|
321 |
+
"!gdown --id 1ysEoAJNxou7RNuT9iKOxRhjVRNY5RLjx --output /content/clothes-virtual-try-on/cloth_segm_u2net_latest.pth --no-cookies"
|
322 |
+
]
|
323 |
+
},
|
324 |
+
{
|
325 |
+
"cell_type": "code",
|
326 |
+
"execution_count": 13,
|
327 |
+
"metadata": {
|
328 |
+
"colab": {
|
329 |
+
"base_uri": "https://localhost:8080/"
|
330 |
+
},
|
331 |
+
"id": "I9MhYntvMP84",
|
332 |
+
"outputId": "774ecb82-0f56-4b65-a6fd-0baa9416d75c"
|
333 |
+
},
|
334 |
+
"outputs": [
|
335 |
+
{
|
336 |
+
"name": "stdout",
|
337 |
+
"output_type": "stream",
|
338 |
+
"text": [
|
339 |
+
"/content\n",
|
340 |
+
"Collecting ninja\n",
|
341 |
+
" Downloading ninja-1.11.1.1-py2.py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl (307 kB)\n",
|
342 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m307.2/307.2 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
343 |
+
"\u001b[?25hInstalling collected packages: ninja\n",
|
344 |
+
"Successfully installed ninja-1.11.1.1\n"
|
345 |
+
]
|
346 |
+
}
|
347 |
+
],
|
348 |
+
"source": [
|
349 |
+
"%cd /content/\n",
|
350 |
+
"!pip install ninja"
|
351 |
+
]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "code",
|
355 |
+
"execution_count": 14,
|
356 |
+
"metadata": {
|
357 |
+
"colab": {
|
358 |
+
"base_uri": "https://localhost:8080/"
|
359 |
+
},
|
360 |
+
"id": "Rz9LOnvyMWEJ",
|
361 |
+
"outputId": "55af9ee0-cdf5-495f-a14f-ac93695a5fbe"
|
362 |
+
},
|
363 |
+
"outputs": [
|
364 |
+
{
|
365 |
+
"name": "stdout",
|
366 |
+
"output_type": "stream",
|
367 |
+
"text": [
|
368 |
+
"Cloning into 'Self-Correction-Human-Parsing'...\n",
|
369 |
+
"remote: Enumerating objects: 719, done.\u001b[K\n",
|
370 |
+
"remote: Counting objects: 100% (719/719), done.\u001b[K\n",
|
371 |
+
"remote: Compressing objects: 100% (568/568), done.\u001b[K\n",
|
372 |
+
"remote: Total 719 (delta 149), reused 611 (delta 140), pack-reused 0\u001b[K\n",
|
373 |
+
"Receiving objects: 100% (719/719), 3.88 MiB | 12.81 MiB/s, done.\n",
|
374 |
+
"Resolving deltas: 100% (149/149), done.\n",
|
375 |
+
"/content/Self-Correction-Human-Parsing\n"
|
376 |
+
]
|
377 |
+
}
|
378 |
+
],
|
379 |
+
"source": [
|
380 |
+
"!git clone https://github.com/PeikeLi/Self-Correction-Human-Parsing\n",
|
381 |
+
"%cd Self-Correction-Human-Parsing\n",
|
382 |
+
"!mkdir checkpoints"
|
383 |
+
]
|
384 |
+
},
|
385 |
+
{
|
386 |
+
"cell_type": "code",
|
387 |
+
"execution_count": 15,
|
388 |
+
"metadata": {
|
389 |
+
"colab": {
|
390 |
+
"base_uri": "https://localhost:8080/"
|
391 |
+
},
|
392 |
+
"id": "b2k0DLCsMaG0",
|
393 |
+
"outputId": "a28b0d51-14a3-426b-a2cb-d9b209e2b202"
|
394 |
+
},
|
395 |
+
"outputs": [
|
396 |
+
{
|
397 |
+
"name": "stdout",
|
398 |
+
"output_type": "stream",
|
399 |
+
"text": [
|
400 |
+
"/usr/local/lib/python3.10/dist-packages/gdown/cli.py:126: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.\n",
|
401 |
+
" warnings.warn(\n",
|
402 |
+
"Downloading...\n",
|
403 |
+
"From (uriginal): https://drive.google.com/uc?id=1k4dllHpu0bdx38J7H28rVVLpU-kOHmnH\n",
|
404 |
+
"From (redirected): https://drive.google.com/uc?id=1k4dllHpu0bdx38J7H28rVVLpU-kOHmnH&confirm=t&uuid=83091795-9ef5-449e-8d11-008bbe2238eb\n",
|
405 |
+
"To: /content/Self-Correction-Human-Parsing/exp-schp-201908261155-lip.pth\n",
|
406 |
+
"100% 267M/267M [00:01<00:00, 210MB/s]\n"
|
407 |
+
]
|
408 |
+
}
|
409 |
+
],
|
410 |
+
"source": [
|
411 |
+
"# downloading LIP dataset model\n",
|
412 |
+
"!gdown --id 1k4dllHpu0bdx38J7H28rVVLpU-kOHmnH\n",
|
413 |
+
"!mv /content/Self-Correction-Human-Parsing/exp-schp-201908261155-lip.pth /content/Self-Correction-Human-Parsing/checkpoints/final.pth"
|
414 |
+
]
|
415 |
+
},
|
416 |
+
{
|
417 |
+
"cell_type": "code",
|
418 |
+
"execution_count": 16,
|
419 |
+
"metadata": {
|
420 |
+
"colab": {
|
421 |
+
"base_uri": "https://localhost:8080/"
|
422 |
+
},
|
423 |
+
"id": "2Y4f3VRyMd9Z",
|
424 |
+
"outputId": "c81b2103-5c6b-4af5-aad1-c51dc65b1015"
|
425 |
+
},
|
426 |
+
"outputs": [
|
427 |
+
{
|
428 |
+
"name": "stdout",
|
429 |
+
"output_type": "stream",
|
430 |
+
"text": [
|
431 |
+
"/content\n"
|
432 |
+
]
|
433 |
+
}
|
434 |
+
],
|
435 |
+
"source": [
|
436 |
+
"%cd /content/"
|
437 |
+
]
|
438 |
+
},
|
439 |
+
{
|
440 |
+
"cell_type": "code",
|
441 |
+
"execution_count": null,
|
442 |
+
"metadata": {
|
443 |
+
"id": "1k2dVj4vMhwA"
|
444 |
+
},
|
445 |
+
"outputs": [],
|
446 |
+
"source": [
|
447 |
+
"%%bash\n",
|
448 |
+
"MINICONDA_INSTALLER_SCRIPT=Miniconda3-4.5.4-Linux-x86_64.sh\n",
|
449 |
+
"MINICONDA_PREFIX=/usr/local\n",
|
450 |
+
"wget https://repo.continuum.io/miniconda/$MINICONDA_INSTALLER_SCRIPT\n",
|
451 |
+
"chmod +x $MINICONDA_INSTALLER_SCRIPT\n",
|
452 |
+
"./$MINICONDA_INSTALLER_SCRIPT -b -f -p $MINICONDA_PREFIX\n",
|
453 |
+
"conda install --channel defaults conda python=3.8 --yes\n",
|
454 |
+
"conda update --channel defaults --all --yes"
|
455 |
+
]
|
456 |
+
},
|
457 |
+
{
|
458 |
+
"cell_type": "code",
|
459 |
+
"execution_count": 18,
|
460 |
+
"metadata": {
|
461 |
+
"id": "I6entwp3MliV"
|
462 |
+
},
|
463 |
+
"outputs": [],
|
464 |
+
"source": [
|
465 |
+
"import sys\n",
|
466 |
+
"_ = (sys.path\n",
|
467 |
+
" .append(\"/usr/local/lib/python3.6/site-packages\"))"
|
468 |
+
]
|
469 |
+
},
|
470 |
+
{
|
471 |
+
"cell_type": "code",
|
472 |
+
"execution_count": null,
|
473 |
+
"metadata": {
|
474 |
+
"id": "cseosViyMtYx"
|
475 |
+
},
|
476 |
+
"outputs": [],
|
477 |
+
"source": [
|
478 |
+
"!conda install --channel conda-forge featuretools --yes"
|
479 |
+
]
|
480 |
+
},
|
481 |
+
{
|
482 |
+
"cell_type": "code",
|
483 |
+
"execution_count": null,
|
484 |
+
"metadata": {
|
485 |
+
"id": "BEnK6NI6M0cz"
|
486 |
+
},
|
487 |
+
"outputs": [],
|
488 |
+
"source": [
|
489 |
+
"!pip install opencv-python torchgeometry"
|
490 |
+
]
|
491 |
+
},
|
492 |
+
{
|
493 |
+
"cell_type": "code",
|
494 |
+
"execution_count": null,
|
495 |
+
"metadata": {
|
496 |
+
"id": "HkySNWttHdW2"
|
497 |
+
},
|
498 |
+
"outputs": [],
|
499 |
+
"source": [
|
500 |
+
"!pip install torchvision"
|
501 |
+
]
|
502 |
+
},
|
503 |
+
{
|
504 |
+
"cell_type": "markdown",
|
505 |
+
"metadata": {
|
506 |
+
"id": "-wvtaXujRhNp"
|
507 |
+
},
|
508 |
+
"source": [
|
509 |
+
"# Welcome to Virtual-Cloth-Assistant\n",
|
510 |
+
"\n",
|
511 |
+
"> It'll take some extra time in first execution for setting up and downloading of model weights"
|
512 |
+
]
|
513 |
+
},
|
514 |
+
{
|
515 |
+
"cell_type": "code",
|
516 |
+
"execution_count": 22,
|
517 |
+
"metadata": {
|
518 |
+
"id": "RwcUm39LM8H0"
|
519 |
+
},
|
520 |
+
"outputs": [],
|
521 |
+
"source": [
|
522 |
+
"def make_dir():\n",
|
523 |
+
" os.system(\"cd /content/ && mkdir inputs && cd inputs && mkdir test && cd test && mkdir cloth cloth-mask image image-parse openpose-img openpose-json\")"
|
524 |
+
]
|
525 |
+
},
|
526 |
+
{
|
527 |
+
"cell_type": "code",
|
528 |
+
"execution_count": 28,
|
529 |
+
"metadata": {
|
530 |
+
"id": "9jGRSFuEM9-q"
|
531 |
+
},
|
532 |
+
"outputs": [],
|
533 |
+
"source": [
|
534 |
+
"from flask import Flask, request, send_file, jsonify\n",
|
535 |
+
"from flask_ngrok import run_with_ngrok\n",
|
536 |
+
"from PIL import Image\n",
|
537 |
+
"import base64\n",
|
538 |
+
"import io\n",
|
539 |
+
"\n",
|
540 |
+
"app = Flask(__name__)\n",
|
541 |
+
"run_with_ngrok(app)\n",
|
542 |
+
"\n",
|
543 |
+
"@app.route(\"/\")\n",
|
544 |
+
"def home():\n",
|
545 |
+
" return jsonify(\"hello world\");\n",
|
546 |
+
"\n",
|
547 |
+
"@app.route(\"/api/transform\", methods=['POST'])\n",
|
548 |
+
"def begin():\n",
|
549 |
+
" make_dir()\n",
|
550 |
+
" print(\"data recieved\")\n",
|
551 |
+
" cloth = request.files['cloth']\n",
|
552 |
+
" model = request.files['model']\n",
|
553 |
+
"\n",
|
554 |
+
" cloth = Image.open(cloth.stream)\n",
|
555 |
+
" model = Image.open(model.stream)\n",
|
556 |
+
"\n",
|
557 |
+
" cloth.save(\"/content/inputs/test/cloth/cloth.jpg\")\n",
|
558 |
+
" model.save(\"/content/inputs/test/image/model.jpg\")\n",
|
559 |
+
"\n",
|
560 |
+
" # running script to compute the predictions\n",
|
561 |
+
" os.system(\"python /content/clothes-virtual-try-on/run.py\")\n",
|
562 |
+
"\n",
|
563 |
+
" # loading output\n",
|
564 |
+
" op = os.listdir(\"/content/output\")[0]\n",
|
565 |
+
" op = Image.open(f\"/content/output/{op}\")\n",
|
566 |
+
" buffer = io.BytesIO()\n",
|
567 |
+
" op.save(buffer, 'png')\n",
|
568 |
+
" buffer.seek(0)\n",
|
569 |
+
" os.system(\"rm -rf /content/output/\")\n",
|
570 |
+
" return send_file(buffer, mimetype='image/gif')"
|
571 |
+
]
|
572 |
+
},
|
573 |
+
{
|
574 |
+
"cell_type": "code",
|
575 |
+
"execution_count": 29,
|
576 |
+
"metadata": {
|
577 |
+
"colab": {
|
578 |
+
"base_uri": "https://localhost:8080/"
|
579 |
+
},
|
580 |
+
"id": "Pl52gbcqZ3GL",
|
581 |
+
"outputId": "72cbd1c5-afa8-4411-9b3e-75b192b65a07"
|
582 |
+
},
|
583 |
+
"outputs": [
|
584 |
+
{
|
585 |
+
"name": "stdout",
|
586 |
+
"output_type": "stream",
|
587 |
+
"text": [
|
588 |
+
" * Serving Flask app '__main__'\n",
|
589 |
+
" * Debug mode: off\n"
|
590 |
+
]
|
591 |
+
},
|
592 |
+
{
|
593 |
+
"name": "stderr",
|
594 |
+
"output_type": "stream",
|
595 |
+
"text": [
|
596 |
+
"INFO:werkzeug:\u001b[31m\u001b[1mWARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead.\u001b[0m\n",
|
597 |
+
" * Running on http://127.0.0.1:5000\n",
|
598 |
+
"INFO:werkzeug:\u001b[33mPress CTRL+C to quit\u001b[0m\n"
|
599 |
+
]
|
600 |
+
},
|
601 |
+
{
|
602 |
+
"name": "stdout",
|
603 |
+
"output_type": "stream",
|
604 |
+
"text": [
|
605 |
+
" * Running on http://e793-34-123-73-186.ngrok-free.app\n",
|
606 |
+
" * Traffic stats available on http://127.0.0.1:4040\n",
|
607 |
+
"data recieved\n"
|
608 |
+
]
|
609 |
+
},
|
610 |
+
{
|
611 |
+
"name": "stderr",
|
612 |
+
"output_type": "stream",
|
613 |
+
"text": [
|
614 |
+
"INFO:werkzeug:127.0.0.1 - - [19/Dec/2023 14:36:09] \"POST /api/transform HTTP/1.1\" 200 -\n"
|
615 |
+
]
|
616 |
+
}
|
617 |
+
],
|
618 |
+
"source": [
|
619 |
+
"if __name__ == '__main__':\n",
|
620 |
+
" app.run()"
|
621 |
+
]
|
622 |
+
}
|
623 |
+
],
|
624 |
+
"metadata": {
|
625 |
+
"accelerator": "GPU",
|
626 |
+
"colab": {
|
627 |
+
"collapsed_sections": [
|
628 |
+
"8gqt11Y_RYxU"
|
629 |
+
],
|
630 |
+
"provenance": []
|
631 |
+
},
|
632 |
+
"gpuClass": "standard",
|
633 |
+
"kernelspec": {
|
634 |
+
"display_name": "Python 3",
|
635 |
+
"name": "python3"
|
636 |
+
},
|
637 |
+
"language_info": {
|
638 |
+
"name": "python"
|
639 |
+
}
|
640 |
+
},
|
641 |
+
"nbformat": 4,
|
642 |
+
"nbformat_minor": 0
|
643 |
+
}
|
test.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
import torchgeometry as tgm
|
8 |
+
|
9 |
+
from datasets import VITONDataset, VITONDataLoader
|
10 |
+
from network import SegGenerator, GMM, ALIASGenerator
|
11 |
+
from utils import gen_noise, load_checkpoint, save_images
|
12 |
+
|
13 |
+
|
14 |
+
def get_opt():
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument('--name', type=str, required=True)
|
17 |
+
|
18 |
+
parser.add_argument('-b', '--batch_size', type=int, default=1)
|
19 |
+
parser.add_argument('-j', '--workers', type=int, default=1)
|
20 |
+
parser.add_argument('--load_height', type=int, default=1024)
|
21 |
+
parser.add_argument('--load_width', type=int, default=768)
|
22 |
+
parser.add_argument('--shuffle', action='store_true')
|
23 |
+
|
24 |
+
parser.add_argument('--dataset_dir', type=str, default='./datasets/')
|
25 |
+
parser.add_argument('--dataset_mode', type=str, default='test')
|
26 |
+
parser.add_argument('--dataset_list', type=str, default='test_pairs.txt')
|
27 |
+
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints/')
|
28 |
+
parser.add_argument('--save_dir', type=str, default='./results/')
|
29 |
+
|
30 |
+
parser.add_argument('--display_freq', type=int, default=1)
|
31 |
+
|
32 |
+
parser.add_argument('--seg_checkpoint', type=str, default='seg_final.pth')
|
33 |
+
parser.add_argument('--gmm_checkpoint', type=str, default='gmm_final.pth')
|
34 |
+
parser.add_argument('--alias_checkpoint', type=str, default='alias_final.pth')
|
35 |
+
|
36 |
+
# common
|
37 |
+
parser.add_argument('--semantic_nc', type=int, default=13, help='# of human-parsing map classes')
|
38 |
+
parser.add_argument('--init_type', choices=['normal', 'xavier', 'xavier_uniform', 'kaiming', 'orthogonal', 'none'], default='xavier')
|
39 |
+
parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
|
40 |
+
|
41 |
+
# for GMM
|
42 |
+
parser.add_argument('--grid_size', type=int, default=5)
|
43 |
+
|
44 |
+
# for ALIASGenerator
|
45 |
+
parser.add_argument('--norm_G', type=str, default='spectralaliasinstance')
|
46 |
+
parser.add_argument('--ngf', type=int, default=64, help='# of generator filters in the first conv layer')
|
47 |
+
parser.add_argument('--num_upsampling_layers', choices=['normal', 'more', 'most'], default='most',
|
48 |
+
help='If \'more\', add upsampling layer between the two middle resnet blocks. '
|
49 |
+
'If \'most\', also add one more (upsampling + resnet) layer at the end of the generator.')
|
50 |
+
|
51 |
+
opt = parser.parse_args()
|
52 |
+
return opt
|
53 |
+
|
54 |
+
|
55 |
+
def test(opt, seg, gmm, alias):
|
56 |
+
up = nn.Upsample(size=(opt.load_height, opt.load_width), mode='bilinear')
|
57 |
+
gauss = tgm.image.GaussianBlur((15, 15), (3, 3))
|
58 |
+
gauss.cuda()
|
59 |
+
|
60 |
+
test_dataset = VITONDataset(opt)
|
61 |
+
test_loader = VITONDataLoader(opt, test_dataset)
|
62 |
+
|
63 |
+
with torch.no_grad():
|
64 |
+
for i, inputs in enumerate(test_loader.data_loader):
|
65 |
+
img_names = inputs['img_name']
|
66 |
+
c_names = inputs['c_name']['unpaired']
|
67 |
+
|
68 |
+
img_agnostic = inputs['img_agnostic'].cuda()
|
69 |
+
parse_agnostic = inputs['parse_agnostic'].cuda()
|
70 |
+
pose = inputs['pose'].cuda()
|
71 |
+
c = inputs['cloth']['unpaired'].cuda()
|
72 |
+
cm = inputs['cloth_mask']['unpaired'].cuda()
|
73 |
+
|
74 |
+
# Part 1. Segmentation generation
|
75 |
+
parse_agnostic_down = F.interpolate(parse_agnostic, size=(256, 192), mode='bilinear')
|
76 |
+
pose_down = F.interpolate(pose, size=(256, 192), mode='bilinear')
|
77 |
+
c_masked_down = F.interpolate(c * cm, size=(256, 192), mode='bilinear')
|
78 |
+
cm_down = F.interpolate(cm, size=(256, 192), mode='bilinear')
|
79 |
+
seg_input = torch.cat((cm_down, c_masked_down, parse_agnostic_down, pose_down, gen_noise(cm_down.size()).cuda()), dim=1)
|
80 |
+
|
81 |
+
parse_pred_down = seg(seg_input)
|
82 |
+
parse_pred = gauss(up(parse_pred_down))
|
83 |
+
parse_pred = parse_pred.argmax(dim=1)[:, None]
|
84 |
+
|
85 |
+
parse_old = torch.zeros(parse_pred.size(0), 13, opt.load_height, opt.load_width, dtype=torch.float).cuda()
|
86 |
+
parse_old.scatter_(1, parse_pred, 1.0)
|
87 |
+
|
88 |
+
labels = {
|
89 |
+
0: ['background', [0]],
|
90 |
+
1: ['paste', [2, 4, 7, 8, 9, 10, 11]],
|
91 |
+
2: ['upper', [3]],
|
92 |
+
3: ['hair', [1]],
|
93 |
+
4: ['left_arm', [5]],
|
94 |
+
5: ['right_arm', [6]],
|
95 |
+
6: ['noise', [12]]
|
96 |
+
}
|
97 |
+
parse = torch.zeros(parse_pred.size(0), 7, opt.load_height, opt.load_width, dtype=torch.float).cuda()
|
98 |
+
for j in range(len(labels)):
|
99 |
+
for label in labels[j][1]:
|
100 |
+
parse[:, j] += parse_old[:, label]
|
101 |
+
|
102 |
+
# Part 2. Clothes Deformation
|
103 |
+
agnostic_gmm = F.interpolate(img_agnostic, size=(256, 192), mode='nearest')
|
104 |
+
parse_cloth_gmm = F.interpolate(parse[:, 2:3], size=(256, 192), mode='nearest')
|
105 |
+
pose_gmm = F.interpolate(pose, size=(256, 192), mode='nearest')
|
106 |
+
c_gmm = F.interpolate(c, size=(256, 192), mode='nearest')
|
107 |
+
gmm_input = torch.cat((parse_cloth_gmm, pose_gmm, agnostic_gmm), dim=1)
|
108 |
+
|
109 |
+
_, warped_grid = gmm(gmm_input, c_gmm)
|
110 |
+
warped_c = F.grid_sample(c, warped_grid, padding_mode='border')
|
111 |
+
warped_cm = F.grid_sample(cm, warped_grid, padding_mode='border')
|
112 |
+
|
113 |
+
# Part 3. Try-on synthesis
|
114 |
+
misalign_mask = parse[:, 2:3] - warped_cm
|
115 |
+
misalign_mask[misalign_mask < 0.0] = 0.0
|
116 |
+
parse_div = torch.cat((parse, misalign_mask), dim=1)
|
117 |
+
parse_div[:, 2:3] -= misalign_mask
|
118 |
+
|
119 |
+
output = alias(torch.cat((img_agnostic, pose, warped_c), dim=1), parse, parse_div, misalign_mask)
|
120 |
+
|
121 |
+
unpaired_names = []
|
122 |
+
for img_name, c_name in zip(img_names, c_names):
|
123 |
+
unpaired_names.append('{}_{}'.format(img_name.split('_')[0], c_name))
|
124 |
+
|
125 |
+
save_images(output, unpaired_names, os.path.join(opt.save_dir, opt.name))
|
126 |
+
|
127 |
+
if (i + 1) % opt.display_freq == 0:
|
128 |
+
print("step: {}".format(i + 1))
|
129 |
+
|
130 |
+
|
131 |
+
def main():
|
132 |
+
opt = get_opt()
|
133 |
+
print(opt)
|
134 |
+
|
135 |
+
if not os.path.exists(os.path.join(opt.save_dir, opt.name)):
|
136 |
+
os.makedirs(os.path.join(opt.save_dir, opt.name))
|
137 |
+
|
138 |
+
seg = SegGenerator(opt, input_nc=opt.semantic_nc + 8, output_nc=opt.semantic_nc)
|
139 |
+
gmm = GMM(opt, inputA_nc=7, inputB_nc=3)
|
140 |
+
opt.semantic_nc = 7
|
141 |
+
alias = ALIASGenerator(opt, input_nc=9)
|
142 |
+
opt.semantic_nc = 13
|
143 |
+
|
144 |
+
load_checkpoint(seg, os.path.join(opt.checkpoint_dir, opt.seg_checkpoint))
|
145 |
+
load_checkpoint(gmm, os.path.join(opt.checkpoint_dir, opt.gmm_checkpoint))
|
146 |
+
load_checkpoint(alias, os.path.join(opt.checkpoint_dir, opt.alias_checkpoint))
|
147 |
+
|
148 |
+
seg.cuda().eval()
|
149 |
+
gmm.cuda().eval()
|
150 |
+
alias.cuda().eval()
|
151 |
+
test(opt, seg, gmm, alias)
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == '__main__':
|
155 |
+
main()
|
utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def gen_noise(shape):
|
10 |
+
noise = np.zeros(shape, dtype=np.uint8)
|
11 |
+
### noise
|
12 |
+
noise = cv2.randn(noise, 0, 255)
|
13 |
+
noise = np.asarray(noise / 255, dtype=np.uint8)
|
14 |
+
noise = torch.tensor(noise, dtype=torch.float32)
|
15 |
+
return noise
|
16 |
+
|
17 |
+
|
18 |
+
def save_images(img_tensors, img_names, save_dir):
|
19 |
+
for img_tensor, img_name in zip(img_tensors, img_names):
|
20 |
+
tensor = (img_tensor.clone()+1)*0.5 * 255
|
21 |
+
tensor = tensor.cpu().clamp(0,255)
|
22 |
+
|
23 |
+
try:
|
24 |
+
array = tensor.numpy().astype('uint8')
|
25 |
+
except:
|
26 |
+
array = tensor.detach().numpy().astype('uint8')
|
27 |
+
|
28 |
+
if array.shape[0] == 1:
|
29 |
+
array = array.squeeze(0)
|
30 |
+
elif array.shape[0] == 3:
|
31 |
+
array = array.swapaxes(0, 1).swapaxes(1, 2)
|
32 |
+
|
33 |
+
im = Image.fromarray(array)
|
34 |
+
im.save(os.path.join(save_dir, img_name), format='JPEG')
|
35 |
+
|
36 |
+
|
37 |
+
def load_checkpoint(model, checkpoint_path):
|
38 |
+
if not os.path.exists(checkpoint_path):
|
39 |
+
raise ValueError("'{}' is not a valid checkpoint path".format(checkpoint_path))
|
40 |
+
model.load_state_dict(torch.load(checkpoint_path))
|