diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..082fa1bd9c2578b65d56d1678fe768a39f025107
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,6 @@
+sd-civitai-browser-plus/aria2/lin/aria2 filter=lfs diff=lfs merge=lfs -text
+sd-civitai-browser-plus/aria2/win/aria2.exe filter=lfs diff=lfs merge=lfs -text
+sd-webui-controlnet/annotator/oneformer/oneformer/data/bpe_simple_vocab_16e6.txt.gz filter=lfs diff=lfs merge=lfs -text
+sd-webui-inpaint-anything/images/inpaint_anything_ui_image_1.png filter=lfs diff=lfs merge=lfs -text
+stable-diffusion-webui-aesthetic-gradients/ss.png filter=lfs diff=lfs merge=lfs -text
+stable-diffusion-webui-rembg/preview.png filter=lfs diff=lfs merge=lfs -text
diff --git a/adetailer/.github/ISSUE_TEMPLATE/bug_report.yaml b/adetailer/.github/ISSUE_TEMPLATE/bug_report.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..90a577252d7fb16397fd1b1da4117ca28343c622
--- /dev/null
+++ b/adetailer/.github/ISSUE_TEMPLATE/bug_report.yaml
@@ -0,0 +1,53 @@
+name: Bug report
+description: Create a report
+title: "[Bug]: "
+labels:
+ - bug
+
+body:
+ - type: textarea
+ attributes:
+ label: Describe the bug
+ description: A clear and concise description of what the bug is.
+ placeholder: |
+ Any language accepted
+ 아무 언어 사용가능
+ すべての言語に対応
+ 接受所有语言
+ Se aceptan todos los idiomas
+ Alle Sprachen werden akzeptiert
+ Toutes les langues sont acceptées
+ Принимаются все языки
+
+ - type: textarea
+ attributes:
+ label: Screenshots
+ description: Screenshots related to the issue.
+
+ - type: textarea
+ attributes:
+ label: Console logs, from start to end.
+ description: |
+ The full console log of your terminal.
+ placeholder: |
+ Python ...
+ Version: ...
+ Commit hash: ...
+ Installing requirements
+ ...
+
+ Launching Web UI with arguments: ...
+ [-] ADetailer initialized. version: ...
+ ...
+ ...
+
+ Traceback (most recent call last):
+ ...
+ ...
+ render: Shell
+ validations:
+ required: true
+
+ - type: textarea
+ attributes:
+ label: List of installed extensions
diff --git a/adetailer/.github/ISSUE_TEMPLATE/feature_request.yaml b/adetailer/.github/ISSUE_TEMPLATE/feature_request.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c496137ce397dd21082f674eb4772edc529ee74e
--- /dev/null
+++ b/adetailer/.github/ISSUE_TEMPLATE/feature_request.yaml
@@ -0,0 +1,24 @@
+name: Feature request
+description: Suggest an idea for this project
+title: "[Feature Request]: "
+
+body:
+ - type: textarea
+ attributes:
+ label: Is your feature request related to a problem? Please describe.
+ description: A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
+
+ - type: textarea
+ attributes:
+ label: Describe the solution you'd like
+ description: A clear and concise description of what you want to happen.
+
+ - type: textarea
+ attributes:
+ label: Describe alternatives you've considered
+ description: A clear and concise description of any alternative solutions or features you've considered.
+
+ - type: textarea
+ attributes:
+ label: Additional context
+ description: Add any other context or screenshots about the feature request here.
diff --git a/adetailer/.github/ISSUE_TEMPLATE/question.yaml b/adetailer/.github/ISSUE_TEMPLATE/question.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3c7945407890bc48dc5305fa979047a2bf2f58e8
--- /dev/null
+++ b/adetailer/.github/ISSUE_TEMPLATE/question.yaml
@@ -0,0 +1,10 @@
+name: Question
+description: Write a question
+labels:
+ - question
+
+body:
+ - type: textarea
+ attributes:
+ label: Question
+ description: Please do not write bug reports or feature requests here.
diff --git a/adetailer/.github/workflows/stale.yml b/adetailer/.github/workflows/stale.yml
new file mode 100644
index 0000000000000000000000000000000000000000..79ab8faf57a9e958a3cf784b5ada4994e8c528d2
--- /dev/null
+++ b/adetailer/.github/workflows/stale.yml
@@ -0,0 +1,13 @@
+name: 'Close stale issues and PRs'
+on:
+ schedule:
+ - cron: '30 1 * * *'
+
+jobs:
+ stale:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/stale@v8
+ with:
+ days-before-stale: 23
+ days-before-close: 3
diff --git a/adetailer/.gitignore b/adetailer/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..ce19e6c13e7687fe5bb70cb1138f8a7cb65162d9
--- /dev/null
+++ b/adetailer/.gitignore
@@ -0,0 +1,196 @@
+# Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode
+# Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode
+
+### Python ###
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+### Python Patch ###
+# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
+poetry.toml
+
+# ruff
+.ruff_cache/
+
+# LSP config files
+pyrightconfig.json
+
+### VisualStudioCode ###
+.vscode/*
+!.vscode/settings.json
+!.vscode/tasks.json
+!.vscode/launch.json
+!.vscode/extensions.json
+!.vscode/*.code-snippets
+
+# Local History for Visual Studio Code
+.history/
+
+# Built Visual Studio Code Extensions
+*.vsix
+
+### VisualStudioCode Patch ###
+# Ignore all local history of files
+.history
+.ionide
+
+# End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode
+*.ipynb
diff --git a/adetailer/.pre-commit-config.yaml b/adetailer/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bb5e142adf9fef598bdb1a5d26189324b419b135
--- /dev/null
+++ b/adetailer/.pre-commit-config.yaml
@@ -0,0 +1,20 @@
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.5.0
+ hooks:
+ - id: check-ast
+ - id: trailing-whitespace
+ args: [--markdown-linebreak-ext=md]
+ - id: end-of-file-fixer
+ - id: mixed-line-ending
+
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.1.14
+ hooks:
+ - id: ruff
+ args: [--fix, --exit-non-zero-on-fix]
+
+ - repo: https://github.com/psf/black-pre-commit-mirror
+ rev: 23.12.1
+ hooks:
+ - id: black
diff --git a/adetailer/.vscode/extensions.json b/adetailer/.vscode/extensions.json
new file mode 100644
index 0000000000000000000000000000000000000000..f18738dfa515e878d914b3a872e7469a7eb0f310
--- /dev/null
+++ b/adetailer/.vscode/extensions.json
@@ -0,0 +1,8 @@
+{
+ "recommendations": [
+ "ms-python.black-formatter",
+ "kevinrose.vsc-python-indent",
+ "charliermarsh.ruff",
+ "shardulm94.trailing-spaces"
+ ]
+}
diff --git a/adetailer/.vscode/settings.json b/adetailer/.vscode/settings.json
new file mode 100644
index 0000000000000000000000000000000000000000..ceaf79adb5c0aa538a8514c423cdb62fc10de43f
--- /dev/null
+++ b/adetailer/.vscode/settings.json
@@ -0,0 +1,8 @@
+{
+ "explorer.fileNesting.enabled": true,
+ "explorer.fileNesting.patterns": {
+ "pyproject.toml": ".env, .gitignore, .pre-commit-config.yaml, Taskfile.yml",
+ "README.md": "LICENSE.md, CHANGELOG.md",
+ "install.py": "preload.py"
+ }
+}
diff --git a/adetailer/CHANGELOG.md b/adetailer/CHANGELOG.md
new file mode 100644
index 0000000000000000000000000000000000000000..d1097786b04ed075efcdd8762c01a726f0b520e8
--- /dev/null
+++ b/adetailer/CHANGELOG.md
@@ -0,0 +1,377 @@
+# Changelog
+
+## 2024-01-23
+
+- v24.1.2
+- controlnet 모델에 `Passthrough` 옵션 추가. 입력으로 들어온 컨트롤넷 옵션을 그대로 사용
+- fastapi 엔드포인트 추가
+
+## 2024-01-10
+
+- v24.1.1
+- SDNext 호환 업데이트 (issue #466)
+ - 설정 값 state에 초기값 추가
+ - 위젯 값을 변경할 때마다 state도 변경되게 함 (기존에는 생성 버튼을 누를 때 적용되었음)
+- `inpaint_depth_hand` 컨트롤넷 모델이 depth 모델로 인식되게 함 (issue #463)
+
+## 2024-01-04
+
+- v24.1.0
+- `depth_hand_refiner` ControlNet 추가 (PR #460)
+
+## 2023-12-30
+
+- v23.12.0
+- 파일을 인자로 추가하는 몇몇 스크립트에 대해 deepcopy의 에러를 피하기 위해 script_args 복사 방법을 변경함
+- skip img2img 기능을 사용할 때 너비, 높이를 128로 고정하여 스킵 과정이 조금 더 나아짐
+- img2img inpainting 모드에서 adetailer 자동 비활성화
+- 처음 생성된 params.txt 파일을 항상 유지하도록 변경함
+
+## 2023-11-19
+
+- v23.11.1
+- 기본 스크립트 목록에 negpip 추가
+ - 기존에 설치한 사람에게 소급적용되지는 않음
+- skip img2img 옵션이 2스텝 이상일 때, 제대로 적용되지 않는 문제 수정
+- SD.Next에서 이미지가 np.ndarray로 입력되는 경우 수정
+- 컨트롤넷 경로를 sys.path에 추가하여 --data-dir등을 지정한 경우에도 임포트 에러가 일어나지 않게 함.
+
+## 2023-10-30
+
+- v23.11.0
+- 이미지의 인덱스 계산방법 변경
+ - webui 1.1.0 미만에서 adetailer 실행 불가능하게 함
+- 컨트롤넷 preprocessor 선택지 늘림
+- 추가 yolo 모델 디렉터리를 설정할 수 있는 옵션 추가
+- infotext에 `/`가 있는 항목이 exif에서 복원되지 않는 문제 수정
+ - 이전 버전에 생성된 이미지는 여전히 복원안됨
+- 같은 탭에서 항상 같은 시드를 적용하게 하는 옵션 추가
+- 컨트롤넷 1.1.411 (f2aafcf2beb99a03cbdf7db73852228ccd6bd1d6) 버전을 사용중일 경우,
+ webui 버전 1.6.0 미만에서 사용할 수 없다는 메세지 출력
+
+## 2023-10-15
+
+- v23.10.1
+- xyz grid에 prompt S/R 추가
+- img2img에서 steps가 1일때 에러가 발생하는 샘플러의 처리를 위해 샘플러 이름도 변경하게 수정
+
+## 2023-10-07
+
+- v23.10.0
+- 허깅페이스 모델을 다운로드 실패했을 때, 계속 다운로드를 시도하지 않음
+- img2img에서 img2img단계를 건너뛰는 기능 추가
+- live preview에서 감지 단계를 보여줌 (PR #352)
+
+## 2023-09-20
+
+- v23.9.3
+- ultralytics 버전 8.0.181로 업데이트 (https://github.com/ultralytics/ultralytics/pull/4891)
+- mediapipe와 ultralytics의 lazy import
+
+## 2023-09-10
+
+- v23.9.2
+- (실험적) VAE 선택 기능
+
+## 2023-09-01
+
+- v23.9.1
+- webui 1.6.0에 추가된 인자를 사용해서 생긴 하위 호환 문제 수정
+
+## 2023-08-31
+
+- v23.9.0
+- (실험적) 체크포인트 선택기능
+ - 버그가 있어 리프레시 버튼은 구현에서 빠짐
+- 1.6.0 업데이트에 따라 img2img에서 사용불가능한 샘플러를 선택했을 때 더이상 Euler로 변경하지 않음
+- 유효하지 않은 인자가 전달되었을 때, 에러를 일으키지 않고 대신 adetailer를 비활성화함
+
+
+## 2023-08-25
+
+- v23.8.1
+- xyz grid에서 model을 `None`으로 설정한 이후에 adetailer가 비활성화 되는 문제 수정
+- skip을 눌렀을 때 진행을 멈춤
+- `--medvram-sdxl`을 설정했을 때에도 cpu를 사용하게 함
+
+## 2023-08-14
+
+- v23.8.0
+- `[PROMPT]` 키워드 추가. `ad_prompt` 또는 `ad_negative_prompt`에 사용하면 입력 프롬프트로 대체됨 (PR #243)
+- Only top k largest 옵션 추가 (PR #264)
+- ultralytics 버전 업데이트
+
+
+## 2023-07-31
+
+- v23.7.11
+- separate clip skip 옵션 추가
+- install requirements 정리 (ultralytics 새 버전, mediapipe~=3.20)
+
+## 2023-07-28
+
+- v23.7.10
+- ultralytics, mediapipe import문 정리
+- traceback에서 컬러를 없앰 (api 때문), 라이브러리 버전도 보여주게 설정.
+- huggingface_hub, pydantic을 install.py에서 없앰
+- 안쓰는 컨트롤넷 관련 코드 삭제
+
+
+## 2023-07-23
+
+- v23.7.9
+- `ultralytics.utils` ModuleNotFoundError 해결 (https://github.com/ultralytics/ultralytics/issues/3856)
+- `pydantic` 2.0 이상 버전 설치안되도록 함
+- `controlnet_dir` cmd args 문제 수정 (PR #107)
+
+## 2023-07-20
+
+- v23.7.8
+- `paste_field_names` 추가했던 것을 되돌림
+
+## 2023-07-19
+
+- v23.7.7
+- 인페인팅 단계에서 별도의 샘플러를 선택할 수 있게 옵션을 추가함 (xyz그리드에도 추가)
+- webui 1.0.0-pre 이하 버전에서 batch index 문제 수정
+- 스크립트에 `paste_field_names`을 추가함. 사용되는지는 모르겠음
+
+## 2023-07-16
+
+- v23.7.6
+- `ultralytics 8.0.135`에 추가된 cpuinfo 기능을 위해 `py-cpuinfo`를 미리 설치하게 함. (미리 설치 안하면 cpu나 mps사용할 때 재시작해야함)
+- init_image가 RGB 모드가 아닐 때 RGB로 변경.
+
+## 2023-07-07
+
+- v23.7.4
+- batch count > 1일때 프롬프트의 인덱스 문제 수정
+
+- v23.7.5
+- i2i의 `cached_uc`와 `cached_c`가 p의 `cached_uc`와 `cached_c`가 다른 인스턴스가 되도록 수정
+
+## 2023-07-05
+
+- v23.7.3
+- 버그 수정
+ - `object()`가 json 직렬화 안되는 문제
+ - `process`를 호출함에 따라 배치 카운트가 2이상일 때, all_prompts가 고정되는 문제
+ - `ad-before`와 `ad-preview` 이미지 파일명이 실제 파일명과 다른 문제
+ - pydantic 2.0 호환성 문제
+
+## 2023-07-04
+
+- v23.7.2
+- `mediapipe_face_mesh_eyes_only` 모델 추가: `mediapipe_face_mesh`로 감지한 뒤 눈만 사용함.
+- 매 배치 시작 전에 `scripts.postprocess`를, 후에 `scripts.process`를 호출함.
+ - 컨트롤넷을 사용하면 소요 시간이 조금 늘어나지만 몇몇 문제 해결에 도움이 됨.
+- `lora_block_weight`를 스크립트 화이트리스트에 추가함.
+ - 한번이라도 ADetailer를 사용한 사람은 수동으로 추가해야함.
+
+## 2023-07-03
+
+- v23.7.1
+- `process_images`를 진행한 뒤 `StableDiffusionProcessing` 오브젝트의 close를 호출함
+- api 호출로 사용했는지 확인하는 속성 추가
+- `NansException`이 발생했을 때 중지하지 않고 남은 과정 계속 진행함
+
+## 2023-07-02
+
+- v23.7.0
+- `NansException`이 발생하면 로그에 표시하고 원본 이미지를 반환하게 설정
+- `rich`를 사용한 에러 트레이싱
+ - install.py에 `rich` 추가
+- 생성 중에 컴포넌트의 값을 변경하면 args의 값도 함께 변경되는 문제 수정 (issue #180)
+- 터미널 로그로 ad_prompt와 ad_negative_prompt에 적용된 실제 프롬프트 확인할 수 있음 (입력과 다를 경우에만)
+
+## 2023-06-28
+
+- v23.6.4
+- 최대 모델 수 5 -> 10개
+- ad_prompt와 ad_negative_prompt에 빈칸으로 놔두면 입력 프롬프트가 사용된다는 문구 추가
+- huggingface 모델 다운로드 실패시 로깅
+- 1st 모델이 `None`일 경우 나머지 입력을 무시하던 문제 수정
+- `--use-cpu` 에 `adetailer` 입력 시 cpu로 yolo모델을 사용함
+
+## 2023-06-20
+
+- v23.6.3
+- 컨트롤넷 inpaint 모델에 대해, 3가지 모듈을 사용할 수 있도록 함
+- Noise Multiplier 옵션 추가 (PR #149)
+- pydantic 최소 버전 1.10.8로 설정 (Issue #146)
+
+## 2023-06-05
+
+- v23.6.2
+- xyz_grid에서 ADetailer를 사용할 수 있게함.
+ - 8가지 옵션만 1st 탭에 적용되도록 함.
+
+## 2023-06-01
+
+- v23.6.1
+- `inpaint, scribble, lineart, openpose, tile` 5가지 컨트롤넷 모델 지원 (PR #107)
+- controlnet guidance start, end 인자 추가 (PR #107)
+- `modules.extensions`를 사용하여 컨트롤넷 확장을 불러오고 경로를 알아내로록 변경
+- ui에서 컨트롤넷을 별도 함수로 분리
+
+## 2023-05-30
+
+- v23.6.0
+- 스크립트의 이름을 `After Detailer`에서 `ADetailer`로 변경
+ - API 사용자는 변경 필요함
+- 몇몇 설정 변경
+ - `ad_conf` → `ad_confidence`. 0~100 사이의 int → 0.0~1.0 사이의 float
+ - `ad_inpaint_full_res` → `ad_inpaint_only_masked`
+ - `ad_inpaint_full_res_padding` → `ad_inpaint_only_masked_padding`
+- mediapipe face mesh 모델 추가
+ - mediapipe 최소 버전 `0.10.0`
+
+- rich traceback 제거함
+- huggingface 다운로드 실패할 때 에러가 나지 않게 하고 해당 모델을 제거함
+
+## 2023-05-26
+
+- v23.5.19
+- 1번째 탭에도 `None` 옵션을 추가함
+- api로 ad controlnet model에 inpaint가 아닌 다른 컨트롤넷 모델을 사용하지 못하도록 막음
+- adetailer 진행중에 total tqdm 진행바 업데이트를 멈춤
+- state.inturrupted 상태에서 adetailer 과정을 중지함
+- 컨트롤넷 process를 각 batch가 끝난 순간에만 호출하도록 변경
+
+### 2023-05-25
+
+- v23.5.18
+- 컨트롤넷 관련 수정
+ - unit의 `input_mode`를 `SIMPLE`로 모두 변경
+ - 컨트롤넷 유넷 훅과 하이잭 함수들을 adetailer를 실행할 때에만 되돌리는 기능 추가
+ - adetailer 처리가 끝난 뒤 컨트롤넷 스크립트의 process를 다시 진행함. (batch count 2 이상일때의 문제 해결)
+- 기본 활성 스크립트 목록에서 컨트롤넷을 뺌
+
+### 2023-05-22
+
+- v23.5.17
+- 컨트롤넷 확장이 있으면 컨트롤넷 스크립트를 활성화함. (컨트롤넷 관련 문제 해결)
+- 모든 컴포넌트에 elem_id 설정
+- ui에 버전을 표시함
+
+
+### 2023-05-19
+
+- v23.5.16
+- 추가한 옵션
+ - Mask min/max ratio
+ - Mask merge mode
+ - Restore faces after ADetailer
+- 옵션들을 Accordion으로 묶음
+
+### 2023-05-18
+
+- v23.5.15
+- 필요한 것만 임포트하도록 변경 (vae 로딩 오류 없어짐. 로딩 속도 빨라짐)
+
+### 2023-05-17
+
+- v23.5.14
+- `[SKIP]`으로 ad prompt 일부를 건너뛰는 기능 추가
+- bbox 정렬 옵션 추가
+- sd_webui 타입힌트를 만들어냄
+- enable checker와 관련된 api 오류 수정?
+
+### 2023-05-15
+
+- v23.5.13
+- `[SEP]`으로 ad prompt를 분리하여 적용하는 기능 추가
+- enable checker를 다시 pydantic으로 변경함
+- ui 관련 함수를 adetailer.ui 폴더로 분리함
+- controlnet을 사용할 때 모든 controlnet unit 비활성화
+- adetailer 폴더가 없으면 만들게 함
+
+### 2023-05-13
+
+- v23.5.12
+- `ad_enable`을 제외한 입력이 dict타입으로 들어오도록 변경
+ - web api로 사용할 때에 특히 사용하기 쉬움
+ - web api breaking change
+- `mask_preprocess` 인자를 넣지 않았던 오류 수정 (PR #47)
+- huggingface에서 모델을 다운로드하지 않는 옵션 추가 `--ad-no-huggingface`
+
+### 2023-05-12
+
+- v23.5.11
+- `ultralytics` 알람 제거
+- 필요없는 exif 인자 더 제거함
+- `use separate steps` 옵션 추가
+- ui 배치를 조정함
+
+### 2023-05-09
+
+- v23.5.10
+- 선택한 스크립트만 ADetailer에 적용하는 옵션 추가, 기본값 `True`. 설정 탭에서 지정가능.
+ - 기본값: `dynamic_prompting,dynamic_thresholding,wildcards,wildcard_recursive`
+- `person_yolov8s-seg.pt` 모델 추가
+- `ultralytics`의 최소 버전을 `8.0.97`로 설정 (C:\\ 문제 해결된 버전)
+
+### 2023-05-08
+
+- v23.5.9
+- 2가지 이상의 모델을 사용할 수 있음. 기본값: 2, 최대: 5
+- segment 모델을 사용할 수 있게 함. `person_yolov8n-seg.pt` 추가
+
+### 2023-05-07
+
+- v23.5.8
+- 프롬프트와 네거티브 프롬프트에 방향키 지원 (PR #24)
+- `mask_preprocess`를 추가함. 이전 버전과 시드값이 달라질 가능성 있음!
+- 이미지 처리가 일어났을 때에만 before이미지를 저장함
+- 설정창의 레이블을 ADetailer 대신 더 적절하게 수정함
+
+### 2023-05-06
+
+- v23.5.7
+- `ad_use_cfg_scale` 옵션 추가. cfg 스케일을 따로 사용할지 말지 결정함.
+- `ad_enable` 기본값을 `True`에서 `False`로 변경
+- `ad_model`의 기본값을 `None`에서 첫번째 모델로 변경
+- 최소 2개의 입력(ad_enable, ad_model)만 들어오면 작동하게 변경.
+
+- v23.5.7.post0
+- `init_controlnet_ext`을 controlnet_exists == True일때에만 실행
+- webui를 C드라이브 바로 밑에 설치한 사람들에게 `ultralytics` 경고 표시
+
+### 2023-05-05 (어린이날)
+
+- v23.5.5
+- `Save images before ADetailer` 옵션 추가
+- 입력으로 들어온 인자와 ALL_ARGS의 길이가 다르면 에러메세지
+- README.md에 설치방법 추가
+
+- v23.5.6
+- get_args에서 IndexError가 발생하면 자세한 에러메세지를 볼 수 있음
+- AdetailerArgs에 extra_params 내장
+- scripts_args를 딥카피함
+- postprocess_image를 약간 분리함
+
+- v23.5.6.post0
+- `init_controlnet_ext`에서 에러메세지를 자세히 볼 수 있음
+
+### 2023-05-04
+
+- v23.5.4
+- use pydantic for arguments validation
+- revert: ad_model to `None` as default
+- revert: `__future__` imports
+- lazily import yolo and mediapipe
+
+### 2023-05-03
+
+- v23.5.3.post0
+- remove `__future__` imports
+- change to copy scripts and scripts args
+
+- v23.5.3.post1
+- change default ad_model from `None`
+
+### 2023-05-02
+
+- v23.5.3
+- Remove `None` from model list and add `Enable ADetailer` checkbox.
+- install.py `skip_install` fix.
diff --git a/adetailer/LICENSE.md b/adetailer/LICENSE.md
new file mode 100644
index 0000000000000000000000000000000000000000..15bc112be2418653138c879e8f15c7b001229324
--- /dev/null
+++ b/adetailer/LICENSE.md
@@ -0,0 +1,662 @@
+
+ GNU AFFERO GENERAL PUBLIC LICENSE
+ Version 3, 19 November 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU Affero General Public License is a free, copyleft license for
+software and other kinds of works, specifically designed to ensure
+cooperation with the community in the case of network server software.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+our General Public Licenses are intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ Developers that use our General Public Licenses protect your rights
+with two steps: (1) assert copyright on the software, and (2) offer
+you this License which gives you legal permission to copy, distribute
+and/or modify the software.
+
+ A secondary benefit of defending all users' freedom is that
+improvements made in alternate versions of the program, if they
+receive widespread use, become available for other developers to
+incorporate. Many developers of free software are heartened and
+encouraged by the resulting cooperation. However, in the case of
+software used on network servers, this result may fail to come about.
+The GNU General Public License permits making a modified version and
+letting the public access it on a server without ever releasing its
+source code to the public.
+
+ The GNU Affero General Public License is designed specifically to
+ensure that, in such cases, the modified source code becomes available
+to the community. It requires the operator of a network server to
+provide the source code of the modified version running there to the
+users of that server. Therefore, public use of a modified version, on
+a publicly accessible server, gives the public access to the source
+code of the modified version.
+
+ An older license, called the Affero General Public License and
+published by Affero, was designed to accomplish similar goals. This is
+a different license, not a version of the Affero GPL, but Affero has
+released a new version of the Affero GPL which permits relicensing under
+this license.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU Affero General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Remote Network Interaction; Use with the GNU General Public License.
+
+ Notwithstanding any other provision of this License, if you modify the
+Program, your modified version must prominently offer all users
+interacting with it remotely through a computer network (if your version
+supports such interaction) an opportunity to receive the Corresponding
+Source of your version by providing access to the Corresponding Source
+from a network server at no charge, through some standard or customary
+means of facilitating copying of software. This Corresponding Source
+shall include the Corresponding Source for any work covered by version 3
+of the GNU General Public License that is incorporated pursuant to the
+following paragraph.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the work with which it is combined will remain governed by version
+3 of the GNU General Public License.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU Affero General Public License from time to time. Such new versions
+will be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU Affero General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU Affero General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU Affero General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published
+ by the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for the
+specific requirements.
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU AGPL, see
+.
diff --git a/adetailer/README.md b/adetailer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3a00116f781be7eb2802529745423489635c9076
--- /dev/null
+++ b/adetailer/README.md
@@ -0,0 +1,97 @@
+# ADetailer
+
+ADetailer is a extension for stable diffusion webui, similar to Detection Detailer, except it uses ultralytics instead of the mmdet.
+
+## Install
+
+(from Mikubill/sd-webui-controlnet)
+
+1. Open "Extensions" tab.
+2. Open "Install from URL" tab in the tab.
+3. Enter `https://github.com/Bing-su/adetailer.git` to "URL for extension's git repository".
+4. Press "Install" button.
+5. Wait 5 seconds, and you will see the message "Installed into stable-diffusion-webui\extensions\adetailer. Use Installed tab to restart".
+6. Go to "Installed" tab, click "Check for updates", and then click "Apply and restart UI". (The next time you can also use this method to update extensions.)
+7. Completely restart A1111 webui including your terminal. (If you do not know what is a "terminal", you can reboot your computer: turn your computer off and turn it on again.)
+
+You can now install it directly from the Extensions tab.
+
+![image](https://i.imgur.com/g6GdRBT.png)
+
+You **DON'T** need to download any model from huggingface.
+
+## Options
+
+| Model, Prompts | | |
+| --------------------------------- | --------------------------------------------------------------------------------- | ------------------------------------------------- |
+| ADetailer model | Determine what to detect. | `None` = disable |
+| ADetailer prompt, negative prompt | Prompts and negative prompts to apply | If left blank, it will use the same as the input. |
+| Skip img2img | Skip img2img. In practice, this works by changing the step count of img2img to 1. | img2img only |
+
+| Detection | | |
+| ------------------------------------ | -------------------------------------------------------------------------------------------- | ------------ |
+| Detection model confidence threshold | Only objects with a detection model confidence above this threshold are used for inpainting. | |
+| Mask min/max ratio | Only use masks whose area is between those ratios for the area of the entire image. | |
+| Mask only the top k largest | Only use the k objects with the largest area of the bbox. | 0 to disable |
+
+If you want to exclude objects in the background, try setting the min ratio to around `0.01`.
+
+| Mask Preprocessing | | |
+| ------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------- |
+| Mask x, y offset | Moves the mask horizontally and vertically by | |
+| Mask erosion (-) / dilation (+) | Enlarge or reduce the detected mask. | [opencv example](https://docs.opencv.org/4.7.0/db/df6/tutorial_erosion_dilatation.html) |
+| Mask merge mode | `None`: Inpaint each mask `Merge`: Merge all masks and inpaint `Merge and Invert`: Merge all masks and Invert, then inpaint | |
+
+Applied in this order: x, y offset → erosion/dilation → merge/invert.
+
+#### Inpainting
+
+Each option corresponds to a corresponding option on the inpaint tab. Therefore, please refer to the inpaint tab for usage details on how to use each option.
+
+## ControlNet Inpainting
+
+You can use the ControlNet extension if you have ControlNet installed and ControlNet models.
+
+Support `inpaint, scribble, lineart, openpose, tile` controlnet models. Once you choose a model, the preprocessor is set automatically. It works separately from the model set by the Controlnet extension.
+
+## Advanced Options
+
+API request example: [wiki/API](https://github.com/Bing-su/adetailer/wiki/API)
+
+`ui-config.json` entries: [wiki/ui-config.json](https://github.com/Bing-su/adetailer/wiki/ui-config.json)
+
+`[SEP], [SKIP]` tokens: [wiki/Advanced](https://github.com/Bing-su/adetailer/wiki/Advanced)
+
+## Media
+
+- 🎥 [どこよりも詳しいAfter Detailer (adetailer)の使い方① 【Stable Diffusion】](https://youtu.be/sF3POwPUWCE)
+- 🎥 [どこよりも詳しいAfter Detailer (adetailer)の使い方② 【Stable Diffusion】](https://youtu.be/urNISRdbIEg)
+
+## Model
+
+| Model | Target | mAP 50 | mAP 50-95 |
+| --------------------- | --------------------- | ----------------------------- | ----------------------------- |
+| face_yolov8n.pt | 2D / realistic face | 0.660 | 0.366 |
+| face_yolov8s.pt | 2D / realistic face | 0.713 | 0.404 |
+| hand_yolov8n.pt | 2D / realistic hand | 0.767 | 0.505 |
+| person_yolov8n-seg.pt | 2D / realistic person | 0.782 (bbox) 0.761 (mask) | 0.555 (bbox) 0.460 (mask) |
+| person_yolov8s-seg.pt | 2D / realistic person | 0.824 (bbox) 0.809 (mask) | 0.605 (bbox) 0.508 (mask) |
+| mediapipe_face_full | realistic face | - | - |
+| mediapipe_face_short | realistic face | - | - |
+| mediapipe_face_mesh | realistic face | - | - |
+
+The yolo models can be found on huggingface [Bingsu/adetailer](https://huggingface.co/Bingsu/adetailer).
+
+### Additional Model
+
+Put your [ultralytics](https://github.com/ultralytics/ultralytics) yolo model in `webui/models/adetailer`. The model name should end with `.pt` or `.pth`.
+
+It must be a bbox detection or segment model and use all label.
+
+## How it works
+
+ADetailer works in three simple steps.
+
+1. Create an image.
+2. Detect object with a detection model and create a mask image.
+3. Inpaint using the image from 1 and the mask from 2.
diff --git a/adetailer/Taskfile.yml b/adetailer/Taskfile.yml
new file mode 100644
index 0000000000000000000000000000000000000000..70e0d357d11f0b8b908fb31f62ce0c104042d2c2
--- /dev/null
+++ b/adetailer/Taskfile.yml
@@ -0,0 +1,27 @@
+# https://taskfile.dev
+
+version: "3"
+
+dotenv:
+ - .env
+
+tasks:
+ default:
+ cmds:
+ - echo "$PYTHON"
+ - echo "$WEBUI"
+ silent: true
+
+ launch:
+ dir: "{{.WEBUI}}"
+ cmds:
+ - "{{.PYTHON}} launch.py --xformers --api"
+ silent: true
+
+ lint:
+ cmds:
+ - pre-commit run -a
+
+ update:
+ cmds:
+ - "{{.PYTHON}} -m pip install -U ultralytics mediapipe ruff pre-commit black"
diff --git a/adetailer/adetailer/__init__.py b/adetailer/adetailer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce389590f91d0e0924941dad30809539077ed6d1
--- /dev/null
+++ b/adetailer/adetailer/__init__.py
@@ -0,0 +1,18 @@
+from .__version__ import __version__
+from .args import ALL_ARGS, ADetailerArgs
+from .common import PredictOutput, get_models
+from .mediapipe import mediapipe_predict
+from .ultralytics import ultralytics_predict
+
+AFTER_DETAILER = "ADetailer"
+
+__all__ = [
+ "__version__",
+ "ADetailerArgs",
+ "AFTER_DETAILER",
+ "ALL_ARGS",
+ "PredictOutput",
+ "get_models",
+ "mediapipe_predict",
+ "ultralytics_predict",
+]
diff --git a/adetailer/adetailer/__version__.py b/adetailer/adetailer/__version__.py
new file mode 100644
index 0000000000000000000000000000000000000000..869aeadd60f0e4efad6bb76e2915d1ad104cf222
--- /dev/null
+++ b/adetailer/adetailer/__version__.py
@@ -0,0 +1 @@
+__version__ = "24.1.2"
diff --git a/adetailer/adetailer/args.py b/adetailer/adetailer/args.py
new file mode 100644
index 0000000000000000000000000000000000000000..130bdeafa44d09e501fa1df87ccd7b3da8417e20
--- /dev/null
+++ b/adetailer/adetailer/args.py
@@ -0,0 +1,236 @@
+from __future__ import annotations
+
+from collections import UserList
+from dataclasses import dataclass
+from functools import cached_property, partial
+from typing import Any, Literal, NamedTuple, Optional
+
+import pydantic
+from pydantic import (
+ BaseModel,
+ Extra,
+ NonNegativeFloat,
+ NonNegativeInt,
+ PositiveInt,
+ confloat,
+ conint,
+ constr,
+ validator,
+)
+
+
+@dataclass
+class SkipImg2ImgOrig:
+ steps: int
+ sampler_name: str
+ width: int
+ height: int
+
+
+class Arg(NamedTuple):
+ attr: str
+ name: str
+
+
+class ArgsList(UserList):
+ @cached_property
+ def attrs(self) -> tuple[str]:
+ return tuple(attr for attr, _ in self)
+
+ @cached_property
+ def names(self) -> tuple[str]:
+ return tuple(name for _, name in self)
+
+
+class ADetailerArgs(BaseModel, extra=Extra.forbid):
+ ad_model: str = "None"
+ ad_prompt: str = ""
+ ad_negative_prompt: str = ""
+ ad_confidence: confloat(ge=0.0, le=1.0) = 0.3
+ ad_mask_k_largest: NonNegativeInt = 0
+ ad_mask_min_ratio: confloat(ge=0.0, le=1.0) = 0.0
+ ad_mask_max_ratio: confloat(ge=0.0, le=1.0) = 1.0
+ ad_dilate_erode: int = 4
+ ad_x_offset: int = 0
+ ad_y_offset: int = 0
+ ad_mask_merge_invert: Literal["None", "Merge", "Merge and Invert"] = "None"
+ ad_mask_blur: NonNegativeInt = 4
+ ad_denoising_strength: confloat(ge=0.0, le=1.0) = 0.4
+ ad_inpaint_only_masked: bool = True
+ ad_inpaint_only_masked_padding: NonNegativeInt = 32
+ ad_use_inpaint_width_height: bool = False
+ ad_inpaint_width: PositiveInt = 512
+ ad_inpaint_height: PositiveInt = 512
+ ad_use_steps: bool = False
+ ad_steps: PositiveInt = 28
+ ad_use_cfg_scale: bool = False
+ ad_cfg_scale: NonNegativeFloat = 7.0
+ ad_use_checkpoint: bool = False
+ ad_checkpoint: Optional[str] = None
+ ad_use_vae: bool = False
+ ad_vae: Optional[str] = None
+ ad_use_sampler: bool = False
+ ad_sampler: str = "DPM++ 2M Karras"
+ ad_use_noise_multiplier: bool = False
+ ad_noise_multiplier: confloat(ge=0.5, le=1.5) = 1.0
+ ad_use_clip_skip: bool = False
+ ad_clip_skip: conint(ge=1, le=12) = 1
+ ad_restore_face: bool = False
+ ad_controlnet_model: str = "None"
+ ad_controlnet_module: str = "None"
+ ad_controlnet_weight: confloat(ge=0.0, le=1.0) = 1.0
+ ad_controlnet_guidance_start: confloat(ge=0.0, le=1.0) = 0.0
+ ad_controlnet_guidance_end: confloat(ge=0.0, le=1.0) = 1.0
+ is_api: bool = True
+
+ @validator("is_api", pre=True)
+ def is_api_validator(cls, v: Any): # noqa: N805
+ "tuple is json serializable but cannot be made with json deserialize."
+ return type(v) is not tuple
+
+ @staticmethod
+ def ppop(
+ p: dict[str, Any],
+ key: str,
+ pops: list[str] | None = None,
+ cond: Any = None,
+ ) -> None:
+ if pops is None:
+ pops = [key]
+ if key not in p:
+ return
+ value = p[key]
+ cond = (not bool(value)) if cond is None else value == cond
+
+ if cond:
+ for k in pops:
+ p.pop(k, None)
+
+ def extra_params(self, suffix: str = "") -> dict[str, Any]:
+ if self.ad_model == "None":
+ return {}
+
+ p = {name: getattr(self, attr) for attr, name in ALL_ARGS}
+ ppop = partial(self.ppop, p)
+
+ ppop("ADetailer prompt")
+ ppop("ADetailer negative prompt")
+ ppop("ADetailer mask only top k largest", cond=0)
+ ppop("ADetailer mask min ratio", cond=0.0)
+ ppop("ADetailer mask max ratio", cond=1.0)
+ ppop("ADetailer x offset", cond=0)
+ ppop("ADetailer y offset", cond=0)
+ ppop("ADetailer mask merge invert", cond="None")
+ ppop("ADetailer inpaint only masked", ["ADetailer inpaint padding"])
+ ppop(
+ "ADetailer use inpaint width height",
+ [
+ "ADetailer use inpaint width height",
+ "ADetailer inpaint width",
+ "ADetailer inpaint height",
+ ],
+ )
+ ppop(
+ "ADetailer use separate steps",
+ ["ADetailer use separate steps", "ADetailer steps"],
+ )
+ ppop(
+ "ADetailer use separate CFG scale",
+ ["ADetailer use separate CFG scale", "ADetailer CFG scale"],
+ )
+ ppop(
+ "ADetailer use separate checkpoint",
+ ["ADetailer use separate checkpoint", "ADetailer checkpoint"],
+ )
+ ppop(
+ "ADetailer use separate VAE",
+ ["ADetailer use separate VAE", "ADetailer VAE"],
+ )
+ ppop(
+ "ADetailer use separate sampler",
+ ["ADetailer use separate sampler", "ADetailer sampler"],
+ )
+ ppop(
+ "ADetailer use separate noise multiplier",
+ ["ADetailer use separate noise multiplier", "ADetailer noise multiplier"],
+ )
+
+ ppop(
+ "ADetailer use separate CLIP skip",
+ ["ADetailer use separate CLIP skip", "ADetailer CLIP skip"],
+ )
+
+ ppop("ADetailer restore face")
+ ppop(
+ "ADetailer ControlNet model",
+ [
+ "ADetailer ControlNet model",
+ "ADetailer ControlNet module",
+ "ADetailer ControlNet weight",
+ "ADetailer ControlNet guidance start",
+ "ADetailer ControlNet guidance end",
+ ],
+ cond="None",
+ )
+ ppop("ADetailer ControlNet module", cond="None")
+ ppop("ADetailer ControlNet weight", cond=1.0)
+ ppop("ADetailer ControlNet guidance start", cond=0.0)
+ ppop("ADetailer ControlNet guidance end", cond=1.0)
+
+ if suffix:
+ p = {k + suffix: v for k, v in p.items()}
+
+ return p
+
+
+_all_args = [
+ ("ad_model", "ADetailer model"),
+ ("ad_prompt", "ADetailer prompt"),
+ ("ad_negative_prompt", "ADetailer negative prompt"),
+ ("ad_confidence", "ADetailer confidence"),
+ ("ad_mask_k_largest", "ADetailer mask only top k largest"),
+ ("ad_mask_min_ratio", "ADetailer mask min ratio"),
+ ("ad_mask_max_ratio", "ADetailer mask max ratio"),
+ ("ad_x_offset", "ADetailer x offset"),
+ ("ad_y_offset", "ADetailer y offset"),
+ ("ad_dilate_erode", "ADetailer dilate erode"),
+ ("ad_mask_merge_invert", "ADetailer mask merge invert"),
+ ("ad_mask_blur", "ADetailer mask blur"),
+ ("ad_denoising_strength", "ADetailer denoising strength"),
+ ("ad_inpaint_only_masked", "ADetailer inpaint only masked"),
+ ("ad_inpaint_only_masked_padding", "ADetailer inpaint padding"),
+ ("ad_use_inpaint_width_height", "ADetailer use inpaint width height"),
+ ("ad_inpaint_width", "ADetailer inpaint width"),
+ ("ad_inpaint_height", "ADetailer inpaint height"),
+ ("ad_use_steps", "ADetailer use separate steps"),
+ ("ad_steps", "ADetailer steps"),
+ ("ad_use_cfg_scale", "ADetailer use separate CFG scale"),
+ ("ad_cfg_scale", "ADetailer CFG scale"),
+ ("ad_use_checkpoint", "ADetailer use separate checkpoint"),
+ ("ad_checkpoint", "ADetailer checkpoint"),
+ ("ad_use_vae", "ADetailer use separate VAE"),
+ ("ad_vae", "ADetailer VAE"),
+ ("ad_use_sampler", "ADetailer use separate sampler"),
+ ("ad_sampler", "ADetailer sampler"),
+ ("ad_use_noise_multiplier", "ADetailer use separate noise multiplier"),
+ ("ad_noise_multiplier", "ADetailer noise multiplier"),
+ ("ad_use_clip_skip", "ADetailer use separate CLIP skip"),
+ ("ad_clip_skip", "ADetailer CLIP skip"),
+ ("ad_restore_face", "ADetailer restore face"),
+ ("ad_controlnet_model", "ADetailer ControlNet model"),
+ ("ad_controlnet_module", "ADetailer ControlNet module"),
+ ("ad_controlnet_weight", "ADetailer ControlNet weight"),
+ ("ad_controlnet_guidance_start", "ADetailer ControlNet guidance start"),
+ ("ad_controlnet_guidance_end", "ADetailer ControlNet guidance end"),
+]
+
+_args = [Arg(*args) for args in _all_args]
+ALL_ARGS = ArgsList(_args)
+
+BBOX_SORTBY = [
+ "None",
+ "Position (left to right)",
+ "Position (center to edge)",
+ "Area (large to small)",
+]
+MASK_MERGE_INVERT = ["None", "Merge", "Merge and Invert"]
diff --git a/adetailer/adetailer/common.py b/adetailer/adetailer/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b697826f80b906cfc8debcb2a99bde78fd3691d
--- /dev/null
+++ b/adetailer/adetailer/common.py
@@ -0,0 +1,132 @@
+from __future__ import annotations
+
+from collections import OrderedDict
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Optional, Union
+
+from huggingface_hub import hf_hub_download
+from PIL import Image, ImageDraw
+from rich import print
+
+repo_id = "Bingsu/adetailer"
+_download_failed = False
+
+
+@dataclass
+class PredictOutput:
+ bboxes: list[list[int | float]] = field(default_factory=list)
+ masks: list[Image.Image] = field(default_factory=list)
+ preview: Optional[Image.Image] = None
+
+
+def hf_download(file: str):
+ global _download_failed
+
+ if _download_failed:
+ return "INVALID"
+
+ try:
+ path = hf_hub_download(repo_id, file)
+ except Exception:
+ msg = f"[-] ADetailer: Failed to load model {file!r} from huggingface"
+ print(msg)
+ path = "INVALID"
+ _download_failed = True
+ return path
+
+
+def scan_model_dir(path_: str | Path) -> list[Path]:
+ if not path_ or not (path := Path(path_)).is_dir():
+ return []
+ return [p for p in path.rglob("*") if p.is_file() and p.suffix in (".pt", ".pth")]
+
+
+def get_models(
+ model_dir: str | Path, extra_dir: str | Path = "", huggingface: bool = True
+) -> OrderedDict[str, str | None]:
+ model_paths = [*scan_model_dir(model_dir), *scan_model_dir(extra_dir)]
+
+ models = OrderedDict()
+ if huggingface:
+ models.update(
+ {
+ "face_yolov8n.pt": hf_download("face_yolov8n.pt"),
+ "face_yolov8s.pt": hf_download("face_yolov8s.pt"),
+ "hand_yolov8n.pt": hf_download("hand_yolov8n.pt"),
+ "person_yolov8n-seg.pt": hf_download("person_yolov8n-seg.pt"),
+ "person_yolov8s-seg.pt": hf_download("person_yolov8s-seg.pt"),
+ }
+ )
+ models.update(
+ {
+ "mediapipe_face_full": None,
+ "mediapipe_face_short": None,
+ "mediapipe_face_mesh": None,
+ "mediapipe_face_mesh_eyes_only": None,
+ }
+ )
+
+ invalid_keys = [k for k, v in models.items() if v == "INVALID"]
+ for key in invalid_keys:
+ models.pop(key)
+
+ for path in model_paths:
+ if path.name in models:
+ continue
+ models[path.name] = str(path)
+
+ return models
+
+
+def create_mask_from_bbox(
+ bboxes: list[list[float]], shape: tuple[int, int]
+) -> list[Image.Image]:
+ """
+ Parameters
+ ----------
+ bboxes: list[list[float]]
+ list of [x1, y1, x2, y2]
+ bounding boxes
+ shape: tuple[int, int]
+ shape of the image (width, height)
+
+ Returns
+ -------
+ masks: list[Image.Image]
+ A list of masks
+
+ """
+ masks = []
+ for bbox in bboxes:
+ mask = Image.new("L", shape, 0)
+ mask_draw = ImageDraw.Draw(mask)
+ mask_draw.rectangle(bbox, fill=255)
+ masks.append(mask)
+ return masks
+
+
+def create_bbox_from_mask(
+ masks: list[Image.Image], shape: tuple[int, int]
+) -> list[list[int]]:
+ """
+ Parameters
+ ----------
+ masks: list[Image.Image]
+ A list of masks
+ shape: tuple[int, int]
+ shape of the image (width, height)
+
+ Returns
+ -------
+ bboxes: list[list[float]]
+ A list of bounding boxes
+
+ """
+ bboxes = []
+ for mask in masks:
+ mask = mask.resize(shape)
+ bbox = mask.getbbox()
+ if bbox is not None:
+ bboxes.append(list(bbox))
+ return bboxes
diff --git a/adetailer/adetailer/mask.py b/adetailer/adetailer/mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2f3680b2659eeea54b142f45c6de9e99445b546
--- /dev/null
+++ b/adetailer/adetailer/mask.py
@@ -0,0 +1,256 @@
+from __future__ import annotations
+
+from enum import IntEnum
+from functools import partial, reduce
+from math import dist
+
+import cv2
+import numpy as np
+from PIL import Image, ImageChops
+
+from adetailer.args import MASK_MERGE_INVERT
+from adetailer.common import PredictOutput
+
+
+class SortBy(IntEnum):
+ NONE = 0
+ LEFT_TO_RIGHT = 1
+ CENTER_TO_EDGE = 2
+ AREA = 3
+
+
+class MergeInvert(IntEnum):
+ NONE = 0
+ MERGE = 1
+ MERGE_INVERT = 2
+
+
+def _dilate(arr: np.ndarray, value: int) -> np.ndarray:
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value))
+ return cv2.dilate(arr, kernel, iterations=1)
+
+
+def _erode(arr: np.ndarray, value: int) -> np.ndarray:
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value))
+ return cv2.erode(arr, kernel, iterations=1)
+
+
+def dilate_erode(img: Image.Image, value: int) -> Image.Image:
+ """
+ The dilate_erode function takes an image and a value.
+ If the value is positive, it dilates the image by that amount.
+ If the value is negative, it erodes the image by that amount.
+
+ Parameters
+ ----------
+ img: PIL.Image.Image
+ the image to be processed
+ value: int
+ kernel size of dilation or erosion
+
+ Returns
+ -------
+ PIL.Image.Image
+ The image that has been dilated or eroded
+ """
+ if value == 0:
+ return img
+
+ arr = np.array(img)
+ arr = _dilate(arr, value) if value > 0 else _erode(arr, -value)
+
+ return Image.fromarray(arr)
+
+
+def offset(img: Image.Image, x: int = 0, y: int = 0) -> Image.Image:
+ """
+ The offset function takes an image and offsets it by a given x(→) and y(↑) value.
+
+ Parameters
+ ----------
+ mask: Image.Image
+ Pass the mask image to the function
+ x: int
+ →
+ y: int
+ ↑
+
+ Returns
+ -------
+ PIL.Image.Image
+ A new image that is offset by x and y
+ """
+ return ImageChops.offset(img, x, -y)
+
+
+def is_all_black(img: Image.Image) -> bool:
+ arr = np.array(img)
+ return cv2.countNonZero(arr) == 0
+
+
+def bbox_area(bbox: list[float]):
+ return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
+
+
+def mask_preprocess(
+ masks: list[Image.Image],
+ kernel: int = 0,
+ x_offset: int = 0,
+ y_offset: int = 0,
+ merge_invert: int | MergeInvert | str = MergeInvert.NONE,
+) -> list[Image.Image]:
+ """
+ The mask_preprocess function takes a list of masks and preprocesses them.
+ It dilates and erodes the masks, and offsets them by x_offset and y_offset.
+
+ Parameters
+ ----------
+ masks: list[Image.Image]
+ A list of masks
+ kernel: int
+ kernel size of dilation or erosion
+ x_offset: int
+ →
+ y_offset: int
+ ↑
+
+ Returns
+ -------
+ list[Image.Image]
+ A list of processed masks
+ """
+ if not masks:
+ return []
+
+ if x_offset != 0 or y_offset != 0:
+ masks = [offset(m, x_offset, y_offset) for m in masks]
+
+ if kernel != 0:
+ masks = [dilate_erode(m, kernel) for m in masks]
+ masks = [m for m in masks if not is_all_black(m)]
+
+ return mask_merge_invert(masks, mode=merge_invert)
+
+
+# Bbox sorting
+def _key_left_to_right(bbox: list[float]) -> float:
+ """
+ Left to right
+
+ Parameters
+ ----------
+ bbox: list[float]
+ list of [x1, y1, x2, y2]
+ """
+ return bbox[0]
+
+
+def _key_center_to_edge(bbox: list[float], *, center: tuple[float, float]) -> float:
+ """
+ Center to edge
+
+ Parameters
+ ----------
+ bbox: list[float]
+ list of [x1, y1, x2, y2]
+ image: Image.Image
+ the image
+ """
+ bbox_center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2)
+ return dist(center, bbox_center)
+
+
+def _key_area(bbox: list[float]) -> float:
+ """
+ Large to small
+
+ Parameters
+ ----------
+ bbox: list[float]
+ list of [x1, y1, x2, y2]
+ """
+ return -bbox_area(bbox)
+
+
+def sort_bboxes(
+ pred: PredictOutput, order: int | SortBy = SortBy.NONE
+) -> PredictOutput:
+ if order == SortBy.NONE or len(pred.bboxes) <= 1:
+ return pred
+
+ if order == SortBy.LEFT_TO_RIGHT:
+ key = _key_left_to_right
+ elif order == SortBy.CENTER_TO_EDGE:
+ width, height = pred.preview.size
+ center = (width / 2, height / 2)
+ key = partial(_key_center_to_edge, center=center)
+ elif order == SortBy.AREA:
+ key = _key_area
+ else:
+ raise RuntimeError
+
+ items = len(pred.bboxes)
+ idx = sorted(range(items), key=lambda i: key(pred.bboxes[i]))
+ pred.bboxes = [pred.bboxes[i] for i in idx]
+ pred.masks = [pred.masks[i] for i in idx]
+ return pred
+
+
+# Filter by ratio
+def is_in_ratio(bbox: list[float], low: float, high: float, orig_area: int) -> bool:
+ area = bbox_area(bbox)
+ return low <= area / orig_area <= high
+
+
+def filter_by_ratio(pred: PredictOutput, low: float, high: float) -> PredictOutput:
+ if not pred.bboxes:
+ return pred
+
+ w, h = pred.preview.size
+ orig_area = w * h
+ items = len(pred.bboxes)
+ idx = [i for i in range(items) if is_in_ratio(pred.bboxes[i], low, high, orig_area)]
+ pred.bboxes = [pred.bboxes[i] for i in idx]
+ pred.masks = [pred.masks[i] for i in idx]
+ return pred
+
+
+def filter_k_largest(pred: PredictOutput, k: int = 0) -> PredictOutput:
+ if not pred.bboxes or k == 0:
+ return pred
+ areas = [bbox_area(bbox) for bbox in pred.bboxes]
+ idx = np.argsort(areas)[-k:]
+ idx = idx[::-1]
+ pred.bboxes = [pred.bboxes[i] for i in idx]
+ pred.masks = [pred.masks[i] for i in idx]
+ return pred
+
+
+# Merge / Invert
+def mask_merge(masks: list[Image.Image]) -> list[Image.Image]:
+ arrs = [np.array(m) for m in masks]
+ arr = reduce(cv2.bitwise_or, arrs)
+ return [Image.fromarray(arr)]
+
+
+def mask_invert(masks: list[Image.Image]) -> list[Image.Image]:
+ return [ImageChops.invert(m) for m in masks]
+
+
+def mask_merge_invert(
+ masks: list[Image.Image], mode: int | MergeInvert | str
+) -> list[Image.Image]:
+ if isinstance(mode, str):
+ mode = MASK_MERGE_INVERT.index(mode)
+
+ if mode == MergeInvert.NONE or not masks:
+ return masks
+
+ if mode == MergeInvert.MERGE:
+ return mask_merge(masks)
+
+ if mode == MergeInvert.MERGE_INVERT:
+ merged = mask_merge(masks)
+ return mask_invert(merged)
+
+ raise RuntimeError
diff --git a/adetailer/adetailer/mediapipe.py b/adetailer/adetailer/mediapipe.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd076191f765234991e17a09a071e78c4c934105
--- /dev/null
+++ b/adetailer/adetailer/mediapipe.py
@@ -0,0 +1,168 @@
+from __future__ import annotations
+
+from functools import partial
+
+import cv2
+import numpy as np
+from PIL import Image, ImageDraw
+
+from adetailer import PredictOutput
+from adetailer.common import create_bbox_from_mask, create_mask_from_bbox
+
+
+def mediapipe_predict(
+ model_type: str, image: Image.Image, confidence: float = 0.3
+) -> PredictOutput:
+ mapping = {
+ "mediapipe_face_short": partial(mediapipe_face_detection, 0),
+ "mediapipe_face_full": partial(mediapipe_face_detection, 1),
+ "mediapipe_face_mesh": mediapipe_face_mesh,
+ "mediapipe_face_mesh_eyes_only": mediapipe_face_mesh_eyes_only,
+ }
+ if model_type in mapping:
+ func = mapping[model_type]
+ return func(image, confidence)
+ msg = f"[-] ADetailer: Invalid mediapipe model type: {model_type}, Available: {list(mapping.keys())!r}"
+ raise RuntimeError(msg)
+
+
+def mediapipe_face_detection(
+ model_type: int, image: Image.Image, confidence: float = 0.3
+) -> PredictOutput:
+ import mediapipe as mp
+
+ img_width, img_height = image.size
+
+ mp_face_detection = mp.solutions.face_detection
+ draw_util = mp.solutions.drawing_utils
+
+ img_array = np.array(image)
+
+ with mp_face_detection.FaceDetection(
+ model_selection=model_type, min_detection_confidence=confidence
+ ) as face_detector:
+ pred = face_detector.process(img_array)
+
+ if pred.detections is None:
+ return PredictOutput()
+
+ preview_array = img_array.copy()
+
+ bboxes = []
+ for detection in pred.detections:
+ draw_util.draw_detection(preview_array, detection)
+
+ bbox = detection.location_data.relative_bounding_box
+ x1 = bbox.xmin * img_width
+ y1 = bbox.ymin * img_height
+ w = bbox.width * img_width
+ h = bbox.height * img_height
+ x2 = x1 + w
+ y2 = y1 + h
+
+ bboxes.append([x1, y1, x2, y2])
+
+ masks = create_mask_from_bbox(bboxes, image.size)
+ preview = Image.fromarray(preview_array)
+
+ return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
+
+
+def mediapipe_face_mesh(image: Image.Image, confidence: float = 0.3) -> PredictOutput:
+ import mediapipe as mp
+
+ mp_face_mesh = mp.solutions.face_mesh
+ draw_util = mp.solutions.drawing_utils
+ drawing_styles = mp.solutions.drawing_styles
+
+ w, h = image.size
+
+ with mp_face_mesh.FaceMesh(
+ static_image_mode=True, max_num_faces=20, min_detection_confidence=confidence
+ ) as face_mesh:
+ arr = np.array(image)
+ pred = face_mesh.process(arr)
+
+ if pred.multi_face_landmarks is None:
+ return PredictOutput()
+
+ preview = arr.copy()
+ masks = []
+
+ for landmarks in pred.multi_face_landmarks:
+ draw_util.draw_landmarks(
+ image=preview,
+ landmark_list=landmarks,
+ connections=mp_face_mesh.FACEMESH_TESSELATION,
+ landmark_drawing_spec=None,
+ connection_drawing_spec=drawing_styles.get_default_face_mesh_tesselation_style(),
+ )
+
+ points = np.intp([(land.x * w, land.y * h) for land in landmarks.landmark])
+ outline = cv2.convexHull(points).reshape(-1).tolist()
+
+ mask = Image.new("L", image.size, "black")
+ draw = ImageDraw.Draw(mask)
+ draw.polygon(outline, fill="white")
+ masks.append(mask)
+
+ bboxes = create_bbox_from_mask(masks, image.size)
+ preview = Image.fromarray(preview)
+ return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
+
+
+def mediapipe_face_mesh_eyes_only(
+ image: Image.Image, confidence: float = 0.3
+) -> PredictOutput:
+ import mediapipe as mp
+
+ mp_face_mesh = mp.solutions.face_mesh
+
+ left_idx = np.array(list(mp_face_mesh.FACEMESH_LEFT_EYE)).flatten()
+ right_idx = np.array(list(mp_face_mesh.FACEMESH_RIGHT_EYE)).flatten()
+
+ w, h = image.size
+
+ with mp_face_mesh.FaceMesh(
+ static_image_mode=True, max_num_faces=20, min_detection_confidence=confidence
+ ) as face_mesh:
+ arr = np.array(image)
+ pred = face_mesh.process(arr)
+
+ if pred.multi_face_landmarks is None:
+ return PredictOutput()
+
+ preview = image.copy()
+ masks = []
+
+ for landmarks in pred.multi_face_landmarks:
+ points = np.intp([(land.x * w, land.y * h) for land in landmarks.landmark])
+ left_eyes = points[left_idx]
+ right_eyes = points[right_idx]
+ left_outline = cv2.convexHull(left_eyes).reshape(-1).tolist()
+ right_outline = cv2.convexHull(right_eyes).reshape(-1).tolist()
+
+ mask = Image.new("L", image.size, "black")
+ draw = ImageDraw.Draw(mask)
+ for outline in (left_outline, right_outline):
+ draw.polygon(outline, fill="white")
+ masks.append(mask)
+
+ bboxes = create_bbox_from_mask(masks, image.size)
+ preview = draw_preview(preview, bboxes, masks)
+ return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
+
+
+def draw_preview(
+ preview: Image.Image, bboxes: list[list[int]], masks: list[Image.Image]
+) -> Image.Image:
+ red = Image.new("RGB", preview.size, "red")
+ for mask in masks:
+ masked = Image.composite(red, preview, mask)
+ preview = Image.blend(preview, masked, 0.25)
+
+ draw = ImageDraw.Draw(preview)
+ for bbox in bboxes:
+ draw.rectangle(bbox, outline="red", width=2)
+
+ return preview
diff --git a/adetailer/adetailer/traceback.py b/adetailer/adetailer/traceback.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d1848ed7b24ddeaa2aed40f305a939237ada2d
--- /dev/null
+++ b/adetailer/adetailer/traceback.py
@@ -0,0 +1,161 @@
+from __future__ import annotations
+
+import io
+import platform
+import sys
+from importlib.metadata import version
+from typing import Any, Callable
+
+from rich.console import Console, Group
+from rich.panel import Panel
+from rich.table import Table
+from rich.traceback import Traceback
+
+from adetailer.__version__ import __version__
+
+
+def processing(*args: Any) -> dict[str, Any]:
+ try:
+ from modules.processing import (
+ StableDiffusionProcessingImg2Img,
+ StableDiffusionProcessingTxt2Img,
+ )
+ except ImportError:
+ return {}
+
+ p = None
+ for arg in args:
+ if isinstance(
+ arg, (StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img)
+ ):
+ p = arg
+ break
+
+ if p is None:
+ return {}
+
+ info = {
+ "prompt": p.prompt,
+ "negative_prompt": p.negative_prompt,
+ "n_iter": p.n_iter,
+ "batch_size": p.batch_size,
+ "width": p.width,
+ "height": p.height,
+ "sampler_name": p.sampler_name,
+ "enable_hr": getattr(p, "enable_hr", False),
+ "hr_upscaler": getattr(p, "hr_upscaler", ""),
+ }
+
+ info.update(sd_models())
+ return info
+
+
+def sd_models() -> dict[str, str]:
+ try:
+ from modules import shared
+
+ opts = shared.opts
+ except Exception:
+ return {}
+
+ return {
+ "checkpoint": getattr(opts, "sd_model_checkpoint", "------"),
+ "vae": getattr(opts, "sd_vae", "------"),
+ "unet": getattr(opts, "sd_unet", "------"),
+ }
+
+
+def ad_args(*args: Any) -> dict[str, Any]:
+ ad_args = [
+ arg
+ for arg in args
+ if isinstance(arg, dict) and arg.get("ad_model", "None") != "None"
+ ]
+ if not ad_args:
+ return {}
+
+ arg0 = ad_args[0]
+ is_api = arg0.get("is_api", True)
+ return {
+ "version": __version__,
+ "ad_model": arg0["ad_model"],
+ "ad_prompt": arg0.get("ad_prompt", ""),
+ "ad_negative_prompt": arg0.get("ad_negative_prompt", ""),
+ "ad_controlnet_model": arg0.get("ad_controlnet_model", "None"),
+ "is_api": type(is_api) is not tuple,
+ }
+
+
+def library_version():
+ libraries = ["torch", "torchvision", "ultralytics", "mediapipe"]
+ d = {}
+ for lib in libraries:
+ try:
+ d[lib] = version(lib)
+ except Exception: # noqa: PERF203
+ d[lib] = "Unknown"
+ return d
+
+
+def sys_info() -> dict[str, Any]:
+ try:
+ import launch
+
+ version = launch.git_tag()
+ commit = launch.commit_hash()
+ except Exception:
+ version = "Unknown (too old or vladmandic)"
+ commit = "Unknown"
+
+ return {
+ "Platform": platform.platform(),
+ "Python": sys.version,
+ "Version": version,
+ "Commit": commit,
+ "Commandline": sys.argv,
+ "Libraries": library_version(),
+ }
+
+
+def get_table(title: str, data: dict[str, Any]) -> Table:
+ table = Table(title=title, highlight=True)
+ table.add_column(" ", justify="right", style="dim")
+ table.add_column("Value")
+ for key, value in data.items():
+ if not isinstance(value, str):
+ value = repr(value)
+ table.add_row(key, value)
+
+ return table
+
+
+def rich_traceback(func: Callable) -> Callable:
+ def wrapper(*args, **kwargs):
+ string = io.StringIO()
+ width = Console().width
+ width = width - 4 if width > 4 else None
+ console = Console(file=string, width=width)
+ try:
+ return func(*args, **kwargs)
+ except Exception as e:
+ tables = [
+ get_table(title, data)
+ for title, data in [
+ ("System info", sys_info()),
+ ("Inputs", processing(*args)),
+ ("ADetailer", ad_args(*args)),
+ ]
+ if data
+ ]
+ tables.append(Traceback(extra_lines=1))
+
+ console.print(Panel(Group(*tables)))
+ output = "\n" + string.getvalue()
+
+ try:
+ error = e.__class__(output)
+ except Exception:
+ error = RuntimeError(output)
+ raise error from None
+
+ return wrapper
diff --git a/adetailer/adetailer/ui.py b/adetailer/adetailer/ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6318ffab37d2d7570d52cc5fb264b06752e5904
--- /dev/null
+++ b/adetailer/adetailer/ui.py
@@ -0,0 +1,640 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from functools import partial
+from types import SimpleNamespace
+from typing import Any
+
+import gradio as gr
+
+from adetailer import AFTER_DETAILER, __version__
+from adetailer.args import ALL_ARGS, MASK_MERGE_INVERT
+from controlnet_ext import controlnet_exists, get_cn_models
+
+cn_module_choices = {
+ "inpaint": [
+ "inpaint_global_harmonious",
+ "inpaint_only",
+ "inpaint_only+lama",
+ ],
+ "lineart": [
+ "lineart_coarse",
+ "lineart_realistic",
+ "lineart_anime",
+ "lineart_anime_denoise",
+ ],
+ "openpose": ["openpose_full", "dw_openpose_full"],
+ "tile": ["tile_resample", "tile_colorfix", "tile_colorfix+sharp"],
+ "scribble": ["t2ia_sketch_pidi"],
+ "depth": ["depth_midas", "depth_hand_refiner"],
+}
+
+
+class Widgets(SimpleNamespace):
+ def tolist(self):
+ return [getattr(self, attr) for attr in ALL_ARGS.attrs]
+
+
+@dataclass
+class WebuiInfo:
+ ad_model_list: list[str]
+ sampler_names: list[str]
+ t2i_button: gr.Button
+ i2i_button: gr.Button
+ checkpoints_list: list[str]
+ vae_list: list[str]
+
+
+def gr_interactive(value: bool = True):
+ return gr.update(interactive=value)
+
+
+def ordinal(n: int) -> str:
+ d = {1: "st", 2: "nd", 3: "rd"}
+ return str(n) + ("th" if 11 <= n % 100 <= 13 else d.get(n % 10, "th"))
+
+
+def suffix(n: int, c: str = " ") -> str:
+ return "" if n == 0 else c + ordinal(n + 1)
+
+
+def on_widget_change(state: dict, value: Any, *, attr: str):
+ if "is_api" in state:
+ state = state.copy()
+ state.pop("is_api")
+ state[attr] = value
+ return state
+
+
+def on_generate_click(state: dict, *values: Any):
+ for attr, value in zip(ALL_ARGS.attrs, values):
+ state[attr] = value
+ state["is_api"] = ()
+ return state
+
+
+def on_cn_model_update(cn_model_name: str):
+ cn_model_name = cn_model_name.replace("inpaint_depth", "depth")
+ for t in cn_module_choices:
+ if t in cn_model_name:
+ choices = cn_module_choices[t]
+ return gr.update(visible=True, choices=choices, value=choices[0])
+ return gr.update(visible=False, choices=["None"], value="None")
+
+
+def elem_id(item_id: str, n: int, is_img2img: bool) -> str:
+ tap = "img2img" if is_img2img else "txt2img"
+ suf = suffix(n, "_")
+ return f"script_{tap}_adetailer_{item_id}{suf}"
+
+
+def state_init(w: Widgets) -> dict[str, Any]:
+ return {attr: getattr(w, attr).value for attr in ALL_ARGS.attrs}
+
+
+def adui(
+ num_models: int,
+ is_img2img: bool,
+ webui_info: WebuiInfo,
+):
+ states = []
+ infotext_fields = []
+ eid = partial(elem_id, n=0, is_img2img=is_img2img)
+
+ with gr.Accordion(AFTER_DETAILER, open=False, elem_id=eid("ad_main_accordion")):
+ with gr.Row():
+ with gr.Column(scale=6):
+ ad_enable = gr.Checkbox(
+ label="Enable ADetailer",
+ value=False,
+ visible=True,
+ elem_id=eid("ad_enable"),
+ )
+
+ with gr.Column(scale=6):
+ ad_skip_img2img = gr.Checkbox(
+ label="Skip img2img",
+ value=False,
+ visible=is_img2img,
+ elem_id=eid("ad_skip_img2img"),
+ )
+
+ with gr.Column(scale=1, min_width=180):
+ gr.Markdown(
+ f"v{__version__}",
+ elem_id=eid("ad_version"),
+ )
+
+ infotext_fields.append((ad_enable, "ADetailer enable"))
+ infotext_fields.append((ad_skip_img2img, "ADetailer skip img2img"))
+
+ with gr.Group(), gr.Tabs():
+ for n in range(num_models):
+ with gr.Tab(ordinal(n + 1)):
+ state, infofields = one_ui_group(
+ n=n,
+ is_img2img=is_img2img,
+ webui_info=webui_info,
+ )
+
+ states.append(state)
+ infotext_fields.extend(infofields)
+
+ # components: [bool, dict, dict, ...]
+ components = [ad_enable, ad_skip_img2img, *states]
+ return components, infotext_fields
+
+
+def one_ui_group(n: int, is_img2img: bool, webui_info: WebuiInfo):
+ w = Widgets()
+ eid = partial(elem_id, n=n, is_img2img=is_img2img)
+
+ with gr.Row():
+ model_choices = (
+ [*webui_info.ad_model_list, "None"]
+ if n == 0
+ else ["None", *webui_info.ad_model_list]
+ )
+
+ w.ad_model = gr.Dropdown(
+ label="ADetailer model" + suffix(n),
+ choices=model_choices,
+ value=model_choices[0],
+ visible=True,
+ type="value",
+ elem_id=eid("ad_model"),
+ )
+
+ with gr.Group():
+ with gr.Row(elem_id=eid("ad_toprow_prompt")):
+ w.ad_prompt = gr.Textbox(
+ label="ad_prompt" + suffix(n),
+ show_label=False,
+ lines=3,
+ placeholder="ADetailer prompt"
+ + suffix(n)
+ + "\nIf blank, the main prompt is used.",
+ elem_id=eid("ad_prompt"),
+ )
+
+ with gr.Row(elem_id=eid("ad_toprow_negative_prompt")):
+ w.ad_negative_prompt = gr.Textbox(
+ label="ad_negative_prompt" + suffix(n),
+ show_label=False,
+ lines=2,
+ placeholder="ADetailer negative prompt"
+ + suffix(n)
+ + "\nIf blank, the main negative prompt is used.",
+ elem_id=eid("ad_negative_prompt"),
+ )
+
+ with gr.Group():
+ with gr.Accordion(
+ "Detection", open=False, elem_id=eid("ad_detection_accordion")
+ ):
+ detection(w, n, is_img2img)
+
+ with gr.Accordion(
+ "Mask Preprocessing",
+ open=False,
+ elem_id=eid("ad_mask_preprocessing_accordion"),
+ ):
+ mask_preprocessing(w, n, is_img2img)
+
+ with gr.Accordion(
+ "Inpainting", open=False, elem_id=eid("ad_inpainting_accordion")
+ ):
+ inpainting(w, n, is_img2img, webui_info)
+
+ with gr.Group():
+ controlnet(w, n, is_img2img)
+
+ state = gr.State(lambda: state_init(w))
+
+ for attr in ALL_ARGS.attrs:
+ widget = getattr(w, attr)
+ on_change = partial(on_widget_change, attr=attr)
+ widget.change(fn=on_change, inputs=[state, widget], outputs=state, queue=False)
+
+ all_inputs = [state, *w.tolist()]
+ target_button = webui_info.i2i_button if is_img2img else webui_info.t2i_button
+ target_button.click(
+ fn=on_generate_click, inputs=all_inputs, outputs=state, queue=False
+ )
+
+ infotext_fields = [(getattr(w, attr), name + suffix(n)) for attr, name in ALL_ARGS]
+
+ return state, infotext_fields
+
+
+def detection(w: Widgets, n: int, is_img2img: bool):
+ eid = partial(elem_id, n=n, is_img2img=is_img2img)
+
+ with gr.Row():
+ with gr.Column(variant="compact"):
+ w.ad_confidence = gr.Slider(
+ label="Detection model confidence threshold" + suffix(n),
+ minimum=0.0,
+ maximum=1.0,
+ step=0.01,
+ value=0.3,
+ visible=True,
+ elem_id=eid("ad_confidence"),
+ )
+ w.ad_mask_k_largest = gr.Slider(
+ label="Mask only the top k largest (0 to disable)" + suffix(n),
+ minimum=0,
+ maximum=10,
+ step=1,
+ value=0,
+ visible=True,
+ elem_id=eid("ad_mask_k_largest"),
+ )
+
+ with gr.Column(variant="compact"):
+ w.ad_mask_min_ratio = gr.Slider(
+ label="Mask min area ratio" + suffix(n),
+ minimum=0.0,
+ maximum=1.0,
+ step=0.001,
+ value=0.0,
+ visible=True,
+ elem_id=eid("ad_mask_min_ratio"),
+ )
+ w.ad_mask_max_ratio = gr.Slider(
+ label="Mask max area ratio" + suffix(n),
+ minimum=0.0,
+ maximum=1.0,
+ step=0.001,
+ value=1.0,
+ visible=True,
+ elem_id=eid("ad_mask_max_ratio"),
+ )
+
+
+def mask_preprocessing(w: Widgets, n: int, is_img2img: bool):
+ eid = partial(elem_id, n=n, is_img2img=is_img2img)
+
+ with gr.Group():
+ with gr.Row():
+ with gr.Column(variant="compact"):
+ w.ad_x_offset = gr.Slider(
+ label="Mask x(→) offset" + suffix(n),
+ minimum=-200,
+ maximum=200,
+ step=1,
+ value=0,
+ visible=True,
+ elem_id=eid("ad_x_offset"),
+ )
+ w.ad_y_offset = gr.Slider(
+ label="Mask y(↑) offset" + suffix(n),
+ minimum=-200,
+ maximum=200,
+ step=1,
+ value=0,
+ visible=True,
+ elem_id=eid("ad_y_offset"),
+ )
+
+ with gr.Column(variant="compact"):
+ w.ad_dilate_erode = gr.Slider(
+ label="Mask erosion (-) / dilation (+)" + suffix(n),
+ minimum=-128,
+ maximum=128,
+ step=4,
+ value=4,
+ visible=True,
+ elem_id=eid("ad_dilate_erode"),
+ )
+
+ with gr.Row():
+ w.ad_mask_merge_invert = gr.Radio(
+ label="Mask merge mode" + suffix(n),
+ choices=MASK_MERGE_INVERT,
+ value="None",
+ elem_id=eid("ad_mask_merge_invert"),
+ )
+
+
+def inpainting(w: Widgets, n: int, is_img2img: bool, webui_info: WebuiInfo):
+ eid = partial(elem_id, n=n, is_img2img=is_img2img)
+
+ with gr.Group():
+ with gr.Row():
+ w.ad_mask_blur = gr.Slider(
+ label="Inpaint mask blur" + suffix(n),
+ minimum=0,
+ maximum=64,
+ step=1,
+ value=4,
+ visible=True,
+ elem_id=eid("ad_mask_blur"),
+ )
+
+ w.ad_denoising_strength = gr.Slider(
+ label="Inpaint denoising strength" + suffix(n),
+ minimum=0.0,
+ maximum=1.0,
+ step=0.01,
+ value=0.4,
+ visible=True,
+ elem_id=eid("ad_denoising_strength"),
+ )
+
+ with gr.Row():
+ with gr.Column(variant="compact"):
+ w.ad_inpaint_only_masked = gr.Checkbox(
+ label="Inpaint only masked" + suffix(n),
+ value=True,
+ visible=True,
+ elem_id=eid("ad_inpaint_only_masked"),
+ )
+ w.ad_inpaint_only_masked_padding = gr.Slider(
+ label="Inpaint only masked padding, pixels" + suffix(n),
+ minimum=0,
+ maximum=256,
+ step=4,
+ value=32,
+ visible=True,
+ elem_id=eid("ad_inpaint_only_masked_padding"),
+ )
+
+ w.ad_inpaint_only_masked.change(
+ gr_interactive,
+ inputs=w.ad_inpaint_only_masked,
+ outputs=w.ad_inpaint_only_masked_padding,
+ queue=False,
+ )
+
+ with gr.Column(variant="compact"):
+ w.ad_use_inpaint_width_height = gr.Checkbox(
+ label="Use separate width/height" + suffix(n),
+ value=False,
+ visible=True,
+ elem_id=eid("ad_use_inpaint_width_height"),
+ )
+
+ w.ad_inpaint_width = gr.Slider(
+ label="inpaint width" + suffix(n),
+ minimum=64,
+ maximum=2048,
+ step=4,
+ value=512,
+ visible=True,
+ elem_id=eid("ad_inpaint_width"),
+ )
+
+ w.ad_inpaint_height = gr.Slider(
+ label="inpaint height" + suffix(n),
+ minimum=64,
+ maximum=2048,
+ step=4,
+ value=512,
+ visible=True,
+ elem_id=eid("ad_inpaint_height"),
+ )
+
+ w.ad_use_inpaint_width_height.change(
+ lambda value: (gr_interactive(value), gr_interactive(value)),
+ inputs=w.ad_use_inpaint_width_height,
+ outputs=[w.ad_inpaint_width, w.ad_inpaint_height],
+ queue=False,
+ )
+
+ with gr.Row():
+ with gr.Column(variant="compact"):
+ w.ad_use_steps = gr.Checkbox(
+ label="Use separate steps" + suffix(n),
+ value=False,
+ visible=True,
+ elem_id=eid("ad_use_steps"),
+ )
+
+ w.ad_steps = gr.Slider(
+ label="ADetailer steps" + suffix(n),
+ minimum=1,
+ maximum=150,
+ step=1,
+ value=28,
+ visible=True,
+ elem_id=eid("ad_steps"),
+ )
+
+ w.ad_use_steps.change(
+ gr_interactive,
+ inputs=w.ad_use_steps,
+ outputs=w.ad_steps,
+ queue=False,
+ )
+
+ with gr.Column(variant="compact"):
+ w.ad_use_cfg_scale = gr.Checkbox(
+ label="Use separate CFG scale" + suffix(n),
+ value=False,
+ visible=True,
+ elem_id=eid("ad_use_cfg_scale"),
+ )
+
+ w.ad_cfg_scale = gr.Slider(
+ label="ADetailer CFG scale" + suffix(n),
+ minimum=0.0,
+ maximum=30.0,
+ step=0.5,
+ value=7.0,
+ visible=True,
+ elem_id=eid("ad_cfg_scale"),
+ )
+
+ w.ad_use_cfg_scale.change(
+ gr_interactive,
+ inputs=w.ad_use_cfg_scale,
+ outputs=w.ad_cfg_scale,
+ queue=False,
+ )
+
+ with gr.Row():
+ with gr.Column(variant="compact"):
+ w.ad_use_checkpoint = gr.Checkbox(
+ label="Use separate checkpoint" + suffix(n),
+ value=False,
+ visible=True,
+ elem_id=eid("ad_use_checkpoint"),
+ )
+
+ ckpts = ["Use same checkpoint", *webui_info.checkpoints_list]
+
+ w.ad_checkpoint = gr.Dropdown(
+ label="ADetailer checkpoint" + suffix(n),
+ choices=ckpts,
+ value=ckpts[0],
+ visible=True,
+ elem_id=eid("ad_checkpoint"),
+ )
+
+ with gr.Column(variant="compact"):
+ w.ad_use_vae = gr.Checkbox(
+ label="Use separate VAE" + suffix(n),
+ value=False,
+ visible=True,
+ elem_id=eid("ad_use_vae"),
+ )
+
+ vaes = ["Use same VAE", *webui_info.vae_list]
+
+ w.ad_vae = gr.Dropdown(
+ label="ADetailer VAE" + suffix(n),
+ choices=vaes,
+ value=vaes[0],
+ visible=True,
+ elem_id=eid("ad_vae"),
+ )
+
+ with gr.Row(), gr.Column(variant="compact"):
+ w.ad_use_sampler = gr.Checkbox(
+ label="Use separate sampler" + suffix(n),
+ value=False,
+ visible=True,
+ elem_id=eid("ad_use_sampler"),
+ )
+
+ w.ad_sampler = gr.Dropdown(
+ label="ADetailer sampler" + suffix(n),
+ choices=webui_info.sampler_names,
+ value=webui_info.sampler_names[0],
+ visible=True,
+ elem_id=eid("ad_sampler"),
+ )
+
+ w.ad_use_sampler.change(
+ gr_interactive,
+ inputs=w.ad_use_sampler,
+ outputs=w.ad_sampler,
+ queue=False,
+ )
+
+ with gr.Row():
+ with gr.Column(variant="compact"):
+ w.ad_use_noise_multiplier = gr.Checkbox(
+ label="Use separate noise multiplier" + suffix(n),
+ value=False,
+ visible=True,
+ elem_id=eid("ad_use_noise_multiplier"),
+ )
+
+ w.ad_noise_multiplier = gr.Slider(
+ label="Noise multiplier for img2img" + suffix(n),
+ minimum=0.5,
+ maximum=1.5,
+ step=0.01,
+ value=1.0,
+ visible=True,
+ elem_id=eid("ad_noise_multiplier"),
+ )
+
+ w.ad_use_noise_multiplier.change(
+ gr_interactive,
+ inputs=w.ad_use_noise_multiplier,
+ outputs=w.ad_noise_multiplier,
+ queue=False,
+ )
+
+ with gr.Column(variant="compact"):
+ w.ad_use_clip_skip = gr.Checkbox(
+ label="Use separate CLIP skip" + suffix(n),
+ value=False,
+ visible=True,
+ elem_id=eid("ad_use_clip_skip"),
+ )
+
+ w.ad_clip_skip = gr.Slider(
+ label="ADetailer CLIP skip" + suffix(n),
+ minimum=1,
+ maximum=12,
+ step=1,
+ value=1,
+ visible=True,
+ elem_id=eid("ad_clip_skip"),
+ )
+
+ w.ad_use_clip_skip.change(
+ gr_interactive,
+ inputs=w.ad_use_clip_skip,
+ outputs=w.ad_clip_skip,
+ queue=False,
+ )
+
+ with gr.Row(), gr.Column(variant="compact"):
+ w.ad_restore_face = gr.Checkbox(
+ label="Restore faces after ADetailer" + suffix(n),
+ value=False,
+ elem_id=eid("ad_restore_face"),
+ )
+
+
+def controlnet(w: Widgets, n: int, is_img2img: bool):
+ eid = partial(elem_id, n=n, is_img2img=is_img2img)
+ cn_models = ["None", "Passthrough", *get_cn_models()]
+
+ with gr.Row(variant="panel"):
+ with gr.Column(variant="compact"):
+ w.ad_controlnet_model = gr.Dropdown(
+ label="ControlNet model" + suffix(n),
+ choices=cn_models,
+ value="None",
+ visible=True,
+ type="value",
+ interactive=controlnet_exists,
+ elem_id=eid("ad_controlnet_model"),
+ )
+
+ w.ad_controlnet_module = gr.Dropdown(
+ label="ControlNet module" + suffix(n),
+ choices=["None"],
+ value="None",
+ visible=False,
+ type="value",
+ interactive=controlnet_exists,
+ elem_id=eid("ad_controlnet_module"),
+ )
+
+ w.ad_controlnet_weight = gr.Slider(
+ label="ControlNet weight" + suffix(n),
+ minimum=0.0,
+ maximum=1.0,
+ step=0.01,
+ value=1.0,
+ visible=True,
+ interactive=controlnet_exists,
+ elem_id=eid("ad_controlnet_weight"),
+ )
+
+ w.ad_controlnet_model.change(
+ on_cn_model_update,
+ inputs=w.ad_controlnet_model,
+ outputs=w.ad_controlnet_module,
+ queue=False,
+ )
+
+ with gr.Column(variant="compact"):
+ w.ad_controlnet_guidance_start = gr.Slider(
+ label="ControlNet guidance start" + suffix(n),
+ minimum=0.0,
+ maximum=1.0,
+ step=0.01,
+ value=0.0,
+ visible=True,
+ interactive=controlnet_exists,
+ elem_id=eid("ad_controlnet_guidance_start"),
+ )
+
+ w.ad_controlnet_guidance_end = gr.Slider(
+ label="ControlNet guidance end" + suffix(n),
+ minimum=0.0,
+ maximum=1.0,
+ step=0.01,
+ value=1.0,
+ visible=True,
+ interactive=controlnet_exists,
+ elem_id=eid("ad_controlnet_guidance_end"),
+ )
diff --git a/adetailer/adetailer/ultralytics.py b/adetailer/adetailer/ultralytics.py
new file mode 100644
index 0000000000000000000000000000000000000000..36062b7893a786a818cf75ab32370129e1325f0e
--- /dev/null
+++ b/adetailer/adetailer/ultralytics.py
@@ -0,0 +1,51 @@
+from __future__ import annotations
+
+from pathlib import Path
+
+import cv2
+from PIL import Image
+from torchvision.transforms.functional import to_pil_image
+
+from adetailer import PredictOutput
+from adetailer.common import create_mask_from_bbox
+
+
+def ultralytics_predict(
+ model_path: str | Path,
+ image: Image.Image,
+ confidence: float = 0.3,
+ device: str = "",
+) -> PredictOutput:
+ from ultralytics import YOLO
+
+ model = YOLO(model_path)
+ pred = model(image, conf=confidence, device=device)
+
+ bboxes = pred[0].boxes.xyxy.cpu().numpy()
+ if bboxes.size == 0:
+ return PredictOutput()
+ bboxes = bboxes.tolist()
+
+ if pred[0].masks is None:
+ masks = create_mask_from_bbox(bboxes, image.size)
+ else:
+ masks = mask_to_pil(pred[0].masks.data, image.size)
+ preview = pred[0].plot()
+ preview = cv2.cvtColor(preview, cv2.COLOR_BGR2RGB)
+ preview = Image.fromarray(preview)
+
+ return PredictOutput(bboxes=bboxes, masks=masks, preview=preview)
+
+
+def mask_to_pil(masks, shape: tuple[int, int]) -> list[Image.Image]:
+ """
+ Parameters
+ ----------
+ masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W).
+ The device can be CUDA, but `to_pil_image` takes care of that.
+
+ shape: tuple[int, int]
+ (width, height) of the original image
+ """
+ n = masks.shape[0]
+ return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)]
diff --git a/adetailer/controlnet_ext/__init__.py b/adetailer/controlnet_ext/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ab666835157561426d684d798735e724a5a4dbe
--- /dev/null
+++ b/adetailer/controlnet_ext/__init__.py
@@ -0,0 +1,7 @@
+from .controlnet_ext import ControlNetExt, controlnet_exists, get_cn_models
+
+__all__ = [
+ "ControlNetExt",
+ "controlnet_exists",
+ "get_cn_models",
+]
diff --git a/adetailer/controlnet_ext/controlnet_ext.py b/adetailer/controlnet_ext/controlnet_ext.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f54f1289599acf0260d0b133b6f90b926e59f8c
--- /dev/null
+++ b/adetailer/controlnet_ext/controlnet_ext.py
@@ -0,0 +1,167 @@
+from __future__ import annotations
+
+import importlib
+import re
+import sys
+from functools import lru_cache
+from pathlib import Path
+from textwrap import dedent
+
+from modules import extensions, sd_models, shared
+
+try:
+ from modules.paths import extensions_builtin_dir, extensions_dir, models_path
+except ImportError as e:
+ msg = """
+ [-] ADetailer: `stable-diffusion-webui < 1.1.0` is no longer supported.
+ Please upgrade to stable-diffusion-webui >= 1.1.0.
+ or you can use ADetailer v23.10.1 (https://github.com/Bing-su/adetailer/archive/refs/tags/v23.10.1.zip)
+ """
+ raise RuntimeError(dedent(msg)) from e
+
+ext_path = Path(extensions_dir)
+ext_builtin_path = Path(extensions_builtin_dir)
+controlnet_exists = False
+controlnet_path = None
+cn_base_path = ""
+
+for extension in extensions.active():
+ if not extension.enabled:
+ continue
+ # For cases like sd-webui-controlnet-master
+ if "sd-webui-controlnet" in extension.name:
+ controlnet_exists = True
+ controlnet_path = Path(extension.path)
+ cn_base_path = ".".join(controlnet_path.parts[-2:])
+ break
+
+if controlnet_path is not None:
+ sd_webui_controlnet_path = controlnet_path.resolve().parent
+ if sd_webui_controlnet_path.stem in ("extensions", "extensions-builtin"):
+ target_path = str(sd_webui_controlnet_path.parent)
+ if target_path not in sys.path:
+ sys.path.append(target_path)
+
+cn_model_module = {
+ "inpaint": "inpaint_global_harmonious",
+ "scribble": "t2ia_sketch_pidi",
+ "lineart": "lineart_coarse",
+ "openpose": "openpose_full",
+ "tile": "tile_resample",
+ "depth": "depth_midas",
+}
+cn_model_regex = re.compile("|".join(cn_model_module.keys()))
+
+
+class ControlNetExt:
+ def __init__(self):
+ self.cn_models = ["None"]
+ self.cn_available = False
+ self.external_cn = None
+
+ def init_controlnet(self):
+ import_path = cn_base_path + ".scripts.external_code"
+
+ self.external_cn = importlib.import_module(import_path, "external_code")
+ self.cn_available = True
+ models = self.external_cn.get_models()
+ self.cn_models.extend(m for m in models if cn_model_regex.search(m))
+
+ def update_scripts_args(
+ self,
+ p,
+ model: str,
+ module: str | None,
+ weight: float,
+ guidance_start: float,
+ guidance_end: float,
+ ):
+ if (not self.cn_available) or model == "None":
+ return
+
+ if module is None or module == "None":
+ for m, v in cn_model_module.items():
+ if m in model:
+ module = v
+ break
+ else:
+ module = None
+
+ cn_units = [
+ self.external_cn.ControlNetUnit(
+ model=model,
+ weight=weight,
+ control_mode=self.external_cn.ControlMode.BALANCED,
+ module=module,
+ guidance_start=guidance_start,
+ guidance_end=guidance_end,
+ pixel_perfect=True,
+ )
+ ]
+
+ try:
+ self.external_cn.update_cn_script_in_processing(p, cn_units)
+ except AttributeError as e:
+ if "script_args_value" not in str(e):
+ raise
+ msg = "[-] Adetailer: ControlNet option not available in WEBUI version lower than 1.6.0 due to updates in ControlNet"
+ raise RuntimeError(msg) from e
+
+
+def get_cn_model_dirs() -> list[Path]:
+ cn_model_dir = Path(models_path, "ControlNet")
+ if controlnet_path is not None:
+ cn_model_dir_old = controlnet_path.joinpath("models")
+ else:
+ cn_model_dir_old = None
+ ext_dir1 = shared.opts.data.get("control_net_models_path", "")
+ ext_dir2 = getattr(shared.cmd_opts, "controlnet_dir", "")
+
+ dirs = [cn_model_dir]
+ dirs += [
+ Path(ext_dir) for ext_dir in [cn_model_dir_old, ext_dir1, ext_dir2] if ext_dir
+ ]
+
+ return dirs
+
+
+@lru_cache
+def _get_cn_models() -> list[str]:
+ """
+ Since we can't import ControlNet, we use a function that does something like
+ controlnet's `list(global_state.cn_models_names.values())`.
+ """
+ cn_model_exts = (".pt", ".pth", ".ckpt", ".safetensors")
+ dirs = get_cn_model_dirs()
+ name_filter = shared.opts.data.get("control_net_models_name_filter", "")
+ name_filter = name_filter.strip(" ").lower()
+
+ model_paths = []
+
+ for base in dirs:
+ if not base.exists():
+ continue
+
+ for p in base.rglob("*"):
+ if (
+ p.is_file()
+ and p.suffix in cn_model_exts
+ and cn_model_regex.search(p.name)
+ ):
+ if name_filter and name_filter not in p.name.lower():
+ continue
+ model_paths.append(p)
+ model_paths.sort(key=lambda p: p.name)
+
+ models = []
+ for p in model_paths:
+ model_hash = sd_models.model_hash(p)
+ name = f"{p.stem} [{model_hash}]"
+ models.append(name)
+ return models
+
+
+def get_cn_models() -> list[str]:
+ if controlnet_exists:
+ return _get_cn_models()
+ return []
diff --git a/adetailer/controlnet_ext/restore.py b/adetailer/controlnet_ext/restore.py
new file mode 100644
index 0000000000000000000000000000000000000000..152ffd0a9732b5512e4440b795a1ad89bcb27cab
--- /dev/null
+++ b/adetailer/controlnet_ext/restore.py
@@ -0,0 +1,43 @@
+from __future__ import annotations
+
+from contextlib import contextmanager
+
+from modules import img2img, processing, shared
+
+
+class CNHijackRestore:
+ def __init__(self):
+ self.process = hasattr(processing, "__controlnet_original_process_images_inner")
+ self.img2img = hasattr(img2img, "__controlnet_original_process_batch")
+
+ def __enter__(self):
+ if self.process:
+ self.orig_process = processing.process_images_inner
+ processing.process_images_inner = getattr(
+ processing, "__controlnet_original_process_images_inner"
+ )
+ if self.img2img:
+ self.orig_img2img = img2img.process_batch
+ img2img.process_batch = getattr(
+ img2img, "__controlnet_original_process_batch"
+ )
+
+ def __exit__(self, *args, **kwargs):
+ if self.process:
+ processing.process_images_inner = self.orig_process
+ if self.img2img:
+ img2img.process_batch = self.orig_img2img
+
+
+@contextmanager
+def cn_allow_script_control():
+ orig = False
+ if "control_net_allow_script_control" in shared.opts.data:
+ try:
+ orig = shared.opts.data["control_net_allow_script_control"]
+ shared.opts.data["control_net_allow_script_control"] = True
+ yield
+ finally:
+ shared.opts.data["control_net_allow_script_control"] = orig
+ else:
+ yield
diff --git a/adetailer/install.py b/adetailer/install.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc7325d64ed0cd183c61f0e3fc3162516777316b
--- /dev/null
+++ b/adetailer/install.py
@@ -0,0 +1,76 @@
+from __future__ import annotations
+
+import importlib.util
+import subprocess
+import sys
+from importlib.metadata import version # python >= 3.8
+
+from packaging.version import parse
+
+import_name = {"py-cpuinfo": "cpuinfo", "protobuf": "google.protobuf"}
+
+
+def is_installed(
+ package: str, min_version: str | None = None, max_version: str | None = None
+):
+ name = import_name.get(package, package)
+ try:
+ spec = importlib.util.find_spec(name)
+ except ModuleNotFoundError:
+ return False
+
+ if spec is None:
+ return False
+
+ if not min_version and not max_version:
+ return True
+
+ if not min_version:
+ min_version = "0.0.0"
+ if not max_version:
+ max_version = "99999999.99999999.99999999"
+
+ try:
+ pkg_version = version(package)
+ return parse(min_version) <= parse(pkg_version) <= parse(max_version)
+ except Exception:
+ return False
+
+
+def run_pip(*args):
+ subprocess.run([sys.executable, "-m", "pip", "install", *args])
+
+
+def install():
+ deps = [
+ # requirements
+ ("ultralytics", "8.1.0", None),
+ ("mediapipe", "0.10.9", None),
+ ("rich", "13.0.0", None),
+ # mediapipe
+ ("protobuf", "3.20", "3.9999"),
+ ]
+
+ for pkg, low, high in deps:
+ if not is_installed(pkg, low, high):
+ if low and high:
+ cmd = f"{pkg}>={low},<={high}"
+ elif low:
+ cmd = f"{pkg}>={low}"
+ elif high:
+ cmd = f"{pkg}<={high}"
+ else:
+ cmd = pkg
+
+ run_pip("-U", cmd)
+
+
+try:
+ import launch
+
+ skip_install = launch.args.skip_install
+except Exception:
+ skip_install = False
+
+if not skip_install:
+ install()
diff --git a/adetailer/preload.py b/adetailer/preload.py
new file mode 100644
index 0000000000000000000000000000000000000000..10be161f22b0a5ef7083609829a21b547eae9aea
--- /dev/null
+++ b/adetailer/preload.py
@@ -0,0 +1,9 @@
+import argparse
+
+
+def preload(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--ad-no-huggingface",
+ action="store_true",
+ help="Don't use adetailer models from huggingface",
+ )
diff --git a/adetailer/pyproject.toml b/adetailer/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..31b63421ce4f573d5b0bafc8061fe080480555b6
--- /dev/null
+++ b/adetailer/pyproject.toml
@@ -0,0 +1,42 @@
+[project]
+name = "adetailer"
+description = "An object detection and auto-mask extension for stable diffusion webui."
+authors = [{ name = "dowon", email = "ks2515@naver.com" }]
+requires-python = ">=3.8,<3.12"
+readme = "README.md"
+license = { text = "AGPL-3.0" }
+
+[project.urls]
+repository = "https://github.com/Bing-su/adetailer"
+
+[tool.isort]
+profile = "black"
+known_first_party = ["launch", "modules"]
+
+[tool.ruff]
+select = [
+ "A",
+ "B",
+ "C4",
+ "C90",
+ "E",
+ "EM",
+ "F",
+ "FA",
+ "I001",
+ "ISC",
+ "N",
+ "PERF",
+ "PIE",
+ "PT",
+ "PTH",
+ "RET",
+ "RUF",
+ "SIM",
+ "UP",
+ "W",
+]
+ignore = ["B008", "B905", "E501", "F401", "UP007"]
+
+[tool.ruff.isort]
+known-first-party = ["launch", "modules"]
diff --git a/adetailer/scripts/!adetailer.py b/adetailer/scripts/!adetailer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e4c7eacdccc7a575c36f124e4527c089194eacd
--- /dev/null
+++ b/adetailer/scripts/!adetailer.py
@@ -0,0 +1,1000 @@
+from __future__ import annotations
+
+import os
+import platform
+import re
+import sys
+import traceback
+from contextlib import contextmanager, suppress
+from copy import copy
+from functools import partial
+from pathlib import Path
+from textwrap import dedent
+from typing import TYPE_CHECKING, Any, NamedTuple
+
+import gradio as gr
+import torch
+from PIL import Image
+from rich import print
+from torchvision.transforms.functional import to_pil_image
+
+import modules
+from adetailer import (
+ AFTER_DETAILER,
+ __version__,
+ get_models,
+ mediapipe_predict,
+ ultralytics_predict,
+)
+from adetailer.args import ALL_ARGS, BBOX_SORTBY, ADetailerArgs, SkipImg2ImgOrig
+from adetailer.common import PredictOutput
+from adetailer.mask import (
+ filter_by_ratio,
+ filter_k_largest,
+ mask_preprocess,
+ sort_bboxes,
+)
+from adetailer.traceback import rich_traceback
+from adetailer.ui import WebuiInfo, adui, ordinal, suffix
+from controlnet_ext import ControlNetExt, controlnet_exists, get_cn_models
+from controlnet_ext.restore import (
+ CNHijackRestore,
+ cn_allow_script_control,
+)
+from modules import images, paths, safe, script_callbacks, scripts, shared
+from modules.devices import NansException
+from modules.processing import (
+ Processed,
+ StableDiffusionProcessingImg2Img,
+ create_infotext,
+ process_images,
+)
+from modules.sd_samplers import all_samplers
+from modules.shared import cmd_opts, opts, state
+
+if TYPE_CHECKING:
+ from fastapi import FastAPI
+
+no_huggingface = getattr(cmd_opts, "ad_no_huggingface", False)
+adetailer_dir = Path(paths.models_path, "adetailer")
+extra_models_dir = shared.opts.data.get("ad_extra_models_dir", "")
+model_mapping = get_models(
+ adetailer_dir, extra_dir=extra_models_dir, huggingface=not no_huggingface
+)
+txt2img_submit_button = img2img_submit_button = None
+SCRIPT_DEFAULT = "dynamic_prompting,dynamic_thresholding,wildcard_recursive,wildcards,lora_block_weight,negpip"
+
+if (
+ not adetailer_dir.exists()
+ and adetailer_dir.parent.exists()
+ and os.access(adetailer_dir.parent, os.W_OK)
+):
+ adetailer_dir.mkdir()
+
+print(
+ f"[-] ADetailer initialized. version: {__version__}, num models: {len(model_mapping)}"
+)
+
+
+@contextmanager
+def change_torch_load():
+ orig = torch.load
+ try:
+ torch.load = safe.unsafe_torch_load
+ yield
+ finally:
+ torch.load = orig
+
+
+@contextmanager
+def pause_total_tqdm():
+ orig = opts.data.get("multiple_tqdm", True)
+ try:
+ opts.data["multiple_tqdm"] = False
+ yield
+ finally:
+ opts.data["multiple_tqdm"] = orig
+
+
+@contextmanager
+def preseve_prompts(p):
+ all_pt = copy(p.all_prompts)
+ all_ng = copy(p.all_negative_prompts)
+ try:
+ yield
+ finally:
+ p.all_prompts = all_pt
+ p.all_negative_prompts = all_ng
+
+
+class AfterDetailerScript(scripts.Script):
+ def __init__(self):
+ super().__init__()
+ self.ultralytics_device = self.get_ultralytics_device()
+
+ self.controlnet_ext = None
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(version={__version__})"
+
+ def title(self):
+ return AFTER_DETAILER
+
+ def show(self, is_img2img):
+ return scripts.AlwaysVisible
+
+ def ui(self, is_img2img):
+ num_models = opts.data.get("ad_max_models", 2)
+ ad_model_list = list(model_mapping.keys())
+ sampler_names = [sampler.name for sampler in all_samplers]
+
+ try:
+ checkpoint_list = modules.sd_models.checkpoint_tiles(use_shorts=True)
+ except TypeError:
+ checkpoint_list = modules.sd_models.checkpoint_tiles()
+ vae_list = modules.shared_items.sd_vae_items()
+
+ webui_info = WebuiInfo(
+ ad_model_list=ad_model_list,
+ sampler_names=sampler_names,
+ t2i_button=txt2img_submit_button,
+ i2i_button=img2img_submit_button,
+ checkpoints_list=checkpoint_list,
+ vae_list=vae_list,
+ )
+
+ components, infotext_fields = adui(num_models, is_img2img, webui_info)
+
+ self.infotext_fields = infotext_fields
+ return components
+
+ def init_controlnet_ext(self) -> None:
+ if self.controlnet_ext is not None:
+ return
+ self.controlnet_ext = ControlNetExt()
+
+ if controlnet_exists:
+ try:
+ self.controlnet_ext.init_controlnet()
+ except ImportError:
+ error = traceback.format_exc()
+ print(
+ f"[-] ADetailer: ControlNetExt init failed:\n{error}",
+ file=sys.stderr,
+ )
+
+ def update_controlnet_args(self, p, args: ADetailerArgs) -> None:
+ if self.controlnet_ext is None:
+ self.init_controlnet_ext()
+
+ if (
+ self.controlnet_ext is not None
+ and self.controlnet_ext.cn_available
+ and args.ad_controlnet_model != "None"
+ ):
+ self.controlnet_ext.update_scripts_args(
+ p,
+ model=args.ad_controlnet_model,
+ module=args.ad_controlnet_module,
+ weight=args.ad_controlnet_weight,
+ guidance_start=args.ad_controlnet_guidance_start,
+ guidance_end=args.ad_controlnet_guidance_end,
+ )
+
+ def is_ad_enabled(self, *args_) -> bool:
+ arg_list = [arg for arg in args_ if isinstance(arg, dict)]
+ if not args_ or not arg_list:
+ message = f"""
+ [-] ADetailer: Invalid arguments passed to ADetailer.
+ input: {args_!r}
+ ADetailer disabled.
+ """
+ print(dedent(message), file=sys.stderr)
+ return False
+
+ ad_enabled = args_[0] if isinstance(args_[0], bool) else True
+ not_none = any(arg.get("ad_model", "None") != "None" for arg in arg_list)
+ return ad_enabled and not_none
+
+ def check_skip_img2img(self, p, *args_) -> None:
+ if (
+ hasattr(p, "_ad_skip_img2img")
+ or not hasattr(p, "init_images")
+ or not p.init_images
+ ):
+ return
+
+ if len(args_) >= 2 and isinstance(args_[1], bool):
+ p._ad_skip_img2img = args_[1]
+ if args_[1]:
+ p._ad_orig = SkipImg2ImgOrig(
+ steps=p.steps,
+ sampler_name=p.sampler_name,
+ width=p.width,
+ height=p.height,
+ )
+ p.steps = 1
+ p.sampler_name = "Euler"
+ p.width = 128
+ p.height = 128
+ else:
+ p._ad_skip_img2img = False
+
+ @staticmethod
+ def get_i(p) -> int:
+ it = p.iteration
+ bs = p.batch_size
+ i = p.batch_index
+ return it * bs + i
+
+ def get_args(self, p, *args_) -> list[ADetailerArgs]:
+ """
+ `args_` is at least 1 in length by `is_ad_enabled` immediately above
+ """
+ args = [arg for arg in args_ if isinstance(arg, dict)]
+
+ if not args:
+ message = f"[-] ADetailer: Invalid arguments passed to ADetailer: {args_!r}"
+ raise ValueError(message)
+
+ if hasattr(p, "_ad_xyz"):
+ args[0] = {**args[0], **p._ad_xyz}
+
+ all_inputs = []
+
+ for n, arg_dict in enumerate(args, 1):
+ try:
+ inp = ADetailerArgs(**arg_dict)
+ except ValueError as e:
+ msgs = [
+ f"[-] ADetailer: ValidationError when validating {ordinal(n)} arguments: {e}\n"
+ ]
+ for attr in ALL_ARGS.attrs:
+ arg = arg_dict.get(attr)
+ dtype = type(arg)
+ arg = "DEFAULT" if arg is None else repr(arg)
+ msgs.append(f" {attr}: {arg} ({dtype})")
+ raise ValueError("\n".join(msgs)) from e
+
+ all_inputs.append(inp)
+
+ return all_inputs
+
+ def extra_params(self, arg_list: list[ADetailerArgs]) -> dict:
+ params = {}
+ for n, args in enumerate(arg_list):
+ params.update(args.extra_params(suffix=suffix(n)))
+ params["ADetailer version"] = __version__
+ return params
+
+ @staticmethod
+ def get_ultralytics_device() -> str:
+ if "adetailer" in shared.cmd_opts.use_cpu:
+ return "cpu"
+
+ if platform.system() == "Darwin":
+ return ""
+
+ vram_args = ["lowvram", "medvram", "medvram_sdxl"]
+ if any(getattr(cmd_opts, vram, False) for vram in vram_args):
+ return "cpu"
+
+ return ""
+
+ def prompt_blank_replacement(
+ self, all_prompts: list[str], i: int, default: str
+ ) -> str:
+ if not all_prompts:
+ return default
+ if i < len(all_prompts):
+ return all_prompts[i]
+ j = i % len(all_prompts)
+ return all_prompts[j]
+
+ def _get_prompt(
+ self,
+ ad_prompt: str,
+ all_prompts: list[str],
+ i: int,
+ default: str,
+ replacements: list[PromptSR],
+ ) -> list[str]:
+ prompts = re.split(r"\s*\[SEP\]\s*", ad_prompt)
+ blank_replacement = self.prompt_blank_replacement(all_prompts, i, default)
+ for n in range(len(prompts)):
+ if not prompts[n]:
+ prompts[n] = blank_replacement
+ elif "[PROMPT]" in prompts[n]:
+ prompts[n] = prompts[n].replace("[PROMPT]", f" {blank_replacement} ")
+
+ for pair in replacements:
+ prompts[n] = prompts[n].replace(pair.s, pair.r)
+ return prompts
+
+ def get_prompt(self, p, args: ADetailerArgs) -> tuple[list[str], list[str]]:
+ i = self.get_i(p)
+ prompt_sr = p._ad_xyz_prompt_sr if hasattr(p, "_ad_xyz_prompt_sr") else []
+
+ prompt = self._get_prompt(args.ad_prompt, p.all_prompts, i, p.prompt, prompt_sr)
+ negative_prompt = self._get_prompt(
+ args.ad_negative_prompt,
+ p.all_negative_prompts,
+ i,
+ p.negative_prompt,
+ prompt_sr,
+ )
+
+ return prompt, negative_prompt
+
+ def get_seed(self, p) -> tuple[int, int]:
+ i = self.get_i(p)
+
+ if not p.all_seeds:
+ seed = p.seed
+ elif i < len(p.all_seeds):
+ seed = p.all_seeds[i]
+ else:
+ j = i % len(p.all_seeds)
+ seed = p.all_seeds[j]
+
+ if not p.all_subseeds:
+ subseed = p.subseed
+ elif i < len(p.all_subseeds):
+ subseed = p.all_subseeds[i]
+ else:
+ j = i % len(p.all_subseeds)
+ subseed = p.all_subseeds[j]
+
+ return seed, subseed
+
+ def get_width_height(self, p, args: ADetailerArgs) -> tuple[int, int]:
+ if args.ad_use_inpaint_width_height:
+ width = args.ad_inpaint_width
+ height = args.ad_inpaint_height
+ elif hasattr(p, "_ad_orig"):
+ width = p._ad_orig.width
+ height = p._ad_orig.height
+ else:
+ width = p.width
+ height = p.height
+
+ return width, height
+
+ def get_steps(self, p, args: ADetailerArgs) -> int:
+ if args.ad_use_steps:
+ return args.ad_steps
+ if hasattr(p, "_ad_orig"):
+ return p._ad_orig.steps
+ return p.steps
+
+ def get_cfg_scale(self, p, args: ADetailerArgs) -> float:
+ return args.ad_cfg_scale if args.ad_use_cfg_scale else p.cfg_scale
+
+ def get_sampler(self, p, args: ADetailerArgs) -> str:
+ if args.ad_use_sampler:
+ return args.ad_sampler
+ if hasattr(p, "_ad_orig"):
+ return p._ad_orig.sampler_name
+ return p.sampler_name
+
+ def get_override_settings(self, p, args: ADetailerArgs) -> dict[str, Any]:
+ d = {}
+
+ if args.ad_use_clip_skip:
+ d["CLIP_stop_at_last_layers"] = args.ad_clip_skip
+
+ if (
+ args.ad_use_checkpoint
+ and args.ad_checkpoint
+ and args.ad_checkpoint not in ("None", "Use same checkpoint")
+ ):
+ d["sd_model_checkpoint"] = args.ad_checkpoint
+
+ if (
+ args.ad_use_vae
+ and args.ad_vae
+ and args.ad_vae not in ("None", "Use same VAE")
+ ):
+ d["sd_vae"] = args.ad_vae
+ return d
+
+ def get_initial_noise_multiplier(self, p, args: ADetailerArgs) -> float | None:
+ return args.ad_noise_multiplier if args.ad_use_noise_multiplier else None
+
+ @staticmethod
+ def infotext(p) -> str:
+ return create_infotext(
+ p, p.all_prompts, p.all_seeds, p.all_subseeds, None, 0, 0
+ )
+
+ def write_params_txt(self, content: str) -> None:
+ params_txt = Path(paths.data_path, "params.txt")
+ with suppress(Exception):
+ params_txt.write_text(content, encoding="utf-8")
+
+ @staticmethod
+ def script_args_copy(script_args):
+ type_: type[list] | type[tuple] = type(script_args)
+ result = []
+ for arg in script_args:
+ try:
+ a = copy(arg)
+ except TypeError:
+ a = arg
+ result.append(a)
+ return type_(result)
+
+ def script_filter(self, p, args: ADetailerArgs):
+ script_runner = copy(p.scripts)
+ script_args = self.script_args_copy(p.script_args)
+
+ ad_only_seleted_scripts = opts.data.get("ad_only_seleted_scripts", True)
+ if not ad_only_seleted_scripts:
+ return script_runner, script_args
+
+ ad_script_names = opts.data.get("ad_script_names", SCRIPT_DEFAULT)
+ script_names_set = {
+ name
+ for script_name in ad_script_names.split(",")
+ for name in (script_name, script_name.strip())
+ }
+
+ if args.ad_controlnet_model != "None":
+ script_names_set.add("controlnet")
+
+ filtered_alwayson = []
+ for script_object in script_runner.alwayson_scripts:
+ filepath = script_object.filename
+ filename = Path(filepath).stem
+ if filename in script_names_set:
+ filtered_alwayson.append(script_object)
+
+ script_runner.alwayson_scripts = filtered_alwayson
+ return script_runner, script_args
+
+ def disable_controlnet_units(
+ self, script_args: list[Any] | tuple[Any, ...]
+ ) -> None:
+ for obj in script_args:
+ if "controlnet" in obj.__class__.__name__.lower():
+ if hasattr(obj, "enabled"):
+ obj.enabled = False
+ if hasattr(obj, "input_mode"):
+ obj.input_mode = getattr(obj.input_mode, "SIMPLE", "simple")
+
+ elif isinstance(obj, dict) and "module" in obj:
+ obj["enabled"] = False
+
+ def get_i2i_p(self, p, args: ADetailerArgs, image):
+ seed, subseed = self.get_seed(p)
+ width, height = self.get_width_height(p, args)
+ steps = self.get_steps(p, args)
+ cfg_scale = self.get_cfg_scale(p, args)
+ initial_noise_multiplier = self.get_initial_noise_multiplier(p, args)
+ sampler_name = self.get_sampler(p, args)
+ override_settings = self.get_override_settings(p, args)
+
+ i2i = StableDiffusionProcessingImg2Img(
+ init_images=[image],
+ resize_mode=0,
+ denoising_strength=args.ad_denoising_strength,
+ mask=None,
+ mask_blur=args.ad_mask_blur,
+ inpainting_fill=1,
+ inpaint_full_res=args.ad_inpaint_only_masked,
+ inpaint_full_res_padding=args.ad_inpaint_only_masked_padding,
+ inpainting_mask_invert=0,
+ initial_noise_multiplier=initial_noise_multiplier,
+ sd_model=p.sd_model,
+ outpath_samples=p.outpath_samples,
+ outpath_grids=p.outpath_grids,
+ prompt="", # replace later
+ negative_prompt="",
+ styles=p.styles,
+ seed=seed,
+ subseed=subseed,
+ subseed_strength=p.subseed_strength,
+ seed_resize_from_h=p.seed_resize_from_h,
+ seed_resize_from_w=p.seed_resize_from_w,
+ sampler_name=sampler_name,
+ batch_size=1,
+ n_iter=1,
+ steps=steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ restore_faces=args.ad_restore_face,
+ tiling=p.tiling,
+ extra_generation_params=p.extra_generation_params,
+ do_not_save_samples=True,
+ do_not_save_grid=True,
+ override_settings=override_settings,
+ )
+
+ i2i.cached_c = [None, None]
+ i2i.cached_uc = [None, None]
+ i2i.scripts, i2i.script_args = self.script_filter(p, args)
+ i2i._ad_disabled = True
+ i2i._ad_inner = True
+
+ if args.ad_controlnet_model != "Passthrough":
+ self.disable_controlnet_units(i2i.script_args)
+
+ if args.ad_controlnet_model not in ["None", "Passthrough"]:
+ self.update_controlnet_args(i2i, args)
+ elif args.ad_controlnet_model == "None":
+ i2i.control_net_enabled = False
+
+ return i2i
+
+ def save_image(self, p, image, *, condition: str, suffix: str) -> None:
+ i = self.get_i(p)
+ if p.all_prompts:
+ i %= len(p.all_prompts)
+ save_prompt = p.all_prompts[i]
+ else:
+ save_prompt = p.prompt
+ seed, _ = self.get_seed(p)
+
+ if opts.data.get(condition, False):
+ images.save_image(
+ image=image,
+ path=p.outpath_samples,
+ basename="",
+ seed=seed,
+ prompt=save_prompt,
+ extension=opts.samples_format,
+ info=self.infotext(p),
+ p=p,
+ suffix=suffix,
+ )
+
+ def get_ad_model(self, name: str):
+ if name not in model_mapping:
+ msg = f"[-] ADetailer: Model {name!r} not found. Available models: {list(model_mapping.keys())}"
+ raise ValueError(msg)
+ return model_mapping[name]
+
+ def sort_bboxes(self, pred: PredictOutput) -> PredictOutput:
+ sortby = opts.data.get("ad_bbox_sortby", BBOX_SORTBY[0])
+ sortby_idx = BBOX_SORTBY.index(sortby)
+ return sort_bboxes(pred, sortby_idx)
+
+ def pred_preprocessing(self, pred: PredictOutput, args: ADetailerArgs):
+ pred = filter_by_ratio(
+ pred, low=args.ad_mask_min_ratio, high=args.ad_mask_max_ratio
+ )
+ pred = filter_k_largest(pred, k=args.ad_mask_k_largest)
+ pred = self.sort_bboxes(pred)
+ return mask_preprocess(
+ pred.masks,
+ kernel=args.ad_dilate_erode,
+ x_offset=args.ad_x_offset,
+ y_offset=args.ad_y_offset,
+ merge_invert=args.ad_mask_merge_invert,
+ )
+
+ @staticmethod
+ def ensure_rgb_image(image: Any):
+ if not isinstance(image, Image.Image):
+ image = to_pil_image(image)
+ if image.mode != "RGB":
+ image = image.convert("RGB")
+ return image
+
+ @staticmethod
+ def i2i_prompts_replace(
+ i2i, prompts: list[str], negative_prompts: list[str], j: int
+ ) -> None:
+ i1 = min(j, len(prompts) - 1)
+ i2 = min(j, len(negative_prompts) - 1)
+ prompt = prompts[i1]
+ negative_prompt = negative_prompts[i2]
+ i2i.prompt = prompt
+ i2i.negative_prompt = negative_prompt
+
+ @staticmethod
+ def compare_prompt(p, processed, n: int = 0):
+ if p.prompt != processed.all_prompts[0]:
+ print(
+ f"[-] ADetailer: applied {ordinal(n + 1)} ad_prompt: {processed.all_prompts[0]!r}"
+ )
+
+ if p.negative_prompt != processed.all_negative_prompts[0]:
+ print(
+ f"[-] ADetailer: applied {ordinal(n + 1)} ad_negative_prompt: {processed.all_negative_prompts[0]!r}"
+ )
+
+ @staticmethod
+ def need_call_process(p) -> bool:
+ if p.scripts is None:
+ return False
+ i = p.batch_index
+ bs = p.batch_size
+ return i == bs - 1
+
+ @staticmethod
+ def need_call_postprocess(p) -> bool:
+ if p.scripts is None:
+ return False
+ return p.batch_index == 0
+
+ @staticmethod
+ def get_i2i_init_image(p, pp):
+ if getattr(p, "_ad_skip_img2img", False):
+ return p.init_images[0]
+ return pp.image
+
+ @staticmethod
+ def get_each_tap_seed(seed: int, i: int):
+ use_same_seed = shared.opts.data.get("ad_same_seed_for_each_tap", False)
+ return seed if use_same_seed else seed + i
+
+ @staticmethod
+ def is_img2img_inpaint(p) -> bool:
+ return hasattr(p, "image_mask") and bool(p.image_mask)
+
+ @rich_traceback
+ def process(self, p, *args_):
+ if getattr(p, "_ad_disabled", False):
+ return
+
+ if self.is_img2img_inpaint(p):
+ p._ad_disabled = True
+ msg = "[-] ADetailer: img2img inpainting detected. adetailer disabled."
+ print(msg)
+ return
+
+ if self.is_ad_enabled(*args_):
+ arg_list = self.get_args(p, *args_)
+ self.check_skip_img2img(p, *args_)
+ extra_params = self.extra_params(arg_list)
+ p.extra_generation_params.update(extra_params)
+ else:
+ p._ad_disabled = True
+
+ def _postprocess_image_inner(
+ self, p, pp, args: ADetailerArgs, *, n: int = 0
+ ) -> bool:
+ """
+ Returns
+ -------
+ bool
+
+ `True` if image was processed, `False` otherwise.
+ """
+ if state.interrupted or state.skipped:
+ return False
+
+ i = self.get_i(p)
+
+ i2i = self.get_i2i_p(p, args, pp.image)
+ seed, subseed = self.get_seed(p)
+ ad_prompts, ad_negatives = self.get_prompt(p, args)
+
+ is_mediapipe = args.ad_model.lower().startswith("mediapipe")
+
+ kwargs = {}
+ if is_mediapipe:
+ predictor = mediapipe_predict
+ ad_model = args.ad_model
+ else:
+ predictor = ultralytics_predict
+ ad_model = self.get_ad_model(args.ad_model)
+ kwargs["device"] = self.ultralytics_device
+
+ with change_torch_load():
+ pred = predictor(ad_model, pp.image, args.ad_confidence, **kwargs)
+
+ masks = self.pred_preprocessing(pred, args)
+ shared.state.assign_current_image(pred.preview)
+
+ if not masks:
+ print(
+ f"[-] ADetailer: nothing detected on image {i + 1} with {ordinal(n + 1)} settings."
+ )
+ return False
+
+ self.save_image(
+ p,
+ pred.preview,
+ condition="ad_save_previews",
+ suffix="-ad-preview" + suffix(n, "-"),
+ )
+
+ steps = len(masks)
+ processed = None
+ state.job_count += steps
+
+ if is_mediapipe:
+ print(f"mediapipe: {steps} detected.")
+
+ p2 = copy(i2i)
+ for j in range(steps):
+ p2.image_mask = masks[j]
+ p2.init_images[0] = self.ensure_rgb_image(p2.init_images[0])
+ self.i2i_prompts_replace(p2, ad_prompts, ad_negatives, j)
+
+ if re.match(r"^\s*\[SKIP\]\s*$", p2.prompt):
+ continue
+
+ p2.seed = self.get_each_tap_seed(seed, j)
+ p2.subseed = self.get_each_tap_seed(subseed, j)
+
+ try:
+ processed = process_images(p2)
+ except NansException as e:
+ msg = f"[-] ADetailer: 'NansException' occurred with {ordinal(n + 1)} settings.\n{e}"
+ print(msg, file=sys.stderr)
+ continue
+ finally:
+ p2.close()
+
+ self.compare_prompt(p2, processed, n=n)
+ p2 = copy(i2i)
+ p2.init_images = [processed.images[0]]
+
+ if processed is not None:
+ pp.image = processed.images[0]
+ return True
+
+ return False
+
+ @rich_traceback
+ def postprocess_image(self, p, pp, *args_):
+ if getattr(p, "_ad_disabled", False) or not self.is_ad_enabled(*args_):
+ return
+
+ pp.image = self.get_i2i_init_image(p, pp)
+ pp.image = self.ensure_rgb_image(pp.image)
+ init_image = copy(pp.image)
+ arg_list = self.get_args(p, *args_)
+ params_txt_content = Path(paths.data_path, "params.txt").read_text("utf-8")
+
+ if self.need_call_postprocess(p):
+ dummy = Processed(p, [], p.seed, "")
+ with preseve_prompts(p):
+ p.scripts.postprocess(copy(p), dummy)
+
+ is_processed = False
+ with CNHijackRestore(), pause_total_tqdm(), cn_allow_script_control():
+ for n, args in enumerate(arg_list):
+ if args.ad_model == "None":
+ continue
+ is_processed |= self._postprocess_image_inner(p, pp, args, n=n)
+
+ if is_processed and not getattr(p, "_ad_skip_img2img", False):
+ self.save_image(
+ p, init_image, condition="ad_save_images_before", suffix="-ad-before"
+ )
+
+ if self.need_call_process(p):
+ with preseve_prompts(p):
+ copy_p = copy(p)
+ if hasattr(p.scripts, "before_process"):
+ p.scripts.before_process(copy_p)
+ p.scripts.process(copy_p)
+
+ self.write_params_txt(params_txt_content)
+
+
+def on_after_component(component, **_kwargs):
+ global txt2img_submit_button, img2img_submit_button
+ if getattr(component, "elem_id", None) == "txt2img_generate":
+ txt2img_submit_button = component
+ return
+
+ if getattr(component, "elem_id", None) == "img2img_generate":
+ img2img_submit_button = component
+
+
+def on_ui_settings():
+ section = ("ADetailer", AFTER_DETAILER)
+ shared.opts.add_option(
+ "ad_max_models",
+ shared.OptionInfo(
+ default=2,
+ label="Max models",
+ component=gr.Slider,
+ component_args={"minimum": 1, "maximum": 10, "step": 1},
+ section=section,
+ ),
+ )
+
+ shared.opts.add_option(
+ "ad_extra_models_dir",
+ shared.OptionInfo(
+ default="",
+ label="Extra path to scan adetailer models",
+ component=gr.Textbox,
+ section=section,
+ ),
+ )
+
+ shared.opts.add_option(
+ "ad_save_previews",
+ shared.OptionInfo(False, "Save mask previews", section=section),
+ )
+
+ shared.opts.add_option(
+ "ad_save_images_before",
+ shared.OptionInfo(False, "Save images before ADetailer", section=section),
+ )
+
+ shared.opts.add_option(
+ "ad_only_seleted_scripts",
+ shared.OptionInfo(
+ True, "Apply only selected scripts to ADetailer", section=section
+ ),
+ )
+
+ textbox_args = {
+ "placeholder": "comma-separated list of script names",
+ "interactive": True,
+ }
+
+ shared.opts.add_option(
+ "ad_script_names",
+ shared.OptionInfo(
+ default=SCRIPT_DEFAULT,
+ label="Script names to apply to ADetailer (separated by comma)",
+ component=gr.Textbox,
+ component_args=textbox_args,
+ section=section,
+ ),
+ )
+
+ shared.opts.add_option(
+ "ad_bbox_sortby",
+ shared.OptionInfo(
+ default="None",
+ label="Sort bounding boxes by",
+ component=gr.Radio,
+ component_args={"choices": BBOX_SORTBY},
+ section=section,
+ ),
+ )
+
+ shared.opts.add_option(
+ "ad_same_seed_for_each_tap",
+ shared.OptionInfo(
+ False, "Use same seed for each tab in adetailer", section=section
+ ),
+ )
+
+
+# xyz_grid
+
+
+class PromptSR(NamedTuple):
+ s: str
+ r: str
+
+
+def set_value(p, x: Any, xs: Any, *, field: str):
+ if not hasattr(p, "_ad_xyz"):
+ p._ad_xyz = {}
+ p._ad_xyz[field] = x
+
+
+def search_and_replace_prompt(p, x: Any, xs: Any, replace_in_main_prompt: bool):
+ if replace_in_main_prompt:
+ p.prompt = p.prompt.replace(xs[0], x)
+ p.negative_prompt = p.negative_prompt.replace(xs[0], x)
+
+ if not hasattr(p, "_ad_xyz_prompt_sr"):
+ p._ad_xyz_prompt_sr = []
+ p._ad_xyz_prompt_sr.append(PromptSR(s=xs[0], r=x))
+
+
+def make_axis_on_xyz_grid():
+ xyz_grid = None
+ for script in scripts.scripts_data:
+ if script.script_class.__module__ == "xyz_grid.py":
+ xyz_grid = script.module
+ break
+
+ if xyz_grid is None:
+ return
+
+ model_list = ["None", *model_mapping.keys()]
+ samplers = [sampler.name for sampler in all_samplers]
+
+ axis = [
+ xyz_grid.AxisOption(
+ "[ADetailer] ADetailer model 1st",
+ str,
+ partial(set_value, field="ad_model"),
+ choices=lambda: model_list,
+ ),
+ xyz_grid.AxisOption(
+ "[ADetailer] ADetailer prompt 1st",
+ str,
+ partial(set_value, field="ad_prompt"),
+ ),
+ xyz_grid.AxisOption(
+ "[ADetailer] ADetailer negative prompt 1st",
+ str,
+ partial(set_value, field="ad_negative_prompt"),
+ ),
+ xyz_grid.AxisOption(
+ "[ADetailer] Prompt S/R (AD 1st)",
+ str,
+ partial(search_and_replace_prompt, replace_in_main_prompt=False),
+ ),
+ xyz_grid.AxisOption(
+ "[ADetailer] Prompt S/R (AD 1st and main prompt)",
+ str,
+ partial(search_and_replace_prompt, replace_in_main_prompt=True),
+ ),
+ xyz_grid.AxisOption(
+ "[ADetailer] Mask erosion / dilation 1st",
+ int,
+ partial(set_value, field="ad_dilate_erode"),
+ ),
+ xyz_grid.AxisOption(
+ "[ADetailer] Inpaint denoising strength 1st",
+ float,
+ partial(set_value, field="ad_denoising_strength"),
+ ),
+ xyz_grid.AxisOption(
+ "[ADetailer] Inpaint only masked 1st",
+ str,
+ partial(set_value, field="ad_inpaint_only_masked"),
+ choices=lambda: ["True", "False"],
+ ),
+ xyz_grid.AxisOption(
+ "[ADetailer] Inpaint only masked padding 1st",
+ int,
+ partial(set_value, field="ad_inpaint_only_masked_padding"),
+ ),
+ xyz_grid.AxisOption(
+ "[ADetailer] ADetailer sampler 1st",
+ str,
+ partial(set_value, field="ad_sampler"),
+ choices=lambda: samplers,
+ ),
+ xyz_grid.AxisOption(
+ "[ADetailer] ControlNet model 1st",
+ str,
+ partial(set_value, field="ad_controlnet_model"),
+ choices=lambda: ["None", *get_cn_models()],
+ ),
+ ]
+
+ if not any(x.label.startswith("[ADetailer]") for x in xyz_grid.axis_options):
+ xyz_grid.axis_options.extend(axis)
+
+
+def on_before_ui():
+ try:
+ make_axis_on_xyz_grid()
+ except Exception:
+ error = traceback.format_exc()
+ print(
+ f"[-] ADetailer: xyz_grid error:\n{error}",
+ file=sys.stderr,
+ )
+
+
+# api
+
+
+def add_api_endpoints(_: gr.Blocks, app: FastAPI):
+ @app.get("/adetailer/v1/version")
+ def version():
+ return {"version": __version__}
+
+ @app.get("/adetailer/v1/schema")
+ def schema():
+ return ADetailerArgs.schema()
+
+ @app.get("/adetailer/v1/ad_model")
+ def ad_model():
+ return {"ad_model": list(model_mapping)}
+
+
+script_callbacks.on_ui_settings(on_ui_settings)
+script_callbacks.on_after_component(on_after_component)
+script_callbacks.on_app_started(add_api_endpoints)
+script_callbacks.on_before_ui(on_before_ui)
diff --git a/kohya-sd-scripts-webui/.gitignore b/kohya-sd-scripts-webui/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..75f0ff1a77ceebd77b6f8cda2434327562ada2c5
--- /dev/null
+++ b/kohya-sd-scripts-webui/.gitignore
@@ -0,0 +1,9 @@
+__pycache__
+venv
+tmp
+
+kohya_ss
+wd14_tagger_model
+presets.json
+meta.json
+presets
\ No newline at end of file
diff --git a/kohya-sd-scripts-webui/README.md b/kohya-sd-scripts-webui/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1046cf714df992f604d9bc679e3e820ac3f85c9c
--- /dev/null
+++ b/kohya-sd-scripts-webui/README.md
@@ -0,0 +1,22 @@
+# kohya sd-scripts webui
+
+[![](https://img.shields.io/static/v1?message=Open%20in%20Colab&logo=googlecolab&labelColor=5c5c5c&color=0f80c1&label=%20&style=for-the-badge)](https://colab.research.google.com/github/ddPn08/kohya-sd-scripts-webui/blob/main/kohya-sd-scripts-webui-colab.ipynb)
+
+Gradio wrapper for [sd-scripts](https://github.com/kohya-ss/sd-scripts) by kohya
+
+It can be used as an extension of [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) or can be launched standalone.
+
+![](/screenshots/webui-01.png)
+
+# Usage
+## As an extension of stable-diffusion-webui
+
+Go to `Extensions` > `Install from URL`, enter the following URL and press the install button.
+
+https://github.com/ddpn08/kohya-sd-scripts-webui.git
+
+![](/screenshots/installation-extension.png)
+
+## Start standalone
+
+Run `webui.bat` for Windows, `webui.sh` for Linux, MacOS
diff --git a/kohya-sd-scripts-webui/built-in-presets.json b/kohya-sd-scripts-webui/built-in-presets.json
new file mode 100644
index 0000000000000000000000000000000000000000..be2d461745db6f251a85e817cc2fa2faca91f888
--- /dev/null
+++ b/kohya-sd-scripts-webui/built-in-presets.json
@@ -0,0 +1,126 @@
+{
+ "train_network": {
+ "lora-x512": {
+ "v2": null,
+ "v_parameterization": null,
+ "pretrained_model_name_or_path": null,
+ "train_data_dir": null,
+ "shuffle_caption": true,
+ "caption_extension": ".caption",
+ "caption_extention": null,
+ "keep_tokens": null,
+ "color_aug": null,
+ "flip_aug": true,
+ "face_crop_aug_range": null,
+ "random_crop": null,
+ "debug_dataset": null,
+ "resolution": "512",
+ "cache_latents": null,
+ "enable_bucket": true,
+ "min_bucket_reso": 256,
+ "max_bucket_reso": 1024,
+ "reg_data_dir": null,
+ "in_json": null,
+ "dataset_repeats": 1,
+ "output_dir": null,
+ "output_name": null,
+ "save_precision": null,
+ "save_every_n_epochs": 5,
+ "save_n_epoch_ratio": null,
+ "save_last_n_epochs": null,
+ "save_last_n_epochs_state": null,
+ "save_state": null,
+ "resume": null,
+ "train_batch_size": 1,
+ "max_token_length": null,
+ "use_8bit_adam": true,
+ "mem_eff_attn": null,
+ "xformers": true,
+ "vae": null,
+ "learning_rate": 0.0001,
+ "max_train_steps": 1600,
+ "max_train_epochs": null,
+ "max_data_loader_n_workers": 8,
+ "seed": null,
+ "gradient_checkpointing": true,
+ "gradient_accumulation_steps": 1,
+ "mixed_precision": "no",
+ "full_fp16": null,
+ "clip_skip": 2,
+ "logging_dir": null,
+ "log_prefix": null,
+ "lr_scheduler": "constant",
+ "lr_warmup_steps": 0,
+ "prior_loss_weight": 1.0,
+ "no_metadata": null,
+ "save_model_as": "safetensors",
+ "unet_lr": null,
+ "text_encoder_lr": null,
+ "network_weights": null,
+ "network_module": "networks.lora",
+ "network_dim": 16,
+ "network_alpha": 1.0,
+ "network_args": null,
+ "network_train_unet_only": null,
+ "network_train_text_encoder_only": null,
+ "training_comment": null
+ }
+ },
+ "train_db": {
+ "db-x512": {
+ "v2": null,
+ "v_parameterization": null,
+ "pretrained_model_name_or_path": null,
+ "train_data_dir": null,
+ "shuffle_caption": true,
+ "caption_extension": ".caption",
+ "caption_extention": null,
+ "keep_tokens": null,
+ "color_aug": null,
+ "flip_aug": true,
+ "face_crop_aug_range": null,
+ "random_crop": null,
+ "debug_dataset": null,
+ "resolution": null,
+ "cache_latents": null,
+ "enable_bucket": true,
+ "min_bucket_reso": 256,
+ "max_bucket_reso": 1024,
+ "reg_data_dir": null,
+ "output_dir": null,
+ "output_name": null,
+ "save_precision": null,
+ "save_every_n_epochs": 5,
+ "save_n_epoch_ratio": null,
+ "save_last_n_epochs": null,
+ "save_last_n_epochs_state": null,
+ "save_state": null,
+ "resume": null,
+ "train_batch_size": 1,
+ "max_token_length": null,
+ "use_8bit_adam": true,
+ "mem_eff_attn": null,
+ "xformers": true,
+ "vae": null,
+ "learning_rate": 1e-06,
+ "max_train_steps": 1600,
+ "max_train_epochs": null,
+ "max_data_loader_n_workers": 8,
+ "seed": null,
+ "gradient_checkpointing": null,
+ "gradient_accumulation_steps": 1,
+ "mixed_precision": "no",
+ "full_fp16": null,
+ "clip_skip": 2,
+ "logging_dir": null,
+ "log_prefix": null,
+ "lr_scheduler": "constant",
+ "lr_warmup_steps": 0,
+ "prior_loss_weight": 1.0,
+ "save_model_as": "safetensors",
+ "use_safetensors": null,
+ "no_token_padding": null,
+ "stop_text_encoder_training": null
+ }
+ }
+}
diff --git a/kohya-sd-scripts-webui/install.py b/kohya-sd-scripts-webui/install.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ec456d484630338687f26d709bf342f8aa14fa5
--- /dev/null
+++ b/kohya-sd-scripts-webui/install.py
@@ -0,0 +1,116 @@
+import sys
+import launch
+import platform
+import os
+import shutil
+import site
+import glob
+import re
+
+dirname = os.path.dirname(__file__)
+repo_dir = os.path.join(dirname, "kohya_ss")
+
+
+def prepare_environment():
+ torch_command = os.environ.get(
+ "TORCH_COMMAND",
+ "pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 --extra-index-url https://download.pytorch.org/whl/cu118",
+ )
+ sd_scripts_repo = os.environ.get("SD_SCRIPTS_REPO", "https://github.com/kohya-ss/sd-scripts.git")
+ sd_scripts_branch = os.environ.get("SD_SCRIPTS_BRANCH", "main")
+ requirements_file = os.environ.get("REQS_FILE", "requirements.txt")
+
+ sys.argv, skip_install = launch.extract_arg(sys.argv, "--skip-install")
+ sys.argv, disable_strict_version = launch.extract_arg(
+ sys.argv, "--disable-strict-version"
+ )
+ sys.argv, skip_torch_cuda_test = launch.extract_arg(
+ sys.argv, "--skip-torch-cuda-test"
+ )
+ sys.argv, skip_checkout_repo = launch.extract_arg(sys.argv, "--skip-checkout-repo")
+ sys.argv, update = launch.extract_arg(sys.argv, "--update")
+ sys.argv, reinstall_xformers = launch.extract_arg(sys.argv, "--reinstall-xformers")
+ sys.argv, reinstall_torch = launch.extract_arg(sys.argv, "--reinstall-torch")
+ xformers = "--xformers" in sys.argv
+ ngrok = "--ngrok" in sys.argv
+
+ if skip_install:
+ return
+
+
+ if (
+ reinstall_torch
+ or not launch.is_installed("torch")
+ or not launch.is_installed("torchvision")
+ ) and not disable_strict_version:
+ launch.run(
+ f'"{launch.python}" -m {torch_command}',
+ "Installing torch and torchvision",
+ "Couldn't install torch",
+ )
+
+ if not skip_torch_cuda_test:
+ launch.run_python(
+ "import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'"
+ )
+
+ if (not launch.is_installed("xformers") or reinstall_xformers) and xformers:
+ launch.run_pip("install xformers --pre", "xformers")
+
+ if update and os.path.exists(repo_dir):
+ launch.run(f'cd "{repo_dir}" && {launch.git} fetch --prune')
+ launch.run(f'cd "{repo_dir}" && {launch.git} reset --hard origin/main')
+ elif not os.path.exists(repo_dir):
+ launch.run(
+ f'{launch.git} clone {sd_scripts_repo} "{repo_dir}"'
+ )
+
+ if not skip_checkout_repo:
+ launch.run(f'cd "{repo_dir}" && {launch.git} checkout {sd_scripts_branch}')
+
+ if not launch.is_installed("gradio"):
+ launch.run_pip("install gradio==3.16.2", "gradio")
+
+ if not launch.is_installed("pyngrok") and ngrok:
+ launch.run_pip("install pyngrok", "ngrok")
+
+ if platform.system() == "Linux":
+ if not launch.is_installed("triton"):
+ launch.run_pip("install triton", "triton")
+
+ if disable_strict_version:
+ with open(os.path.join(repo_dir, requirements_file), "r") as f:
+ txt = f.read()
+ requirements = [
+ re.split("==|<|>", a)[0]
+ for a in txt.split("\n")
+ if (not a.startswith("#") and a != ".")
+ ]
+ requirements = " ".join(requirements)
+ launch.run_pip(
+ f'install "{requirements}" "{repo_dir}"',
+ "requirements for kohya sd-scripts",
+ )
+ else:
+ launch.run(
+ f'cd "{repo_dir}" && "{launch.python}" -m pip install -r requirements.txt',
+ desc=f"Installing requirements for kohya sd-scripts",
+ errdesc=f"Couldn't install requirements for kohya sd-scripts",
+ )
+
+ if platform.system() == "Windows":
+ for file in glob.glob(os.path.join(repo_dir, "bitsandbytes_windows", "*")):
+ filename = os.path.basename(file)
+ for dir in site.getsitepackages():
+ outfile = (
+ os.path.join(dir, "bitsandbytes", "cuda_setup", filename)
+ if filename == "main.py"
+ else os.path.join(dir, "bitsandbytes", filename)
+ )
+ if not os.path.exists(os.path.dirname(outfile)):
+ continue
+ shutil.copy(file, outfile)
+
+
+if __name__ == "__main__":
+ prepare_environment()
diff --git a/kohya-sd-scripts-webui/kohya-sd-scripts-webui-colab.ipynb b/kohya-sd-scripts-webui/kohya-sd-scripts-webui-colab.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..d0a26aad494513a31534778ac1a11154a0e0afc1
--- /dev/null
+++ b/kohya-sd-scripts-webui/kohya-sd-scripts-webui-colab.ipynb
@@ -0,0 +1,157 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "view-in-github"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zSM6HuYmkYCt"
+ },
+ "source": [
+ "# [kohya sd-scripts webui](https://github.com/ddPn08/kohya-sd-scripts-webui)\n",
+ "\n",
+ "This notebook is for running [sd-scripts](https://github.com/kohya-ss/sd-scripts) by [Kohya](https://github.com/kohya-ss).\n",
+ "\n",
+ "このノートブックは[Kohya](https://github.com/kohya-ss)さんによる[sd-scripts](https://github.com/kohya-ss/sd-scripts)を実行するためのものです。\n",
+ "\n",
+ "# Repository\n",
+ "[kohya_ss/sd-scripts](https://github.com/kohya-ss/sd-scripts)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "zXcznGdeyb2I"
+ },
+ "outputs": [],
+ "source": [
+ "! nvidia-smi\n",
+ "! nvcc -V\n",
+ "! free -h"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "tj65Tb_oyxtP"
+ },
+ "outputs": [],
+ "source": [
+ "# @markdown # Mount Google Drive\n",
+ "mount_gdrive = True # @param {type:\"boolean\"}\n",
+ "gdrive_preset_path = \"/content/drive/MyDrive/AI/kohya-sd-scripts-webui/presets\" # @param {type:\"string\"}\n",
+ "\n",
+ "if mount_gdrive:\n",
+ " from google.colab import drive\n",
+ " drive.mount('/content/drive', force_remount=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "FN7UJvSdzBFF"
+ },
+ "outputs": [],
+ "source": [
+ "# @markdown # Initialize environment\n",
+ "\n",
+ "! git clone https://github.com/ddPn08/kohya-sd-scripts-webui.git\n",
+ "\n",
+ "import os\n",
+ "\n",
+ "if not os.path.exists(gdrive_preset_path):\n",
+ " os.makedirs(gdrive_preset_path, exist_ok=True)\n",
+ "\n",
+ "! rm -f ./kohya-sd-scripts-webui/presets.json\n",
+ "! ln -s {gdrive_preset_path} ./kohya-sd-scripts-webui/presets\n",
+ "\n",
+ "conda_dir = \"/opt/conda\" # @param{type:\"string\"}\n",
+ "conda_bin = os.path.join(conda_dir, \"bin\", \"conda\")\n",
+ "\n",
+ "if not os.path.exists(conda_bin):\n",
+ " ! curl -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh\n",
+ " ! chmod +x Miniconda3-latest-Linux-x86_64.sh\n",
+ " ! bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p {conda_dir}\n",
+ " ! rm Miniconda3-latest-Linux-x86_64.sh\n",
+ "\n",
+ "def run_script(s):\n",
+ " ! {s}\n",
+ "\n",
+ "def make_args(d):\n",
+ " arguments = \"\"\n",
+ " for k, v in d.items():\n",
+ " if type(v) == bool:\n",
+ " arguments += f\"--{k} \" if v else \"\"\n",
+ " elif type(v) == str and v:\n",
+ " arguments += f\"--{k} \\\"{v}\\\" \"\n",
+ " elif v:\n",
+ " arguments += f\"--{k}={v} \"\n",
+ " return arguments"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "uetu1lShs6aJ"
+ },
+ "outputs": [],
+ "source": [
+ "# @markdown # Run\n",
+ "\n",
+ "# @markdown \n",
+ "\n",
+ "# @markdown ## Optional | Ngrok Tunnel\n",
+ "# @markdown Get token from [here](https://dashboard.ngrok.com/get-started/your-authtoken)\n",
+ "\n",
+ "ngrok_token = \"\" # @param {type:\"string\"}\n",
+ "ngrok_region = \"us\" # @param [\"us\", \"eu\", \"au\", \"ap\", \"sa\", \"jp\", \"in\"]\n",
+ "\n",
+ "arguments = {\n",
+ " \"ngrok\": ngrok_token,\n",
+ " \"ngrok-region\": ngrok_region,\n",
+ " \"share\": ngrok_token is None,\n",
+ " \"xformers\": True,\n",
+ " \"enable-console-log\": True\n",
+ "}\n",
+ "\n",
+ "run_script(f\"\"\"\n",
+ "eval \"$({conda_bin} shell.bash hook)\"\n",
+ "cd kohya-sd-scripts-webui\n",
+ "python launch.py {make_args(arguments)}\n",
+ "\"\"\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "include_colab_link": true,
+ "provenance": []
+ },
+ "gpuClass": "standard",
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/kohya-sd-scripts-webui/launch.py b/kohya-sd-scripts-webui/launch.py
new file mode 100644
index 0000000000000000000000000000000000000000..0462b4a72cf2ad3d55742660fce8bbc9db5effc3
--- /dev/null
+++ b/kohya-sd-scripts-webui/launch.py
@@ -0,0 +1,79 @@
+import install
+import subprocess
+import os
+import sys
+import importlib.util
+
+python = sys.executable
+git = os.environ.get("GIT", "git")
+index_url = os.environ.get("INDEX_URL", "")
+skip_install = False
+
+
+def run(command, desc=None, errdesc=None, custom_env=None):
+ if desc is not None:
+ print(desc)
+
+ result = subprocess.run(
+ command,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=True,
+ env=os.environ if custom_env is None else custom_env,
+ )
+
+ if result.returncode != 0:
+
+ message = f"""{errdesc or 'Error running command'}.
+Command: {command}
+Error code: {result.returncode}
+stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else ''}
+stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else ''}
+"""
+ raise RuntimeError(message)
+
+ return result.stdout.decode(encoding="utf8", errors="ignore")
+
+
+def check_run(command):
+ result = subprocess.run(
+ command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True
+ )
+ return result.returncode == 0
+
+
+def is_installed(package):
+ try:
+ spec = importlib.util.find_spec(package)
+ except ModuleNotFoundError:
+ return False
+
+ return spec is not None
+
+
+def run_pip(args, desc=None):
+ if skip_install:
+ return
+
+ index_url_line = f" --index-url {index_url}" if index_url != "" else ""
+ return run(
+ f'"{python}" -m pip {args} --prefer-binary{index_url_line}',
+ desc=f"Installing {desc}",
+ errdesc=f"Couldn't install {desc}",
+ )
+
+
+def run_python(code, desc=None, errdesc=None):
+ return run(f'"{python}" -c "{code}"', desc, errdesc)
+
+
+def extract_arg(args, name):
+ return [x for x in args if x != name], name in args
+
+
+if __name__ == "__main__":
+ install.prepare_environment()
+
+ from scripts import main
+
+ main.launch()
diff --git a/kohya-sd-scripts-webui/main.py b/kohya-sd-scripts-webui/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..24212ad0db18ff67be321d2b94e49c8ee6da9d72
--- /dev/null
+++ b/kohya-sd-scripts-webui/main.py
@@ -0,0 +1,14 @@
+import io
+import sys
+import subprocess
+
+ps = subprocess.Popen(
+ [sys.executable, "-u", "./sub.py"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT
+)
+
+reader = io.TextIOWrapper(ps.stdout, encoding='utf8')
+while ps.poll() is None:
+ char = reader.read(1)
+ if char == '\n':
+ print('break')
+ sys.stdout.write(char)
diff --git a/kohya-sd-scripts-webui/screenshots/installation-extension.png b/kohya-sd-scripts-webui/screenshots/installation-extension.png
new file mode 100644
index 0000000000000000000000000000000000000000..cbd75c6a824b2531db4caeb8923dda65aa2e355e
Binary files /dev/null and b/kohya-sd-scripts-webui/screenshots/installation-extension.png differ
diff --git a/kohya-sd-scripts-webui/screenshots/webui-01.png b/kohya-sd-scripts-webui/screenshots/webui-01.png
new file mode 100644
index 0000000000000000000000000000000000000000..66f57eb35619c15f5409a74c96328546e6dc5066
Binary files /dev/null and b/kohya-sd-scripts-webui/screenshots/webui-01.png differ
diff --git a/kohya-sd-scripts-webui/script.js b/kohya-sd-scripts-webui/script.js
new file mode 100644
index 0000000000000000000000000000000000000000..e860ff0a1644bf8f8228c0d0bb13ef7529018a16
--- /dev/null
+++ b/kohya-sd-scripts-webui/script.js
@@ -0,0 +1,87 @@
+function gradioApp() {
+ const elems = document.getElementsByTagName('gradio-app')
+ const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot
+ return !!gradioShadowRoot ? gradioShadowRoot : document;
+}
+
+let executed = false
+
+/** @type {(() => void)[]} */
+
+/**
+ * @param {string} tab
+ * @param {boolean} show
+ */
+function kohya_sd_webui__toggle_runner_button(tab, show) {
+ gradioApp().getElementById(`kohya_sd_webui__${tab}_run_button`).style.display = show ? 'block' : 'none'
+ gradioApp().getElementById(`kohya_sd_webui__${tab}_stop_button`).style.display = show ? 'none' : 'block'
+}
+
+window.addEventListener('DOMContentLoaded', () => {
+ const observer = new MutationObserver((m) => {
+ if (!executed && gradioApp().querySelector('#kohya_sd_webui__root')) {
+ executed = true;
+
+ /** @type {Record} */
+ const helps = kohya_sd_webui__help_map
+ /** @type {string[]} */
+ const all_tabs = kohya_sd_webui__all_tabs
+
+ const initializeTerminalObserver = () => {
+ const container = gradioApp().querySelector("#kohya_sd_webui__terminal_outputs")
+ const parentContainer = container.parentElement
+ const clearBtn = document.createElement('button')
+ clearBtn.innerText = 'Clear The Terminal'
+ clearBtn.style.color = 'yellow';
+ parentContainer.insertBefore(clearBtn, container)
+ let clearTerminal = false;
+ clearBtn.addEventListener('click', () => {
+ container.innerHTML = ''
+ clearTerminal = true
+ })
+ setInterval(async () => {
+ const res = await fetch('./internal/extensions/kohya-sd-scripts-webui/terminal/outputs', {
+ method: "POST",
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify({
+ output_index: container.children.length,
+ clear_terminal: clearTerminal,
+ }),
+ })
+ clearTerminal = false
+ const obj = await res.json()
+ const isBottom = container.scrollHeight - container.scrollTop === container.clientHeight
+ for(const line of obj.outputs){
+ const el = document.createElement('div')
+ el.innerText = line
+ container.appendChild(el)
+ }
+ if(isBottom) container.scrollTop = container.scrollHeight
+ }, 1000)
+ }
+
+ const checkProcessIsAlive = () => {
+ setInterval(async () => {
+ const res = await fetch('./internal/extensions/kohya-sd-scripts-webui/process/alive')
+ const obj = await res.json()
+ for (const tab of all_tabs)
+ kohya_sd_webui__toggle_runner_button(tab, !obj.alive)
+
+ }, 1000)
+ }
+
+ initializeTerminalObserver()
+ checkProcessIsAlive()
+
+ for (const tab of all_tabs)
+ gradioApp().querySelector(`#kohya_sd_webui__${tab}_run_button`).addEventListener('click', () => kohya_sd_webui__toggle_runner_button(tab, false))
+
+ for (const [k, v] of Object.entries(helps)) {
+ el = gradioApp().getElementById(k)
+ if (!el) continue
+ el.title = v
+ }
+ }
+ })
+ observer.observe(gradioApp(), { childList: true, subtree: true })
+})
\ No newline at end of file
diff --git a/kohya-sd-scripts-webui/scripts/main.py b/kohya-sd-scripts-webui/scripts/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..87631a910e9048a475f152fbda8e4986466a56d8
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/main.py
@@ -0,0 +1,98 @@
+import json
+import os
+import time
+
+import gradio.routes
+
+import scripts.runner as runner
+import scripts.shared as shared
+from scripts.shared import ROOT_DIR, is_webui_extension
+from scripts.ui import create_ui
+
+
+def create_js():
+ jsfile = os.path.join(ROOT_DIR, "script.js")
+ with open(jsfile, mode="r") as f:
+ js = f.read()
+
+ js = js.replace("kohya_sd_webui__help_map", json.dumps(shared.help_title_map))
+ js = js.replace(
+ "kohya_sd_webui__all_tabs",
+ json.dumps(shared.loaded_tabs),
+ )
+ return js
+
+
+def create_head():
+ head = f''
+
+ def template_response_for_webui(*args, **kwargs):
+ res = shared.gradio_template_response_original(*args, **kwargs)
+ res.body = res.body.replace(b"", f"{head}".encode("utf8"))
+ return res
+
+ def template_response(*args, **kwargs):
+ res = template_response_for_webui(*args, **kwargs)
+ res.init_headers()
+ return res
+
+ if is_webui_extension():
+ import modules.shared
+
+ modules.shared.GradioTemplateResponseOriginal = template_response_for_webui
+ else:
+ gradio.routes.templates.TemplateResponse = template_response
+
+
+def wait_on_server():
+ while 1:
+ time.sleep(0.5)
+
+
+def on_ui_tabs():
+ cssfile = os.path.join(ROOT_DIR, "style.css")
+ with open(cssfile, mode="r") as f:
+ css = f.read()
+ sd_scripts = create_ui(css)
+ create_head()
+ return [(sd_scripts, "Kohya sd-scripts", "kohya_sd_scripts")]
+
+
+def launch():
+ block, _, _ = on_ui_tabs()[0]
+ if shared.cmd_opts.ngrok is not None:
+ import scripts.ngrok as ngrok
+
+ address = ngrok.connect(
+ shared.cmd_opts.ngrok,
+ shared.cmd_opts.port if shared.cmd_opts.port is not None else 7860,
+ shared.cmd_opts.ngrok_region,
+ )
+ print("Running on ngrok URL: " + address)
+
+ app, local_url, share_url = block.launch(
+ share=shared.cmd_opts.share,
+ server_port=shared.cmd_opts.port,
+ server_name=shared.cmd_opts.host,
+ prevent_thread_lock=True,
+ )
+
+ runner.initialize_api(app)
+
+ wait_on_server()
+
+
+if not hasattr(shared, "gradio_template_response_original"):
+ shared.gradio_template_response_original = gradio.routes.templates.TemplateResponse
+
+if is_webui_extension():
+ from modules import script_callbacks
+
+ def initialize_api(_, app):
+ runner.initialize_api(app)
+
+ script_callbacks.on_ui_tabs(on_ui_tabs)
+ script_callbacks.on_app_started(initialize_api)
+
+if __name__ == "__main__":
+ launch()
diff --git a/kohya-sd-scripts-webui/scripts/ngrok.py b/kohya-sd-scripts-webui/scripts/ngrok.py
new file mode 100644
index 0000000000000000000000000000000000000000..567e7fd9332c3e065c48f10c3e37bc30eb277cf9
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/ngrok.py
@@ -0,0 +1,28 @@
+def connect(token, port, region):
+ from pyngrok import conf, exception, ngrok
+
+ account = None
+ if token is None:
+ token = "None"
+ else:
+ if ":" in token:
+ account = token.split(":")[1] + ":" + token.split(":")[-1]
+ token = token.split(":")[0]
+
+ config = conf.PyngrokConfig(auth_token=token, region=region)
+ try:
+ if account is None:
+ public_url = ngrok.connect(
+ port, pyngrok_config=config, bind_tls=True
+ ).public_url
+ else:
+ public_url = ngrok.connect(
+ port, pyngrok_config=config, bind_tls=True, auth=account
+ ).public_url
+ except exception.PyngrokNgrokError:
+ print(
+ f"Invalid ngrok authtoken, ngrok connection aborted.\n"
+ f"Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken"
+ )
+ else:
+ return public_url
diff --git a/kohya-sd-scripts-webui/scripts/presets.py b/kohya-sd-scripts-webui/scripts/presets.py
new file mode 100644
index 0000000000000000000000000000000000000000..646e07034e125502c874aa37322bbdaebac67e60
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/presets.py
@@ -0,0 +1,179 @@
+import argparse
+import inspect
+import os
+from pathlib import Path
+import toml
+from kohya_ss.library import train_util, config_util
+
+import gradio as gr
+
+from scripts.shared import ROOT_DIR
+from scripts.utilities import gradio_to_args
+
+PRESET_DIR = os.path.join(ROOT_DIR, "presets")
+PRESET_PATH = os.path.join(ROOT_DIR, "presets.json")
+
+
+def get_arg_templates(fn):
+ parser = argparse.ArgumentParser()
+ args = [parser]
+ sig = inspect.signature(fn)
+ args.extend([True] * (len(sig.parameters) - 1))
+ fn(*args)
+ keys = [
+ x.replace("--", "") for x in parser.__dict__["_option_string_actions"].keys()
+ ]
+ keys = [x for x in keys if x not in ["help", "-h"]]
+ return keys, fn.__name__.replace("add_", "")
+
+
+arguments_functions = [
+ train_util.add_dataset_arguments,
+ train_util.add_optimizer_arguments,
+ train_util.add_sd_models_arguments,
+ train_util.add_sd_saving_arguments,
+ train_util.add_training_arguments,
+ config_util.add_config_arguments,
+]
+
+arg_templates = [get_arg_templates(x) for x in arguments_functions]
+
+
+def load_presets():
+ obj = {}
+ os.makedirs(PRESET_DIR, exist_ok=True)
+ preset_names = os.listdir(PRESET_DIR)
+ for preset_name in preset_names:
+ preset_path = os.path.join(PRESET_DIR, preset_name)
+ obj[preset_name] = {}
+ for key in os.listdir(preset_path):
+ key = key.replace(".toml", "")
+ obj[preset_name][key] = load_preset(preset_name, key)
+ return obj
+
+
+def load_preset(key, name):
+ filepath = os.path.join(PRESET_DIR, key, name + ".toml")
+ if not os.path.exists(filepath):
+ return {}
+ with open(filepath, mode="r") as f:
+ obj = toml.load(f)
+
+ flatten = {}
+ for k, v in obj.items():
+ if not isinstance(v, dict):
+ flatten[k] = v
+ else:
+ for k2, v2 in v.items():
+ flatten[k2] = v2
+ return flatten
+
+
+def save_preset(key, name, value):
+ obj = {}
+ for k, v in value.items():
+ if isinstance(v, Path):
+ v = str(v)
+ for (template, category) in arg_templates:
+ if k in template:
+ if category not in obj:
+ obj[category] = {}
+ obj[category][k] = v
+ break
+ else:
+ obj[k] = v
+
+ filepath = os.path.join(PRESET_DIR, key, name + ".toml")
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
+ with open(filepath, mode="w") as f:
+ toml.dump(obj, f)
+
+
+def delete_preset(key, name):
+ filepath = os.path.join(PRESET_DIR, key, name + ".toml")
+ if os.path.exists(filepath):
+ os.remove(filepath)
+
+
+def create_ui(key, tmpls, opts):
+ get_templates = lambda: tmpls() if callable(tmpls) else tmpls
+ get_options = lambda: opts() if callable(opts) else opts
+
+ presets = load_presets()
+
+ if key not in presets:
+ presets[key] = {}
+
+ with gr.Box():
+ with gr.Row():
+ with gr.Column() as c:
+ load_preset_button = gr.Button("Load preset", variant="primary")
+ delete_preset_button = gr.Button("Delete preset")
+ with gr.Column() as c:
+ load_preset_name = gr.Dropdown(
+ list(presets[key].keys()), show_label=False
+ ).style(container=False)
+ reload_presets_button = gr.Button("🔄️")
+ with gr.Column() as c:
+ c.scale = 0.5
+ save_preset_name = gr.Textbox(
+ "", placeholder="Preset name", lines=1, show_label=False
+ ).style(container=False)
+ save_preset_button = gr.Button("Save preset", variant="primary")
+
+ def update_dropdown():
+ presets = load_presets()
+ if key not in presets:
+ presets[key] = {}
+ return gr.Dropdown.update(choices=list(presets[key].keys()))
+
+ def _save_preset(args):
+ name = args[save_preset_name]
+ if not name:
+ return update_dropdown()
+ args = gradio_to_args(get_templates(), get_options(), args)
+ save_preset(key, name, args)
+ return update_dropdown()
+
+ def _load_preset(args):
+ name = args[load_preset_name]
+ if not name:
+ return update_dropdown()
+ args = gradio_to_args(get_templates(), get_options(), args)
+ preset = load_preset(key, name)
+ result = []
+ for k, _ in args.items():
+ if k == load_preset_name:
+ continue
+ if k not in preset:
+ result.append(None)
+ continue
+ v = preset[k]
+ if type(v) == list:
+ v = " ".join(v)
+ result.append(v)
+ return result[0] if len(result) == 1 else result
+
+ def _delete_preset(name):
+ if not name:
+ return update_dropdown()
+ delete_preset(key, name)
+ return update_dropdown()
+
+ def init():
+ save_preset_button.click(
+ _save_preset,
+ set([save_preset_name, *get_options().values()]),
+ [load_preset_name],
+ )
+ load_preset_button.click(
+ _load_preset,
+ set([load_preset_name, *get_options().values()]),
+ [*get_options().values()],
+ )
+ delete_preset_button.click(_delete_preset, load_preset_name, [load_preset_name])
+ reload_presets_button.click(
+ update_dropdown, inputs=[], outputs=[load_preset_name]
+ )
+
+ return init
diff --git a/kohya-sd-scripts-webui/scripts/runner.py b/kohya-sd-scripts-webui/scripts/runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a34d616ee4e457ce95e5742efa83f417c4b97e7
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/runner.py
@@ -0,0 +1,113 @@
+import io
+import sys
+
+import fastapi
+import gradio as gr
+from pydantic import BaseModel, Field
+
+import scripts.shared as shared
+from scripts.utilities import run_python
+
+proc = None
+outputs = []
+
+
+def alive():
+ return proc is not None
+
+
+def initialize_runner(script_file, tmpls, opts):
+ run_button = gr.Button(
+ "Run",
+ variant="primary",
+ elem_id=f"kohya_sd_webui__{shared.current_tab}_run_button",
+ )
+ stop_button = gr.Button(
+ "Stop",
+ variant="secondary",
+ elem_id=f"kohya_sd_webui__{shared.current_tab}_stop_button",
+ )
+ get_templates = lambda: tmpls() if callable(tmpls) else tmpls
+ get_options = lambda: opts() if callable(opts) else opts
+
+ def run(args):
+ global proc
+ global outputs
+ if alive():
+ return
+ proc = run_python(script_file, get_templates(), get_options(), args)
+ reader = io.TextIOWrapper(proc.stdout, encoding="utf-8-sig")
+ line = ""
+ while proc is not None and proc.poll() is None:
+ try:
+ char = reader.read(1)
+ if shared.cmd_opts.enable_console_log:
+ sys.stdout.write(char)
+ if char == "\n":
+ outputs.append(line)
+ line = ""
+ continue
+ line += char
+ except:
+ ()
+ proc = None
+
+ def stop():
+ global proc
+ print("killed the running process")
+ proc.kill()
+ proc = None
+
+ def init():
+ run_button.click(
+ run,
+ set(get_options().values()),
+ )
+ stop_button.click(stop)
+
+ return init
+
+
+class GetOutputRequest(BaseModel):
+ output_index: int = Field(
+ default=0, title="Index of the beginning of the log to retrieve"
+ )
+ clear_terminal: bool = Field(
+ default=False, title="Whether to clear the terminal"
+ )
+
+
+class GetOutputResponse(BaseModel):
+ outputs: list = Field(title="List of terminal output")
+
+
+class ProcessAliveResponse(BaseModel):
+ alive: bool = Field(title="Whether the process is running.")
+
+
+def api_get_outputs(req: GetOutputRequest):
+ i = req.output_index
+ if req.clear_terminal:
+ global outputs
+ outputs = []
+ out = outputs[i:] if len(outputs) > i else []
+ return GetOutputResponse(outputs=out)
+
+
+def api_get_isalive(req: fastapi.Request):
+ return ProcessAliveResponse(alive=alive())
+
+
+def initialize_api(app: fastapi.FastAPI):
+ app.add_api_route(
+ "/internal/extensions/kohya-sd-scripts-webui/terminal/outputs",
+ api_get_outputs,
+ methods=["POST"],
+ response_model=GetOutputResponse,
+ )
+ app.add_api_route(
+ "/internal/extensions/kohya-sd-scripts-webui/process/alive",
+ api_get_isalive,
+ methods=["GET"],
+ response_model=ProcessAliveResponse,
+ )
diff --git a/kohya-sd-scripts-webui/scripts/shared.py b/kohya-sd-scripts-webui/scripts/shared.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c7ccaa6f54471afb1f897b7c38022fdf178cb66
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/shared.py
@@ -0,0 +1,32 @@
+import argparse
+import importlib
+import os
+import sys
+
+
+def is_webui_extension():
+ try:
+ importlib.import_module("webui")
+ return True
+ except:
+ return False
+
+
+ROOT_DIR = (
+ importlib.import_module("modules.scripts").basedir()
+ if is_webui_extension()
+ else os.path.dirname(os.path.dirname(__file__))
+)
+
+current_tab = None
+loaded_tabs = []
+help_title_map = {}
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--share", action="store_true")
+parser.add_argument("--port", type=int, default=None)
+parser.add_argument("--host", type=str, default=None)
+parser.add_argument("--ngrok", type=str, default=None)
+parser.add_argument("--ngrok-region", type=str, default="us")
+parser.add_argument("--enable-console-log", action="store_true")
+cmd_opts, _ = parser.parse_known_args(sys.argv)
diff --git a/kohya-sd-scripts-webui/scripts/tabs/networks/check_lora_weights.py b/kohya-sd-scripts-webui/scripts/tabs/networks/check_lora_weights.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddad2ebaa9a69db10da363b7e0b49d39e5852217
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/networks/check_lora_weights.py
@@ -0,0 +1,23 @@
+import gradio as gr
+
+from scripts import ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Check lora wights"
+
+
+def create_ui():
+ options = {}
+ templates, script_file = load_args_template("networks", "check_lora_weights.py")
+
+ with gr.Column():
+ init = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ init()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/networks/extract_lora_from_models.py b/kohya-sd-scripts-webui/scripts/tabs/networks/extract_lora_from_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..f77f790ab8bc92424f3ec539a9c6178a598800f6
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/networks/extract_lora_from_models.py
@@ -0,0 +1,25 @@
+import gradio as gr
+
+from scripts import ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Extract lora from models"
+
+
+def create_ui():
+ options = {}
+ templates, script_file = load_args_template(
+ "networks", "extract_lora_from_models.py"
+ )
+
+ with gr.Column():
+ init = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ init()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/networks/lora_interrogator.py b/kohya-sd-scripts-webui/scripts/tabs/networks/lora_interrogator.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f11c5710ef1c2451d2b803ac50c5a00aae5186c
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/networks/lora_interrogator.py
@@ -0,0 +1,23 @@
+import gradio as gr
+
+from scripts import ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Lora interrogator"
+
+
+def create_ui():
+ options = {}
+ templates, script_file = load_args_template("networks", "lora_interrogator.py")
+
+ with gr.Column():
+ init = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ init()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/networks/merge_lora.py b/kohya-sd-scripts-webui/scripts/tabs/networks/merge_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..74f534d45f383fa373a06d3367a208783732bfc0
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/networks/merge_lora.py
@@ -0,0 +1,23 @@
+import gradio as gr
+
+from scripts import ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Merge lora"
+
+
+def create_ui():
+ options = {}
+ templates, script_file = load_args_template("networks", "merge_lora.py")
+
+ with gr.Column():
+ init = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ init()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/networks/resize_lora.py b/kohya-sd-scripts-webui/scripts/tabs/networks/resize_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..98edbf99f427ef38128321e55d71187e21e167a9
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/networks/resize_lora.py
@@ -0,0 +1,23 @@
+import gradio as gr
+
+from scripts import ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Resize lora"
+
+
+def create_ui():
+ options = {}
+ templates, script_file = load_args_template("networks", "resize_lora.py")
+
+ with gr.Column():
+ init = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ init()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/networks/svd_merge_lora.py b/kohya-sd-scripts-webui/scripts/tabs/networks/svd_merge_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca089e8cebf309c3d3d97f236497a9fb7c64ef86
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/networks/svd_merge_lora.py
@@ -0,0 +1,23 @@
+import gradio as gr
+
+from scripts import ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Svd merge lora"
+
+
+def create_ui():
+ options = {}
+ templates, script_file = load_args_template("networks", "svd_merge_lora.py")
+
+ with gr.Column():
+ init = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ init()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/preparation/clean_captions_and_tags.py b/kohya-sd-scripts-webui/scripts/tabs/preparation/clean_captions_and_tags.py
new file mode 100644
index 0000000000000000000000000000000000000000..2da949256a4c3d2b53821ffcc9ceb1d415344477
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/preparation/clean_captions_and_tags.py
@@ -0,0 +1,37 @@
+import gradio as gr
+
+from scripts import presets, ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Clean captions and tags"
+
+
+def create_ui():
+ import traceback
+
+ try:
+ options = {}
+ templates, script_file = load_args_template(
+ "finetune", "clean_captions_and_tags.py"
+ )
+
+ with gr.Column():
+ init_runner = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ with gr.Row():
+ init_ui = presets.create_ui(
+ "finetune.clean_captions_and_tags", templates, options
+ )
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ init_runner()
+ init_ui()
+
+ except:
+ traceback.print_exc()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/preparation/make_captions.py b/kohya-sd-scripts-webui/scripts/tabs/preparation/make_captions.py
new file mode 100644
index 0000000000000000000000000000000000000000..38a2c67f1af87b88b526141606af26cb78d41ec1
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/preparation/make_captions.py
@@ -0,0 +1,29 @@
+import gradio as gr
+
+from scripts import presets, ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Make captions"
+
+
+def create_ui():
+ options = {}
+ templates, script_file = load_args_template("finetune", "make_captions.py")
+
+ with gr.Column():
+ init_runner = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ with gr.Row():
+ init_ui = presets.create_ui(
+ "finetune.make_captions", templates, options
+ )
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ init_runner()
+ init_ui()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/preparation/make_captions_by_git.py b/kohya-sd-scripts-webui/scripts/tabs/preparation/make_captions_by_git.py
new file mode 100644
index 0000000000000000000000000000000000000000..bef4389644b66bc136ab84741602af3f8991cfa7
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/preparation/make_captions_by_git.py
@@ -0,0 +1,29 @@
+import gradio as gr
+
+from scripts import presets, ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Make captions by GIT"
+
+
+def create_ui():
+ options = {}
+ templates, script_file = load_args_template("finetune", "make_captions_by_git.py")
+
+ with gr.Column():
+ init_runner = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ with gr.Row():
+ init_ui = presets.create_ui(
+ "finetune.make_captions_by_git", templates, options
+ )
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ init_runner()
+ init_ui()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/preparation/merge_captions.py b/kohya-sd-scripts-webui/scripts/tabs/preparation/merge_captions.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a1b45699f35a17b5a844b9ab54f4ac5ff589cfd
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/preparation/merge_captions.py
@@ -0,0 +1,31 @@
+import gradio as gr
+
+from scripts import presets, ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Merge captions"
+
+
+def create_ui():
+ options = {}
+ templates, script_file = load_args_template(
+ "finetune", "merge_captions_to_metadata.py"
+ )
+
+ with gr.Column():
+ inti_runner = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ with gr.Row():
+ init_ui = presets.create_ui(
+ "finetune.merge_captions_to_metadata", templates, options
+ )
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ inti_runner()
+ init_ui()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/preparation/merge_tags.py b/kohya-sd-scripts-webui/scripts/tabs/preparation/merge_tags.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbc1317c19defb80a55aeb8ce665d4199066c660
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/preparation/merge_tags.py
@@ -0,0 +1,31 @@
+import gradio as gr
+
+from scripts import presets, ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Merge tags"
+
+
+def create_ui():
+ options = {}
+ templates, script_file = load_args_template(
+ "finetune", "merge_dd_tags_to_metadata.py"
+ )
+
+ with gr.Column():
+ init_runner = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ with gr.Row():
+ init_id = presets.create_ui(
+ "finetune.merge_dd_tags_to_metadata", templates, options
+ )
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ init_runner()
+ init_id()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/preparation/prepare_latents.py b/kohya-sd-scripts-webui/scripts/tabs/preparation/prepare_latents.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4d80252d5d7d0643e72f0feb63e2690c04f9d97
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/preparation/prepare_latents.py
@@ -0,0 +1,31 @@
+import gradio as gr
+
+from scripts import presets, ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Prepare latents"
+
+
+def create_ui():
+ options = {}
+ templates, script_file = load_args_template(
+ "finetune", "prepare_buckets_latents.py"
+ )
+
+ with gr.Column():
+ init_runner = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ with gr.Row():
+ init_ui = presets.create_ui(
+ "finetune.prepare_buckets_latents", templates, options
+ )
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ init_runner()
+ init_ui()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/preparation/tag_images_by_wd14tagger.py b/kohya-sd-scripts-webui/scripts/tabs/preparation/tag_images_by_wd14tagger.py
new file mode 100644
index 0000000000000000000000000000000000000000..50ebac56a64258e42df2e34374d3add5f1476d93
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/preparation/tag_images_by_wd14tagger.py
@@ -0,0 +1,31 @@
+import gradio as gr
+
+from scripts import presets, ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Tag images by wd1.4tagger"
+
+
+def create_ui():
+ options = {}
+ templates, script_file = load_args_template(
+ "finetune", "tag_images_by_wd14_tagger.py"
+ )
+
+ with gr.Column():
+ init_runner = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ with gr.Row():
+ init_id = presets.create_ui(
+ "finetune.tag_images_by_wd14_tagger", templates, options
+ )
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ init_runner()
+ init_id()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/tools/convert_diffusers.py b/kohya-sd-scripts-webui/scripts/tabs/tools/convert_diffusers.py
new file mode 100644
index 0000000000000000000000000000000000000000..709cdd254836b14d8ad9602396eb068482fb8755
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/tools/convert_diffusers.py
@@ -0,0 +1,31 @@
+import gradio as gr
+
+from scripts import presets, ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Convert Diffusers"
+
+
+def create_ui():
+ options = {}
+ templates, script_file = load_args_template(
+ "tools", "convert_diffusers20_original_sd.py"
+ )
+
+ with gr.Column():
+ init_runner = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ with gr.Row():
+ init_ui = presets.create_ui(
+ "convert_diffusers20_original_sd", templates, options
+ )
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ init_runner()
+ init_ui()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/tools/detect_face_rotate.py b/kohya-sd-scripts-webui/scripts/tabs/tools/detect_face_rotate.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bc4d0fc340a33437fb0ad8a175aeb8b2268a2d8
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/tools/detect_face_rotate.py
@@ -0,0 +1,30 @@
+import gradio as gr
+
+from scripts import presets, ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Detect face rotate"
+
+
+def create_ui():
+ options = {}
+
+ templates, script_file = load_args_template("tools", "detect_face_rotate.py")
+
+ with gr.Column():
+ init_runner = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ with gr.Row():
+ init_ui = presets.create_ui(
+ "tools.detect_face_rotate", templates, options
+ )
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ init_runner()
+ init_ui()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/tools/resize_images_to_resolution.py b/kohya-sd-scripts-webui/scripts/tabs/tools/resize_images_to_resolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..2065ee4c68e02203690007da77971f7e176ba225
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/tools/resize_images_to_resolution.py
@@ -0,0 +1,32 @@
+import gradio as gr
+
+from scripts import presets, ui
+from scripts.runner import initialize_runner
+from scripts.utilities import load_args_template, options_to_gradio
+
+
+def title():
+ return "Resize images to resolution"
+
+
+def create_ui():
+ options = {}
+
+ templates, script_file = load_args_template(
+ "tools", "resize_images_to_resolution.py"
+ )
+
+ with gr.Column():
+ init_runner = initialize_runner(script_file, templates, options)
+ with gr.Box():
+ with gr.Row():
+ init_ui = presets.create_ui(
+ "tools.resize_images_to_resolution", templates, options
+ )
+ with gr.Box():
+ ui.title("Options")
+ with gr.Column():
+ options_to_gradio(templates, options)
+
+ init_runner()
+ init_ui()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/training/fine_tune.py b/kohya-sd-scripts-webui/scripts/tabs/training/fine_tune.py
new file mode 100644
index 0000000000000000000000000000000000000000..a52f1d37db84ef0380b4f16f98edae729d572cf4
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/training/fine_tune.py
@@ -0,0 +1,93 @@
+import argparse
+
+import gradio as gr
+
+from kohya_ss.library import train_util, config_util
+from scripts import presets, ui, ui_overrides
+from scripts.runner import initialize_runner
+from scripts.utilities import args_to_gradio, load_args_template, options_to_gradio
+
+
+def title():
+ return "Fine tune"
+
+
+def create_ui():
+ sd_models_arguments = argparse.ArgumentParser()
+ dataset_arguments = argparse.ArgumentParser()
+ training_arguments = argparse.ArgumentParser()
+ sd_saving_arguments = argparse.ArgumentParser()
+ optimizer_arguments = argparse.ArgumentParser()
+ config_arguments = argparse.ArgumentParser()
+ train_util.add_sd_models_arguments(sd_models_arguments)
+ train_util.add_dataset_arguments(dataset_arguments, False, True, True)
+ train_util.add_training_arguments(training_arguments, False)
+ train_util.add_sd_saving_arguments(sd_saving_arguments)
+ train_util.add_optimizer_arguments(optimizer_arguments)
+ config_util.add_config_arguments(config_arguments)
+ sd_models_options = {}
+ dataset_options = {}
+ training_options = {}
+ sd_saving_options = {}
+ optimizer_options = {}
+ config_options = {}
+ finetune_options = {}
+
+ templates, script_file = load_args_template("fine_tune.py")
+
+ get_options = lambda: {
+ **sd_models_options,
+ **dataset_options,
+ **training_options,
+ **sd_saving_options,
+ **optimizer_options,
+ **finetune_options,
+ **config_options,
+ }
+
+ get_templates = lambda: {
+ **sd_models_arguments.__dict__["_option_string_actions"],
+ **dataset_arguments.__dict__["_option_string_actions"],
+ **training_arguments.__dict__["_option_string_actions"],
+ **sd_saving_arguments.__dict__["_option_string_actions"],
+ **optimizer_arguments.__dict__["_option_string_actions"],
+ **config_arguments.__dict__["_option_string_actions"],
+ **templates,
+ }
+
+ with gr.Column():
+ init_runner = initialize_runner(script_file, get_templates, get_options)
+ with gr.Box():
+ with gr.Row():
+ init_ui = presets.create_ui("fine_tune", get_templates, get_options)
+ with gr.Row():
+ with gr.Group():
+ with gr.Box():
+ ui.title("Fine tune options")
+ options_to_gradio(templates, finetune_options)
+ with gr.Box():
+ ui.title("Model options")
+ args_to_gradio(sd_models_arguments, sd_models_options)
+ with gr.Box():
+ ui.title("Dataset options")
+ args_to_gradio(dataset_arguments, dataset_options)
+ with gr.Box():
+ ui.title("Dataset Config options")
+ args_to_gradio(config_arguments, config_options)
+ with gr.Box():
+ ui.title("Training options")
+ args_to_gradio(training_arguments, training_options)
+ with gr.Group():
+ with gr.Box():
+ ui.title("Save options")
+ args_to_gradio(sd_saving_arguments, sd_saving_options)
+ with gr.Box():
+ ui.title("Optimizer options")
+ args_to_gradio(
+ optimizer_arguments,
+ optimizer_options,
+ ui_overrides.OPTIMIZER_OPTIONS,
+ )
+
+ init_runner()
+ init_ui()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/training/train_db.py b/kohya-sd-scripts-webui/scripts/tabs/training/train_db.py
new file mode 100644
index 0000000000000000000000000000000000000000..2954b8603accc92c94fa6e3b4bd09fe13777cb7e
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/training/train_db.py
@@ -0,0 +1,93 @@
+import argparse
+
+import gradio as gr
+
+from kohya_ss.library import train_util, config_util
+from scripts import presets, ui, ui_overrides
+from scripts.runner import initialize_runner
+from scripts.utilities import args_to_gradio, load_args_template, options_to_gradio
+
+
+def title():
+ return "Train dreambooth"
+
+
+def create_ui():
+ sd_models_arguments = argparse.ArgumentParser()
+ dataset_arguments = argparse.ArgumentParser()
+ training_arguments = argparse.ArgumentParser()
+ sd_saving_arguments = argparse.ArgumentParser()
+ optimizer_arguments = argparse.ArgumentParser()
+ config_arguments = argparse.ArgumentParser()
+ train_util.add_sd_models_arguments(sd_models_arguments)
+ train_util.add_dataset_arguments(dataset_arguments, True, False, True)
+ train_util.add_training_arguments(training_arguments, True)
+ train_util.add_sd_saving_arguments(sd_saving_arguments)
+ train_util.add_optimizer_arguments(optimizer_arguments)
+ config_util.add_config_arguments(config_arguments)
+ sd_models_options = {}
+ dataset_options = {}
+ training_options = {}
+ sd_saving_options = {}
+ optimizer_options = {}
+ config_options = {}
+ dreambooth_options = {}
+
+ templates, script_file = load_args_template("train_db.py")
+
+ get_options = lambda: {
+ **sd_models_options,
+ **dataset_options,
+ **training_options,
+ **sd_saving_options,
+ **optimizer_options,
+ **config_options,
+ **dreambooth_options,
+ }
+
+ get_templates = lambda: {
+ **sd_models_arguments.__dict__["_option_string_actions"],
+ **dataset_arguments.__dict__["_option_string_actions"],
+ **training_arguments.__dict__["_option_string_actions"],
+ **sd_saving_arguments.__dict__["_option_string_actions"],
+ **optimizer_arguments.__dict__["_option_string_actions"],
+ **config_arguments.__dict__["_option_string_actions"],
+ **templates,
+ }
+
+ with gr.Column():
+ init_runner = initialize_runner(script_file, get_templates, get_options)
+ with gr.Box():
+ with gr.Row():
+ init_ui = presets.create_ui("train_db", get_templates, get_options)
+ with gr.Row():
+ with gr.Group():
+ with gr.Box():
+ ui.title("Dreambooth options")
+ options_to_gradio(templates, dreambooth_options)
+ with gr.Box():
+ ui.title("Model options")
+ args_to_gradio(sd_models_arguments, sd_models_options)
+ with gr.Box():
+ ui.title("Dataset options")
+ args_to_gradio(dataset_arguments, dataset_options)
+ with gr.Box():
+ ui.title("Dataset Config options")
+ args_to_gradio(config_arguments, config_options)
+ with gr.Box():
+ ui.title("Training options")
+ args_to_gradio(training_arguments, training_options)
+ with gr.Group():
+ with gr.Box():
+ ui.title("Save options")
+ args_to_gradio(sd_saving_arguments, sd_saving_options)
+ with gr.Box():
+ ui.title("Optimizer options")
+ args_to_gradio(
+ optimizer_arguments,
+ optimizer_options,
+ ui_overrides.OPTIMIZER_OPTIONS,
+ )
+
+ init_runner()
+ init_ui()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/training/train_network.py b/kohya-sd-scripts-webui/scripts/tabs/training/train_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e15ec8e8d451407b4dde8bf678194a6d93b415f
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/training/train_network.py
@@ -0,0 +1,84 @@
+import argparse
+
+import gradio as gr
+
+from kohya_ss.library import train_util, config_util
+from scripts import presets, ui, ui_overrides
+from scripts.runner import initialize_runner
+from scripts.utilities import args_to_gradio, load_args_template, options_to_gradio
+
+
+def title():
+ return "Train network"
+
+
+def create_ui():
+ sd_models_arguments = argparse.ArgumentParser()
+ dataset_arguments = argparse.ArgumentParser()
+ training_arguments = argparse.ArgumentParser()
+ optimizer_arguments = argparse.ArgumentParser()
+ config_arguments = argparse.ArgumentParser()
+ train_util.add_sd_models_arguments(sd_models_arguments)
+ train_util.add_dataset_arguments(dataset_arguments, True, True, True)
+ train_util.add_training_arguments(training_arguments, True)
+ train_util.add_optimizer_arguments(optimizer_arguments)
+ config_util.add_config_arguments(config_arguments)
+ sd_models_options = {}
+ dataset_options = {}
+ training_options = {}
+ optimizer_options = {}
+ config_options = {}
+ network_options = {}
+
+ templates, script_file = load_args_template("train_network.py")
+
+ get_options = lambda: {
+ **sd_models_options,
+ **dataset_options,
+ **training_options,
+ **optimizer_options,
+ **config_options,
+ **network_options,
+ }
+
+ get_templates = lambda: {
+ **sd_models_arguments.__dict__["_option_string_actions"],
+ **dataset_arguments.__dict__["_option_string_actions"],
+ **training_arguments.__dict__["_option_string_actions"],
+ **optimizer_arguments.__dict__["_option_string_actions"],
+ **config_arguments.__dict__["_option_string_actions"],
+ **templates,
+ }
+
+ with gr.Column():
+ init_runner = initialize_runner(script_file, get_templates, get_options)
+ with gr.Box():
+ with gr.Row():
+ init_id = presets.create_ui("train_network", get_templates, get_options)
+ with gr.Row():
+ with gr.Group():
+ with gr.Box():
+ ui.title("Network options")
+ options_to_gradio(templates, network_options)
+ with gr.Box():
+ ui.title("Model options")
+ args_to_gradio(sd_models_arguments, sd_models_options)
+ with gr.Box():
+ ui.title("Dataset Config options")
+ args_to_gradio(config_arguments, config_options)
+ with gr.Box():
+ ui.title("Dataset options")
+ args_to_gradio(dataset_arguments, dataset_options)
+ with gr.Box():
+ ui.title("Training options")
+ args_to_gradio(training_arguments, training_options)
+ with gr.Box():
+ ui.title("Optimizer options")
+ args_to_gradio(
+ optimizer_arguments,
+ optimizer_options,
+ ui_overrides.OPTIMIZER_OPTIONS,
+ )
+
+ init_runner()
+ init_id()
diff --git a/kohya-sd-scripts-webui/scripts/tabs/training/train_textual_inversion.py b/kohya-sd-scripts-webui/scripts/tabs/training/train_textual_inversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..321c2d79063141679df79c75c95c25dbc540e184
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/tabs/training/train_textual_inversion.py
@@ -0,0 +1,86 @@
+import argparse
+
+import gradio as gr
+
+from kohya_ss.library import train_util, config_util
+from scripts import presets, ui, ui_overrides
+from scripts.runner import initialize_runner
+from scripts.utilities import args_to_gradio, load_args_template, options_to_gradio
+
+
+def title():
+ return "Train textual inversion"
+
+
+def create_ui():
+ sd_models_arguments = argparse.ArgumentParser()
+ dataset_arguments = argparse.ArgumentParser()
+ training_arguments = argparse.ArgumentParser()
+ optimizer_arguments = argparse.ArgumentParser()
+ config_arguments = argparse.ArgumentParser()
+ train_util.add_sd_models_arguments(sd_models_arguments)
+ train_util.add_dataset_arguments(dataset_arguments, True, True, False)
+ train_util.add_training_arguments(training_arguments, True)
+ train_util.add_optimizer_arguments(optimizer_arguments)
+ config_util.add_config_arguments(config_arguments)
+ sd_models_options = {}
+ dataset_options = {}
+ training_options = {}
+ optimizer_options = {}
+ config_options = {}
+ ti_options = {}
+
+ templates, script_file = load_args_template("train_textual_inversion.py")
+
+ get_options = lambda: {
+ **sd_models_options,
+ **dataset_options,
+ **training_options,
+ **optimizer_options,
+ **config_options,
+ **ti_options,
+ }
+
+ get_templates = lambda: {
+ **sd_models_arguments.__dict__["_option_string_actions"],
+ **dataset_arguments.__dict__["_option_string_actions"],
+ **training_arguments.__dict__["_option_string_actions"],
+ **optimizer_arguments.__dict__["_option_string_actions"],
+ **config_arguments.__dict__["_option_string_actions"],
+ **templates,
+ }
+
+ with gr.Column():
+ init_runner = initialize_runner(script_file, get_templates, get_options)
+ with gr.Box():
+ with gr.Row():
+ init_ui = presets.create_ui(
+ "train_textual_inversion", get_templates, get_options
+ )
+ with gr.Row():
+ with gr.Group():
+ with gr.Box():
+ ui.title("Textual inversion options")
+ options_to_gradio(templates, ti_options)
+ with gr.Box():
+ ui.title("Model options")
+ args_to_gradio(sd_models_arguments, sd_models_options)
+ with gr.Box():
+ ui.title("Dataset Config options")
+ args_to_gradio(config_arguments, config_options)
+ with gr.Box():
+ ui.title("Dataset options")
+ args_to_gradio(dataset_arguments, dataset_options)
+ with gr.Box():
+ ui.title("Training options")
+ args_to_gradio(training_arguments, training_options)
+ with gr.Box():
+ ui.title("Optimizer options")
+ args_to_gradio(
+ optimizer_arguments,
+ optimizer_options,
+ ui_overrides.OPTIMIZER_OPTIONS,
+ )
+
+ init_runner()
+ init_ui()
diff --git a/kohya-sd-scripts-webui/scripts/ui.py b/kohya-sd-scripts-webui/scripts/ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..eec24142240fe45ba7d2783729bf36177c8884f1
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/ui.py
@@ -0,0 +1,53 @@
+import glob
+import importlib
+import os
+import sys
+
+import gradio as gr
+
+import scripts.shared as shared
+from scripts.shared import ROOT_DIR
+from scripts.utilities import path_to_module
+
+
+def title(txt):
+ gr.HTML(
+ f'
{txt}
',
+ )
+
+
+def create_ui(css):
+ PATHS = [
+ os.path.join(ROOT_DIR, "kohya_ss", "library"),
+ ROOT_DIR,
+ ]
+ sys.path.extend(PATHS)
+ with gr.Blocks(css=css, analytics_enabled=False) as ui:
+ with gr.Tabs(elem_id="kohya_sd_webui__root"):
+ tabs_dir = os.path.join(ROOT_DIR, "scripts", "tabs")
+ for category in os.listdir(tabs_dir):
+ dir = os.path.join(tabs_dir, category)
+ tabs = glob.glob(os.path.join(dir, "*.py"))
+ sys.path.append(dir)
+ if len(tabs) < 1:
+ continue
+ with gr.TabItem(category):
+ for lib in tabs:
+ try:
+ module_path = path_to_module(lib)
+ module_name = module_path.replace(".", "_")
+
+ module = importlib.import_module(module_path)
+ shared.current_tab = module_name
+ shared.loaded_tabs.append(module_name)
+
+ with gr.TabItem(module.title()):
+ module.create_ui()
+ except Exception as e:
+ print(f"Failed to load {module_path}")
+ print(e)
+ sys.path.remove(dir)
+ with gr.TabItem("terminal"):
+ gr.HTML('')
+ sys.path = [x for x in sys.path if x not in PATHS]
+ return ui
diff --git a/kohya-sd-scripts-webui/scripts/ui_overrides.py b/kohya-sd-scripts-webui/scripts/ui_overrides.py
new file mode 100644
index 0000000000000000000000000000000000000000..49d9cf14b1edef370b86b9a46e462aa5868a029b
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/ui_overrides.py
@@ -0,0 +1,26 @@
+OPTIMIZER_OPTIONS = {
+ "optimizer_type": {
+ "type": list,
+ "choices": [
+ "AdamW",
+ "AdamW8bit",
+ "Lion",
+ "SGDNesterov",
+ "SGDNesterov8bit",
+ "DAdaptation",
+ "AdaFactor",
+ ],
+ },
+ "lr_scheduler": {
+ "type": list,
+ "choices": [
+ "linear",
+ "cosine",
+ "cosine_with_restarts",
+ "polynomial",
+ "constant",
+ "constant_with_warmup",
+ "adafactor",
+ ],
+ },
+}
diff --git a/kohya-sd-scripts-webui/scripts/utilities.py b/kohya-sd-scripts-webui/scripts/utilities.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc76350858e40daceb0e1f623996d9154c9dcbcc
--- /dev/null
+++ b/kohya-sd-scripts-webui/scripts/utilities.py
@@ -0,0 +1,285 @@
+import ast
+import importlib
+import os
+import subprocess
+import sys
+
+import gradio as gr
+
+import scripts.shared as shared
+from scripts.shared import ROOT_DIR
+
+python = sys.executable
+
+
+def path_to_module(filepath):
+ return (
+ os.path.relpath(filepath, ROOT_DIR).replace(os.path.sep, ".").replace(".py", "")
+ )
+
+
+def which(program):
+ def is_exe(fpath):
+ return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
+
+ fpath, _ = os.path.split(program)
+ if fpath:
+ if is_exe(program):
+ return program
+ else:
+ for path in os.environ["PATH"].split(os.pathsep):
+ path = path.strip('"')
+ exe_file = os.path.join(path, program)
+ if is_exe(exe_file):
+ return exe_file
+
+ return None
+
+
+def literal_eval(v, module=None):
+ if v == "str":
+ return str
+ elif v == "int":
+ return int
+ elif v == "float":
+ return float
+ elif v == list:
+ return list
+ else:
+ if module:
+ try:
+ m = importlib.import_module(module)
+ if hasattr(m, v):
+ return getattr(m, v)
+ except:
+ ()
+
+ return ast.literal_eval(v)
+
+
+def compile_arg_parser(txt, module_path=None):
+ in_parser = False
+ parsers = {}
+ args = []
+ arg = ""
+ in_list = False
+ in_str = None
+
+ def compile(arg):
+ arg = arg.strip()
+ matches = arg.split("=")
+
+ if len(matches) > 1:
+ k = "".join(matches[:1])
+ v = literal_eval("".join(matches[1:]), module_path)
+ return (k, v)
+ else:
+ return literal_eval(arg, module_path)
+
+ for line in txt.split("\n"):
+ line = line.split("#")[0]
+
+ if "parser.add_argument(" in line:
+ in_parser = True
+ line = line.replace("parser.add_argument(", "")
+
+ if not in_parser:
+ continue
+
+ for c in line:
+
+ if in_str is None and c == ")":
+ if arg.strip():
+ args.append(compile(arg))
+ in_parser = False
+ [dest, *others] = args
+ parsers[dest] = {"dest": dest.replace("--", ""), **dict(others)}
+ arg = ""
+ args = []
+ break
+
+ if c == "[":
+ in_list = True
+ elif c == "]":
+ in_list = False
+ if c == '"' or c == "'":
+ if in_str is not None and in_str == c:
+ in_str = None
+ elif in_str is None:
+ in_str = c
+
+ if c == "," and not in_list and in_str is None:
+ args.append(compile(arg))
+ arg = ""
+ continue
+
+ arg += c
+
+ if arg.strip():
+ args.append(compile(arg))
+ return parsers
+
+
+def load_args_template(*filename):
+ repo_dir = os.path.join(ROOT_DIR, "kohya_ss")
+ filepath = os.path.join(repo_dir, *filename)
+ with open(filepath, mode="r", encoding="utf-8_sig") as f:
+ lines = f.readlines()
+ add = False
+ txt = ""
+ for line in lines:
+ if add == True:
+ txt += line
+ if "def setup_parser()" in line:
+ add = True
+ continue
+ return compile_arg_parser(txt, path_to_module(filepath)), filepath
+
+
+def check_key(d, k):
+ return k in d and d[k] is not None
+
+
+def get_arg_type(d):
+ if check_key(d, "choices"):
+ return list
+ if check_key(d, "type"):
+ return d["type"]
+ if check_key(d, "action") and (
+ d["action"] == "store_true" or d["action"] == "store_false"
+ ):
+ return bool
+ if check_key(d, "const") and type(d["const"]) == bool:
+ return bool
+ return str
+
+
+def options_to_gradio(options, out, overrides={}):
+ for _, item in options.items():
+ item = item.__dict__ if hasattr(item, "__dict__") else item
+ key = item["dest"]
+ if key == "help":
+ continue
+ override = overrides[key] if key in overrides else {}
+ component = None
+
+ help = item["help"] if "help" in item else ""
+ id = f"kohya_sd_webui__{shared.current_tab.replace('.', '_')}_{key}"
+ type = override["type"] if "type" in override else get_arg_type(item)
+ if type == list:
+ choices = [
+ c if c is not None else "None"
+ for c in (
+ override["choices"] if "choices" in override else item["choices"]
+ )
+ ]
+ component = gr.Radio(
+ choices=choices,
+ value=item["default"] if check_key(item, "default") else choices[0],
+ label=key,
+ elem_id=id,
+ interactive=True,
+ )
+ elif type == bool:
+ component = gr.Checkbox(
+ value=item["default"] if check_key(item, "default") else False,
+ label=key,
+ elem_id=id,
+ interactive=True,
+ )
+ else:
+ component = gr.Textbox(
+ value=item["default"] if check_key(item, "default") else "",
+ label=key,
+ elem_id=id,
+ interactive=True,
+ ).style()
+
+ shared.help_title_map[id] = help
+ out[key] = component
+
+
+def args_to_gradio(args, out, overrides={}):
+ options_to_gradio(args.__dict__["_option_string_actions"], out, overrides)
+
+
+def gradio_to_args(arguments, options, args, strarg=False):
+ def find_arg(key):
+ for k, arg in arguments.items():
+ arg = arg.__dict__ if hasattr(arg, "__dict__") else arg
+ if arg["dest"] == key:
+ return k, arg
+ return None, None
+
+ def get_value(key):
+ item = args[options[key]]
+ raw_key, arg = find_arg(key)
+ arg_type = get_arg_type(arg)
+ multiple = "nargs" in arg and arg["nargs"] == "*"
+
+ def set_type(x):
+ if x is None or x == "None":
+ return None
+ elif arg_type is None:
+ return x
+ elif arg_type == list:
+ return x
+ return arg_type(x)
+
+ if multiple and item is None or item == "":
+ return raw_key, None
+
+ return raw_key, (
+ [set_type(x) for x in item.split(" ")] if multiple else set_type(item)
+ )
+
+ if strarg:
+ main = []
+ optional = {}
+
+ for k in options:
+ key, v = get_value(k)
+ if key.startswith("--"):
+ key = k.replace("--", "")
+ optional[key] = v
+ else:
+ main.append(v)
+
+ main = [x for x in main if x is not None]
+
+ return main, optional
+ else:
+ result = {}
+ for k in options:
+ _, v = get_value(k)
+ result[k] = v
+ return result
+
+
+def make_args(d):
+ arguments = []
+ for k, v in d.items():
+ if type(v) == bool:
+ arguments.append(f"--{k}" if v else "")
+ elif type(v) == list and len(v) > 0:
+ arguments.extend([f"--{k}", *v])
+ elif type(v) == str and v:
+ arguments.extend([f"--{k}", f"{v}"])
+ elif v:
+ arguments.extend([f"--{k}", f"{v}"])
+ return arguments
+
+
+def run_python(script, templates, options, args):
+ main, optional = gradio_to_args(templates, options, args, strarg=True)
+ args = [x for x in [*main, *make_args(optional)] if x]
+ proc_args = [python, "-u", script, *args]
+ print("Start process: ", " ".join(proc_args))
+
+ ps = subprocess.Popen(
+ proc_args,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ cwd=os.path.join(ROOT_DIR, "kohya_ss"),
+ )
+ return ps
diff --git a/kohya-sd-scripts-webui/style.css b/kohya-sd-scripts-webui/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..4ce5f9cab2ec1ba78a431ff98bfd809a1db40679
--- /dev/null
+++ b/kohya-sd-scripts-webui/style.css
@@ -0,0 +1,8 @@
+#kohya_sd_webui__terminal_outputs {
+ height: 80vh;
+ overflow-y: auto;
+}
+
+button[id^='kohya_sd_webui__'][id$='_stop_button'] {
+ display: none;
+}
\ No newline at end of file
diff --git a/kohya-sd-scripts-webui/sub.py b/kohya-sd-scripts-webui/sub.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5f2fa2e9342de5a99e1308db5feddc8a65f68a3
--- /dev/null
+++ b/kohya-sd-scripts-webui/sub.py
@@ -0,0 +1,8 @@
+from tqdm import tqdm
+from time import sleep
+
+progress_bar = tqdm(range(10), smoothing=0, desc="steps")
+for i in range(10):
+ sleep(5)
+ progress_bar.update(1)
+ progress_bar.set_postfix({"log": f"sleeping 5 sec {i}"})
diff --git a/kohya-sd-scripts-webui/update.bat b/kohya-sd-scripts-webui/update.bat
new file mode 100644
index 0000000000000000000000000000000000000000..62427efb04195ca4c87025040a2d5608fc3f15dc
--- /dev/null
+++ b/kohya-sd-scripts-webui/update.bat
@@ -0,0 +1,4 @@
+@echo off
+git fetch --prune
+git reset --hard origin/main
+pause
\ No newline at end of file
diff --git a/kohya-sd-scripts-webui/update.sh b/kohya-sd-scripts-webui/update.sh
new file mode 100644
index 0000000000000000000000000000000000000000..598f59256d651d64c7a22181c2dacdcf20eac83a
--- /dev/null
+++ b/kohya-sd-scripts-webui/update.sh
@@ -0,0 +1,2 @@
+git fetch --prune
+git reset --hard origin/main
\ No newline at end of file
diff --git a/kohya-sd-scripts-webui/webui.bat b/kohya-sd-scripts-webui/webui.bat
new file mode 100644
index 0000000000000000000000000000000000000000..7f7c7687a1bfb6024cecfc6daec2b80014f09e0d
--- /dev/null
+++ b/kohya-sd-scripts-webui/webui.bat
@@ -0,0 +1,74 @@
+@echo off
+
+if not defined PYTHON (set PYTHON=python)
+if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
+
+set ERROR_REPORTING=FALSE
+
+mkdir tmp 2>NUL
+
+%PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :start_venv
+echo Couldn't launch python
+goto :show_stdout_stderr
+
+:start_venv
+if ["%VENV_DIR%"] == ["-"] goto :skip_venv
+
+dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :activate_venv
+
+for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i"
+echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%
+%PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt
+if %ERRORLEVEL% == 0 goto :activate_venv
+echo Unable to create venv in directory "%VENV_DIR%"
+goto :show_stdout_stderr
+
+:activate_venv
+set PYTHON="%VENV_DIR%\Scripts\Python.exe"
+echo venv %PYTHON%
+if [%ACCELERATE%] == ["True"] goto :accelerate
+goto :launch
+
+:skip_venv
+
+:accelerate
+echo "Checking for accelerate"
+set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe"
+if EXIST %ACCELERATE% goto :accelerate_launch
+
+:launch
+%PYTHON% launch.py %*
+pause
+exit /b
+
+:accelerate_launch
+echo "Accelerating"
+%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py
+pause
+exit /b
+
+:show_stdout_stderr
+
+echo.
+echo exit code: %errorlevel%
+
+for /f %%i in ("tmp\stdout.txt") do set size=%%~zi
+if %size% equ 0 goto :show_stderr
+echo.
+echo stdout:
+type tmp\stdout.txt
+
+:show_stderr
+for /f %%i in ("tmp\stderr.txt") do set size=%%~zi
+if %size% equ 0 goto :show_stderr
+echo.
+echo stderr:
+type tmp\stderr.txt
+
+:endofscript
+
+echo.
+echo Launch unsuccessful. Exiting.
+pause
diff --git a/kohya-sd-scripts-webui/webui.sh b/kohya-sd-scripts-webui/webui.sh
new file mode 100644
index 0000000000000000000000000000000000000000..05ae26b3a7394c8d0cd474777fbc7fd994b193dd
--- /dev/null
+++ b/kohya-sd-scripts-webui/webui.sh
@@ -0,0 +1,58 @@
+# python3 executable
+if [[ -z "${python_cmd}" ]]
+then
+ python_cmd="python3"
+fi
+
+# git executable
+if [[ -z "${GIT}" ]]
+then
+ export GIT="git"
+fi
+
+# python3 venv without trailing slash
+if [[ -z "${venv_dir}" ]]
+then
+ venv_dir="venv"
+fi
+
+if [[ -z "${LAUNCH_SCRIPT}" ]]
+then
+ LAUNCH_SCRIPT="launch.py"
+fi
+
+# this script cannot be run as root by default
+can_run_as_root=0
+delimiter="################################################################"
+
+printf "\n%s\n" "${delimiter}"
+printf "Create and activate python venv"
+printf "\n%s\n" "${delimiter}"
+if [[ ! -d "${venv_dir}" ]]
+then
+ "${python_cmd}" -m venv "${venv_dir}"
+ first_launch=1
+fi
+# shellcheck source=/dev/null
+if [[ -f "${venv_dir}"/bin/activate ]]
+then
+ source "${venv_dir}"/bin/activate
+else
+ printf "\n%s\n" "${delimiter}"
+ printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m"
+ printf "\n%s\n" "${delimiter}"
+ exit 1
+fi
+
+if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]
+then
+ printf "\n%s\n" "${delimiter}"
+ printf "Accelerating launch.py..."
+ printf "\n%s\n" "${delimiter}"
+ exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@"
+else
+ printf "\n%s\n" "${delimiter}"
+ printf "Launching launch.py..."
+ printf "\n%s\n" "${delimiter}"
+ exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
+fi
diff --git a/put extensions here.txt b/put extensions here.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sd-civitai-browser-plus/.github/FUNDING.yml b/sd-civitai-browser-plus/.github/FUNDING.yml
new file mode 100644
index 0000000000000000000000000000000000000000..6d0af7885427407541e59b781dc19c2872b6915b
--- /dev/null
+++ b/sd-civitai-browser-plus/.github/FUNDING.yml
@@ -0,0 +1,3 @@
+# These are supported funding model platforms
+
+custom: ["https://www.paypal.me/JeJongen"]
diff --git a/sd-civitai-browser-plus/.github/ISSUE_TEMPLATE/bug_report.yml b/sd-civitai-browser-plus/.github/ISSUE_TEMPLATE/bug_report.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d55765a7a1a78d703c7a031830029fab0dfafd15
--- /dev/null
+++ b/sd-civitai-browser-plus/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -0,0 +1,60 @@
+name: 🐛 Bug report
+description: If something isn't working as expected, please report a bug here.
+title: "[Bug]: "
+labels: ["bug"]
+
+body:
+ - type: markdown
+ attributes:
+ value: |
+ *Please fill this form with as much information as possible!*
+ - type: textarea
+ id: what-did
+ attributes:
+ label: Describe the bug.
+ description: A clear and concise description of what the bug is.
+ validations:
+ required: true
+ - type: textarea
+ id: steps
+ attributes:
+ label: Steps to reproduce the problem.
+ description: Precise step by step instructions on how to reproduce the bug.
+ value: |
+ 1. Go to ....
+ 2. Press ....
+ 3. ...
+ validations:
+ required: true
+ - type: textarea
+ id: what-should
+ attributes:
+ label: Expected behavior
+ description: A clear and concise description of what you expected to happen.
+ validations:
+ required: true
+ - type: textarea
+ id: sysinfo
+ attributes:
+ label: System info
+ description: Information about your system and the versions that were used.
+ value: |
+ * Extension version:
+ * OS:
+ * SD-WebUI version:
+ * Python:
+ validations:
+ required: true
+ - type: textarea
+ id: logs
+ attributes:
+ label: Console logs
+ description: Please share the complete cmd/terminal logs from the time the error occurred, ensuring you include all associated error messages.
+ render: Shell
+ validations:
+ required: true
+ - type: textarea
+ id: misc
+ attributes:
+ label: Additional information
+ description: Please provide any relevant additional info or context.
diff --git a/sd-civitai-browser-plus/.github/ISSUE_TEMPLATE/config.yml b/sd-civitai-browser-plus/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..95d64503d645d008eaac35b1985cfe4c6dbb2a02
--- /dev/null
+++ b/sd-civitai-browser-plus/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1,8 @@
+blank_issues_enabled: false
+contact_links:
+ - name: "🌟 Feature Request"
+ url: "https://github.com/BlafKing/sd-civitai-browser-plus/discussions/new?category=ideas"
+ about: "Feature requests are handled in the Discussions tab under 'Ideas'."
+ - name: "❓ Question"
+ url: "https://github.com/BlafKing/sd-civitai-browser-plus/discussions/new?category=q-a"
+ about: "If you have questions, please ask them in the Discussions tab under 'Q&A'."
diff --git a/sd-civitai-browser-plus/.gitignore b/sd-civitai-browser-plus/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..d0683b4e1f5027ec96118b222cb330c864c5a0d6
--- /dev/null
+++ b/sd-civitai-browser-plus/.gitignore
@@ -0,0 +1,3 @@
+__pycache__
+.vscode
+running
\ No newline at end of file
diff --git a/sd-civitai-browser-plus/LICENSE b/sd-civitai-browser-plus/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..0ad25db4bd1d86c452db3f9602ccdbe172438f52
--- /dev/null
+++ b/sd-civitai-browser-plus/LICENSE
@@ -0,0 +1,661 @@
+ GNU AFFERO GENERAL PUBLIC LICENSE
+ Version 3, 19 November 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU Affero General Public License is a free, copyleft license for
+software and other kinds of works, specifically designed to ensure
+cooperation with the community in the case of network server software.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+our General Public Licenses are intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ Developers that use our General Public Licenses protect your rights
+with two steps: (1) assert copyright on the software, and (2) offer
+you this License which gives you legal permission to copy, distribute
+and/or modify the software.
+
+ A secondary benefit of defending all users' freedom is that
+improvements made in alternate versions of the program, if they
+receive widespread use, become available for other developers to
+incorporate. Many developers of free software are heartened and
+encouraged by the resulting cooperation. However, in the case of
+software used on network servers, this result may fail to come about.
+The GNU General Public License permits making a modified version and
+letting the public access it on a server without ever releasing its
+source code to the public.
+
+ The GNU Affero General Public License is designed specifically to
+ensure that, in such cases, the modified source code becomes available
+to the community. It requires the operator of a network server to
+provide the source code of the modified version running there to the
+users of that server. Therefore, public use of a modified version, on
+a publicly accessible server, gives the public access to the source
+code of the modified version.
+
+ An older license, called the Affero General Public License and
+published by Affero, was designed to accomplish similar goals. This is
+a different license, not a version of the Affero GPL, but Affero has
+released a new version of the Affero GPL which permits relicensing under
+this license.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU Affero General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Remote Network Interaction; Use with the GNU General Public License.
+
+ Notwithstanding any other provision of this License, if you modify the
+Program, your modified version must prominently offer all users
+interacting with it remotely through a computer network (if your version
+supports such interaction) an opportunity to receive the Corresponding
+Source of your version by providing access to the Corresponding Source
+from a network server at no charge, through some standard or customary
+means of facilitating copying of software. This Corresponding Source
+shall include the Corresponding Source for any work covered by version 3
+of the GNU General Public License that is incorporated pursuant to the
+following paragraph.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the work with which it is combined will remain governed by version
+3 of the GNU General Public License.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU Affero General Public License from time to time. Such new versions
+will be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU Affero General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU Affero General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU Affero General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published
+ by the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for the
+specific requirements.
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU AGPL, see
+.
diff --git a/sd-civitai-browser-plus/README.md b/sd-civitai-browser-plus/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1510758abb6714c45fdc495cdef190ecbc3774c3
--- /dev/null
+++ b/sd-civitai-browser-plus/README.md
@@ -0,0 +1,622 @@
+
+![CivitAI Browser-05+](https://github.com/BlafKing/sd-civitai-browser-plus/assets/9644716/95afcc41-56f0-4398-8779-51cb2a9e2f55)
+
+---
+### Extension for [Automatic1111's Stable Difussion Web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
+
+
+
Features 🚀
+
Browse all models from CivitAI 🧩
+
+* Explore a wide range of models at your fingertips.
+
+
Check for updates and installed models 🔄
+
+* Easily spot new updates and identify already installed models while browsing.
+* Ability to scan all installed models for available updates.
+
+
Download any Model, any version, and any file 📥
+
+* Get the specific model version and file you need hassle-free.
+* Download queue to avoid waiting for finished downloads.
+
+
Automatically assign tags to installed models 🏷️
+
+* Assign tags by scanning all installed models for automatic use in image generation.
+
+
Quick Model Info Access 📊
+
+* A button for each model card in txt2img and img2img to load it into the extension.
+* A button under each image in model info to send it's generation info to txt2img.
+
+
High-speed downloads with Aria2 🚄
+
+* Maximize your bandwidth for lightning-fast downloads.
+
+
Sleek and Intuitive User Interface 🖌️
+
+* Enjoy a clutter-free, user-friendly interface, designed to enhance your experience.
+
+
Actively maintained with feature requests welcome 🛠️
+
+* Feel free to send me your feature requests, and I'll do my best to implement them!
+
+
+
+
+
Known Issues 🐛
+
+
Unable to download / Frozen download:
+
+**If you're experiencing issues with broken or frozen downloads, there are two possible solutions you can try:**
+
+1. **Revert to the old download method**:
+ A solution could be to disable the "Download models using Aria2" feature.
+This will switch back to the old download method, which may resolve the issue.
+
+ ![Revert to old download method](https://github.com/BlafKing/sd-civitai-browser-plus/assets/9644716/982b0ebb-0cac-4053-8060-285533e0e176)
+
+2. **Disable Async DNS for Aria2**:
+ If you're using a DNS manager program like PortMaster, try turning on the "Disable Async DNS for Aria2" option.
+
+ ![Disable Async DNS for Aria2](https://github.com/BlafKing/sd-civitai-browser-plus/assets/9644716/3cf7fab3-0df5-4995-9543-d9824b7931ff)
+
+These settings can be found under the "Settings" tab in Web-UI and then under the "CivitAI Browser+" tile.
+
+
+
+
+
+# How to install 📘
+
+
+
+1. Download the latest version from this site and unpack the .zip
+![2023-09-25 13_06_31](https://github.com/BlafKing/sd-civitai-browser-plus/assets/9644716/12e46c6b-74b5-4ed5-bf55-cb76c5f75c62)
+
+2. Navigate to your extensions folder (Your SD folder/webui/extensions)
+3. Place the unpacked folder inside the extensions folder
+4. Restart SD-WebUI
+
+# Preview 👀
+
+https://github.com/BlafKing/sd-civitai-browser-plus/assets/9644716/44c5c7a0-4854-4043-bfbb-f32fa9df5a74
+
+
+# Star History 🌟
+
+
+
+
+
+# Changelog 📋
+
+
v3.4.0
+
+* Feature: (BETA) Download queue! rearrange download order and remove models from queue
+ - Will likely contain bugs, still not completely finished.
+* Feature: Customizable sub folder insertion options, choose what sub folder options you want!
+* New setting: Toggle per prompt example image buttons
+* New setting: Insert sub folder options
+* Bug fix: Add to queue fixed, now properly gets enabled.
+* Bug fix: Symlinks now get correctly recognized and used.
+* Bug fix: No longer creates accidental sub folder when bulk downloading.
+
+---
+
v3.3.1
+
+* Feature: Ability to send individual parts of image generation data to txt2img.
+* Feature: Added compatibility for [stable-diffusion-webui-forge](https://github.com/lllyasviel/stable-diffusion-webui-forge) fork.
+* New setting: Use local images in the HTML
+ - Does not work in combination with the "Use local HTML file for model info" option!
+* New setting: Store the HTML and api_info in the custom images location
+* Bug fix: New HTML model info now scales with width so it should always fit.
+* Bug fix: Various bug fixes to the "Update model info & tags" function.
+* Bug fix: Auto save all images now uses correctly uses custom image path if set.
+* Bug fix: "Save model info" button should no longer return errors.
+* Bug fix: Old download method (non Aria2) should now work again.
+
+---
+
v3.3.0
+
+* Feature: New txt2img and img2img model info overlay on CivitAI button press.
+* Feature: Base Model as sub folder option.
+* Feature: Ability to multi-download to selected folder.
+* Feature: Use the same folder as older versions when updating using multi-download.
+* Feature: txt2img and img2img CivitAI buttons can use local HTML file, toggle in settings.
+* Note: Save images no longer saves .html and API info, save model info does this instead now.
+* New setting: Save API info of model when saving model info.
+* New setting: Automatically save all images after download.
+* New setting: Use local HTML file for model info.
+* Bug fix: better JSON decode, now forces UTF-8
+* Bug fix: Now uses the proper default file when using multi-download
+* Bug fix: Hide early access models fix, now works when published_at does not exist in API.
+* Bug fix: Fix attempt for queue clearing upon download fail.
+
+---
+
v3.2.5
+
+* Bug fix: Removed default API Key since it gets blocked after many downloads.
+ - Because of this it's now required for some downloads to use a personal CivitAI key, this can be set in the the settings tab of SD-WebUI under the CivitAI Browser+ tab.
+* Bug fix: Fixed bug when selecting a model from txt2img/img2img that doesn't exist on CivitAI.
+* Bug fix: Changed model selection to Model ID instead of model name
+ - This previously caused issues when 2 models were named the same.
+* Bug fix: Fixed an issue where the default file was not properly used by default.
+* Bug fix: Fixed some tiles not being selectable due to having "'" in it's title
+* Bug fix: Now automatically removes residual Aria2 files.
+
+---
+
v3.2.4
+
+* Bug fix: Fix version detection for non standard SD-WebUI versions.
+* Bug fix: Retry to fetch ModelID if previously not found in update functions.
+* Bug fix: Style fix for when the Lobe theme is used in SD-WebUI
+* Bug fix: Better required packages import error catching.
+* Bug fix: Fixed CivitAI button scaling in txt2img and img2img tabs.
+* Bug fix: Added ability to handle models that have no hashes saved.
+
+---
+
v3.2.3
+
+* Bug fix: Generate hash toggle in update models was inverted (silly mistake, sry bout that)
+* Bug fix: Better error detection if no model IDs were retrieved during update functions.
+* Bug fix: Better error handling if a local model does not exist on CivitAI
+
+---
+
v3.2.2
+
+* Bug fix: Fixed an `api_response` issue in the update model functions
+* Bug fix: Reverted automatically retrieving base models to fix startup issues
+* Bug fix: Better error description if a model no longer exists on CivitAI
+* Bug fix: Primary file is now used as default file.
+* Bug fix: Search after updating models no longer returns errors.
+
+---
+
v3.2.1
+
+* Feature: Extension now automatically retrieves latest base models from CivitAI.
+* Bug fix: Hotfix for functionality with SD.Next
+
+---
+
v3.2.0
+
+* Feature: A toggle for One-Time hash generation for externally downloaded models.
+* Feature: Updated extension settings layout for SD-WebUI 1.7.0 and higher.
+* Bug fix: Set default value of Lora & LoCON combination based on SD-WebUI version.
+* Bug fix: LORA models with embedding files now get placed inside embeddings folder.
+* Bug fix: Better tile count handling to avoid issues with incorrect tile count.
+* Bug fix: Better settings saving/loading to prevent writing issues.
+
+---
+
v3.1.1
+
+* Bug fix: Early Access models now get correctly hidden/detected.
+* Bug fix: Better timeout/offline server detection for options in "Update Models" tab.
+* Bug fix: Better error detection if required packages were not installed/imported.
+* Bug fix: Download button now displays as "Add to queue" during active download.
+
+---
+
v3.1.0
+
+* Feature: Send to txt2img, Send any image in the model info to txt2img.
+* Feature: Added new Base model filters:
+ - SD 1.5 LCM, SDXL 1.0 LCM, SDXL Distilled, SDXL Turbo, SVD, SVD XT
+* Feature: Hide installed models filter toggle.
+* Feature: Better display of permissions and tags in model info.
+* New setting: Append sub folders to custom image path.
+* New setting: Toggle gif/video playback, Disable if video's are taking high CPU usage.
+* Bug fix: Better handling if hash is not found.
+
+---
+
v3.0
+
+* Feature: Download queue! Ability to add downloads to a queue. (Finally!)
+* Feature: Checkboxes to download multiple models at once.
+ - This will automatically use the first version and first file of the selected model(s).
+ - Will use the default sub folders per content type defined in sub folder settings.
+* Feature: "Select all" button to select all downloadable models at once.
+* Feature: "Open on CivitAI" button when viewing a models metadata in txt2img or img2img.
+ - Will only display if the model's info has been saved to the .json after v3.0
+* Feature: Ability to rename model filename
+ - Note that it's not recommended to change the filename since some checks rely on it.
+* Bug fix: Fixed display of saved .html files.
+* Bug fix: Removed potential illegal characters from file name/path name.
+* Bug fix: Fixed case sensitive sorting of sub folders.
+
+---
+
v2.1.0
+
+* Feature: "Overwrite any existing previews, tags or descriptions" Toggle in Update tab.
+* Feature: Added content type "All" to model scanning to select all content types.
+
+---
+
v2.0.1
+
+* Bug fix: Folders starting with "." now no longer show sub folders.
+* Bug fix: Added headers to simulate browser request. (May fix issues for users from Russia)
+
+---
+
v2.0
+
+* Feature: New button on each model card in txt2img and img2img to view it in the extension.
+
+ Preview
+
+![ezgif-3-b1f0de4dd2](https://github.com/BlafKing/sd-civitai-browser-plus/assets/9644716/536a693a-c30c-438e-a34f-1aec54e4e7ee)
+
+
+
+* Feature: NSFW toggle now properly impacts search results.
+* Feature: Ability to set [\Model Name] & [\Model Name\Version Name] as default sub folders.
+* New setting: Hide sub folders that start with a '.'
+* Bug fix: Preview HTML is now emptied when loading a new page.
+* Bug fix: Buttons now correctly display when loading new page.
+* Bug fix: Fixed compatibility with SD.Next. (again)
+* Bug fix: Emptied tags, base model, and filename upon loading new page.
+* Bug fix: Filter change detection fixed
+
+---
+
v1.16
+
+* Feature: Ability to download/update model preview images in Update Models tab.
+* Feature: "Update model tags" changed to "Update model info & tags".
+ - The option now saves tags, description and base model version.
+ - This also applies to the browser, saved tags is changed to save model info.
+* Bug fix: Archived models are now hidden since they cannot be used.
+
+---
+
v1.15.2
+
+* New setting: Custom save images location
+* New setting: Default sub folders
+ - Any sub folders you have will be able to be selected as default, per content type.
+ - If a content type doesn't appear, then it means there are no subfolders in that type.
+* Bug fix: Unreleased models caused a crash, now hidden by default since they can't be used.
+
+---
+
v1.15.1
+
+* New setting: Show console logs during update scanning.
+* Bug fix: Scan for update no longer prints incorrect info about outdated models.
+* Bug fix: Removed bad logic which triggered the same function multiple times.
+* Cleanup: Optimized functions and improved the speed of selecting models.
+
+---
+
v1.15
+
+* Feature: Filter option to show favorited models. (requires personal API key)
+* Feature: Back to top button when viewing model details.
+* New setting: Page navigation as header. (keeps page navigation always visible at the top)
+* Bug fix: Aria2 now restarts when UI is reloaded.
+* Bug fix: SHA256 error fixed if .json files don't contain it.
+* Bug fix: Cleaned up javascript code.
+
+---
+
v1.14.7
+
+* New setting: Hide early access models (EA models are only downloadable by supporters)
+* New setting: Personal CivitAI API key (Text field to insert personal API key)
+ - Useful for CivitAI supporters, you can use your own API Key to allow downloading Early Access models
+* Bug fix: Extension now works with `no gradio queue` flag.
+* Bug fix: Auto disable Aria2 on MacOS due to incompatibility.
+* Bug fix: Now properly works on SD.Next again.
+* Bug fix: Download progression and cancelling is no longer broken on old download method.
+* Bug fix: Extension now correctly downloads models where it is required to be logged in.
+* Bug fix: Extension no longer attempts to install already installed requirements.
+
+---
+
v1.14.6
+
+* Bug fix: Removed pre-load of default page, caused issues for some users.
+* Bug fix: Fixed internal model naming, caused issues when model names included '
+* Bug fix: Different host for .svg icons, caused issues with MalwareBytes.
+* Bug fix: Preview saving was broken due to passing the wrong file path.
+
+---
+
v1.14.5
+
+* Feature: Base Model filter now impacts search results.
+* Feature: Ability to input model URL into search bar to find corresponding model.
+* Bug fix: Adetailer models now get placed in the correct folder
+
+---
+
v1.14.4
+
+* Bug fix: Page slider broke the Next Page button when loaded from "Update Models".
+* Bug fix: "Save settings as default" button inserted broken .json data.
+* Bug fix: Triggering "Scan for available updates" twice resulted in an error.
+
+---
+
v1.14.3
+
+* Bug fix: LORA content type was broken when "Treat LoCon as LORA" was turned on.
+
+---
+
v1.14.2
+
+* Feature: Custom page handling when scanning models.
+* Bug fix: Model scan feature now works for large model count (+900)
+* Bug fix: Better broken .json error handling
+
+---
+
v1.14.1
+
+* Bug fix: Gifs did not display properly.
+* Bug fix: Video's no longer save as preview since they cannot be used.
+* Bug fix: Filter window was not hidden by default.
+
+---
+
v1.14
+
+* Feature: Redesign of UI.
+* Feature: New dropdown with filter settings.
+* Feature: Button to save current filter settings as default. (requires restart)
+* Feature: Tag box can now be typed in to save custom tags.
+* Feature: Delete function removes any unpacked files.
+
+---
+
v1.13
+
+* Feature: Updated available content types:
+ - Upscaler
+ - MotionModule
+ - Wildcards
+ - Workflows
+ - Other
+* Feature: Videos can now also be displayed on preview cards and in the model info.
+* Feature: Automatically scans upscaler type by looking through model's description.
+* Feature: Automatically identify correct folder for wildcards based on extension.
+* Bug fix: Version ID got saved instead of correct Model ID after download.
+
+---
+
v1.12.5
+
+* Bug fix: [Installed] tag was only assigned to latest installed version.
+* Bug fix: Folder location didn't update when selecting different version/file.
+* Bug fix: Version scanning didn't properly scan sha256 in uppercase.
+
+---
+
v1.12.4
+
+* Feature: You can now refresh by pressing Ctrl+Enter and Alt+Enter.
+* Bug fix: Auto unpack feature was unpacking unintended archives, now only unpacks .zip.
+
+---
+
v1.12.3
+
+* New setting: Option to toggle automatically unpacking .zip models.
+* Bug fix: Error wasn't catched when file path was incorrect.
+
+---
+
v1.12.2
+
+* Feature: Able to download multiple files from each version.
+* Bug fix: Models did not get deleted properly when in nested folders.
+* Bug fix: Wrong sha256 was being saved after downloading.
+* Bug fix: Wrong default folder was used when installed model got selected.
+
+---
+
v1.12.1
+
+* Feature: File deletion now uses both SHA256 and file name to detect correct file.
+* New setting: Option to toggle automatically inserting 2 default sub folders.
+* New setting: Option to toggle installing LoCON's in LORA folder.
+* Bug fix: Default file was incorrect when selecting a model.
+* Bug fix: Next Page caused an error when changing content type.
+
+---
+
v1.12
+
+* Feature: Ability to load all selected installed models into browser in Update Models tab.
+* Feature: Installed/outdated models check is now done using SHA256 + file name.
+* Feature: Ability to select multiple content Types when searching and scanning.
+* Feature: Greatly improved speed of model scanning if model ID is saved in .json
+
+---
+
v1.11.2
+
+* Feature: Redesign of model page by [ManOrMonster](https://github.com/ManOrMonster)
+* Model page changes (https://github.com/BlafKing/sd-civitai-browser-plus/pull/33)
+
+ - Redesigned the look of the model page.
+ - Added link to model page on CivitAI. Click on model name to open.
+ - Added link to uploader/creator page on CivitAI. Click creator name to open.
+ - Added CivitAI avatar display.
+ - Separate description section.
+ - First sample image is marked with data attribute and downloaded as preview image instead of grabbing first in model HTML. This guarantees that the first sample image (not avatar or image in description) is used when downloading the model.
+ - Sample images are marked with data attribute so that only they are downloaded when using "Save Images" (no description images or avatar).
+ - Removed trained tags from info since they are displayed above.
+ - Each sample image has its own section.
+ - Sample images zoom in when clicked, zoom out when clicking anywhere.
+ - Forced width is removed from sample image URLs so that nice big images can be viewed.
+ - Metadata is arranged so that the most commonly used data is at the top, no more searching for prompts.
+ - Extra metadata is in accordion labeled "More details...". This is especially useful to hide insanely large ComfyUI JSON.
+
+
+
+---
+
v1.11.1
+
+* Feature: Error detection during Aria2 downloads.
+* Bug fix: Avoid starting Aria2 RPC multiple times with better port check.
+* Bug fix: Fixed dynamic tile status updates after deleting/downloading.
+---
+
+
v1.11
+
+* Feature: Ability to scan all installed models for available updates.
+* Feature: Model ID and sha256 get saved to .json after scanning or downloading a model.
+* Bug fix: Fixed crash when base model is not found.
+* Bug fix: No longer overwrite sha256 and model ID in existing .json.
+---
+
+
v1.10.1
+
+* Bug fix: Fixed pathing for Unix systems
+* Bug fix: Extra checks to prevent deleting unintentional files.
+* Feature: Models get moved to trash instead of fully deleted.
+---
+
v1.10
+
+* Feature: Update tags for all installed models!
+* Feature: Tabs for Browsing and updating Tags.
+* Feature: Buttons to select which folders to update tags in.
+---
+
v1.9.4
+
+* Feature: Added Civit AI settings tab
+ - New setting: Disable downloading with Aria2. (will use old download method instead)
+ - New setting: Disable using Async DNS. (can fix issues for some users who use DNS managing programs)
+ - New setting: Show Aria2 logs in the CMD.
+ - New setting: Set the amount of connections when downloading a model with Aria2.
+ (The optimal connection count is different per user, try to find the lowest option which still gives you full bandwidth speed)
+---
+
v1.9.3
+
+* Feature: Included Motrix Aria2 version.
+* Feature: Max connections per server set to 64 and split file set 64.
+* Feature: Aria2 is now shipped with this extension for Linux as well.
+---
+
+
v1.9.2
+
+* Cleanup: Split up script into multiple files for improved oversight/readability.
+* Cleanup: Centered model icons
+---
+
+
v1.9.1
+
+* Bug fix: Added back old download function if aria2 fails.
+---
+
+
v1.9
+
+* Feature: Faster downloads by using Aria2.
+* Feature: More info about current download: Speed, ETA, File Size and % completion.
+---
+
+
v1.8.1
+
+* Feature: Sub Folder list now contains 2 default options: `/{Model name}` & `/{Model name}/{Version name}`
+---
+
+
v1.8
+
+* Feature: Ability to download different file types per version.
+* Feature: NSFW Toggle is now dynamic.
+* Feature: Version list now dynamically updates after download.
+* Cleanup: Rearranged/Resized UI elements.
+* Bug fix: Downloading models now uses file ID instead of names.
+* Bug fix: NSFW Toggle no longer hides images tagged as "Soft".
+* Bug fix: Fixed each model load running twice.
+---
+
+
v1.7.2
+
+* Bug fix: Download button did not get re-enabled properly.
+* Bug fix: Tile status did not get updated properly when download failed.
+---
+
+
v1.7.1
+
+* Feature: Base Model filtering dims tiles instead of hiding.
+* Bug fix: NSFW Blur increases with tile size.
+* Bug fix: Dynamic tile status after installation & deletion now correctly detects other versions.
+---
+
+
v1.7
+
+* Feature: Introduced seperate download progress bar, browse while downloading.
+* Feature: no more force refresh after installing, cancelling and deleting.
+* Feature: Added toggle to sort Tiles by date.
+* Feature: Dynamic changing of tile borders after installation & deletion.
+* Removal: 'Auto delete old version' removed since it relied on a reload.
+---
+
+
v1.6
+
+* Bug fix: Page count is now always correclty read when refreshing.
+(You can fill in the page number you'd like to visit and press refresh to go to that page)
+* Feature: 'Filter Base Model' to dynamically hide any unselected Base models.
+(Please note: This does not impact search results, since the CivitAI API does not yet support this)
+---
+
+
v1.5
+
+* Feature: Slider to change tile size.
+* Feature: Download Folder textbox which can be used to define a custom download path.
+* Feature: Sub Folder Dropdown to select any available subfolder(s) as download location.
+* Feature: Display a timed out message instead of an error icon.
+* Bug fix: Nested files can now be detected as installed or outdated.
+* Bug fix: Auto selects corresponding folder of any installed models.
+* Bug fix: Better cancellation logic to prevent downloads from continuing.
+---
+
+
v1.4
+
+* Feature: Download progress bar is now on web page instead of CMD.
+* Feature: Added Cancel and Delete buttons.
+* Feature: Download button will now change according to circumstances:
+ - Cancel button if there's a current download.
+ - Delete button if the selected version is installed.
+* Cleanup: Better margin fixes with theme detection.
+* Bug fix: Delete option now also removes .json files.
+* Bug fix: Buttons are now disabled during download. (except cancel button)
+---
+
+
v1.3.1
+
+* Bug fix: Fixed tag saving bugs/oversights.
+* Bug fix: Trained tags display now do not include the model itself.
+---
+
+
v1.3
+
+* Feature: 'Save Tags' button saves tags to a .json file which gets used in image creaton.
+ (If a LORA with saved tags is used it will automatically input all tags into the txt box in image creation)
+* Feature: 'Save tags after download' toggle to automatically save .json tags.
+* Cleanup: Removed "Get model info" button.
+* Cleanup: Removed download link box.
+* Cleanup: Removed "No" from search options.
+* Cleanup: Added border radius to cards.
+* Cleanup: Improved padding based on if Lobe theme is being used.
+---
+
+
v1.2
+
+* Feature: Automatically saves preview image when downloading a model.
+* Feature: Added [installed] text suffix for any versions that are installed in the 'Version' tab.
+* Cleanup: Changed 'Model Filename' from a dropbox to a textbox.
+* Cleanup: Made bottom textboxes non typeable.
+* Cleanup: Disabled bottom buttons when no model is selected.
+* Bug fix: Margin error on the latest tile.
+* Bug fix: Version checking is now case sensitive.
+* Bug fix: Default verison in version tab shows installed version.
+---
+
+
v1.1
+
+* Feature: Dropdown box which can filter by time period.
+* Cleanup: 'Content type' changed from buttons to a dropdown box.
+* Bug fix: Fixed tiles not reloading when already pressed.
+---
+
+
v1.0
+
+* Feature: 'Refresh' now reloads the current page unless any options have been changed.
+* Feature: Made the page refresh after a download and made it load during one.
+* Feature: Orange glow for any outdated installed packages.
+* Feature: 'Delete old version after download' option.
+* Feature: Ability to manually fill in a page number to load the corresponding page.
+* Cleanup: Removed new folder option.
+* Cleanup: Made the glow around frames always visible without hovering.
+* Pulled fork from: [SignalFlagZ's Fork](https://github.com/SignalFlagZ/sd-civitai-browser) [v1.1.0](https://github.com/SignalFlagZ/sd-civitai-browser/releases/tag/1.1.0)
diff --git a/sd-civitai-browser-plus/aria2/lin/aria2 b/sd-civitai-browser-plus/aria2/lin/aria2
new file mode 100644
index 0000000000000000000000000000000000000000..b93f84f39353b2784d18c91d1525d8f12f904aa6
--- /dev/null
+++ b/sd-civitai-browser-plus/aria2/lin/aria2
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6f50816471fbd5e91c04df6cf6f995c6279d295b70187f313e6d3b04f65769fc
+size 9926088
diff --git a/sd-civitai-browser-plus/aria2/win/aria2.exe b/sd-civitai-browser-plus/aria2/win/aria2.exe
new file mode 100644
index 0000000000000000000000000000000000000000..5485d42b9ba6f54ef5dfa6c71150ea725737d493
--- /dev/null
+++ b/sd-civitai-browser-plus/aria2/win/aria2.exe
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:098a065ab71f639bf7048e790c870756fd6e83de9cc678915bbd07077d473fa2
+size 8136704
diff --git a/sd-civitai-browser-plus/install.py b/sd-civitai-browser-plus/install.py
new file mode 100644
index 0000000000000000000000000000000000000000..e81b40fd150a172e90565fc5d1d9677991c1e14a
--- /dev/null
+++ b/sd-civitai-browser-plus/install.py
@@ -0,0 +1,19 @@
+import launch
+from pathlib import Path
+
+aria2path = Path(__file__).resolve().parents[0] / "aria2"
+
+for item in aria2path.iterdir():
+ if item.is_file():
+ item.unlink()
+
+def install_req(check_name, install_name=None):
+ if not install_name: install_name = check_name
+ if not launch.is_installed(f"{check_name}"):
+ launch.run_pip(f"install {install_name}", "requirements for CivitAI Browser")
+
+install_req("send2trash")
+install_req("zip_unicode", "ZipUnicode")
+install_req("bs4", "beautifulsoup4")
+install_req("fake_useragent")
+install_req("packaging")
\ No newline at end of file
diff --git a/sd-civitai-browser-plus/javascript/Sortable.min.js b/sd-civitai-browser-plus/javascript/Sortable.min.js
new file mode 100644
index 0000000000000000000000000000000000000000..bb9953355480571864d048d1b4070237af5027d6
--- /dev/null
+++ b/sd-civitai-browser-plus/javascript/Sortable.min.js
@@ -0,0 +1,2 @@
+/*! Sortable 1.15.2 - MIT | git://github.com/SortableJS/Sortable.git */
+!function(t,e){"object"==typeof exports&&"undefined"!=typeof module?module.exports=e():"function"==typeof define&&define.amd?define(e):(t=t||self).Sortable=e()}(this,function(){"use strict";function e(e,t){var n,o=Object.keys(e);return Object.getOwnPropertySymbols&&(n=Object.getOwnPropertySymbols(e),t&&(n=n.filter(function(t){return Object.getOwnPropertyDescriptor(e,t).enumerable})),o.push.apply(o,n)),o}function I(o){for(var t=1;tt.length)&&(e=t.length);for(var n=0,o=new Array(e);n"===e[0]&&(e=e.substring(1)),t))try{if(t.matches)return t.matches(e);if(t.msMatchesSelector)return t.msMatchesSelector(e);if(t.webkitMatchesSelector)return t.webkitMatchesSelector(e)}catch(t){return}}function P(t,e,n,o){if(t){n=n||document;do{if(null!=e&&(">"!==e[0]||t.parentNode===n)&&p(t,e)||o&&t===n)return t}while(t!==n&&(t=(i=t).host&&i!==document&&i.host.nodeType?i.host:i.parentNode))}var i;return null}var g,m=/\s+/g;function k(t,e,n){var o;t&&e&&(t.classList?t.classList[n?"add":"remove"](e):(o=(" "+t.className+" ").replace(m," ").replace(" "+e+" "," "),t.className=(o+(n?" "+e:"")).replace(m," ")))}function R(t,e,n){var o=t&&t.style;if(o){if(void 0===n)return document.defaultView&&document.defaultView.getComputedStyle?n=document.defaultView.getComputedStyle(t,""):t.currentStyle&&(n=t.currentStyle),void 0===e?n:n[e];o[e=!(e in o||-1!==e.indexOf("webkit"))?"-webkit-"+e:e]=n+("string"==typeof n?"":"px")}}function v(t,e){var n="";if("string"==typeof t)n=t;else do{var o=R(t,"transform")}while(o&&"none"!==o&&(n=o+" "+n),!e&&(t=t.parentNode));var i=window.DOMMatrix||window.WebKitCSSMatrix||window.CSSMatrix||window.MSCSSMatrix;return i&&new i(n)}function b(t,e,n){if(t){var o=t.getElementsByTagName(e),i=0,r=o.length;if(n)for(;i=n.left-e&&i<=n.right+e,e=r>=n.top-e&&r<=n.bottom+e;return o&&e?a=t:void 0}}),a);if(e){var n,o={};for(n in t)t.hasOwnProperty(n)&&(o[n]=t[n]);o.target=o.rootEl=e,o.preventDefault=void 0,o.stopPropagation=void 0,e[K]._onDragOver(o)}}var i,r,a}function Bt(t){V&&V.parentNode[K]._isOutsideThisEl(t.target)}function Ft(t,e){if(!t||!t.nodeType||1!==t.nodeType)throw"Sortable: `el` must be an HTMLElement, not ".concat({}.toString.call(t));this.el=t,this.options=e=a({},e),t[K]=this;var n,o,i={group:null,sort:!0,disabled:!1,store:null,handle:null,draggable:/^[uo]l$/i.test(t.nodeName)?">li":">*",swapThreshold:1,invertSwap:!1,invertedSwapThreshold:null,removeCloneOnHide:!0,direction:function(){return Pt(t,this.options)},ghostClass:"sortable-ghost",chosenClass:"sortable-chosen",dragClass:"sortable-drag",ignore:"a, img",filter:null,preventOnFilter:!0,animation:0,easing:null,setData:function(t,e){t.setData("Text",e.textContent)},dropBubble:!1,dragoverBubble:!1,dataIdAttr:"data-id",delay:0,delayOnTouchOnly:!1,touchStartThreshold:(Number.parseInt?Number:window).parseInt(window.devicePixelRatio,10)||1,forceFallback:!1,fallbackClass:"sortable-fallback",fallbackOnBody:!1,fallbackTolerance:0,fallbackOffset:{x:0,y:0},supportPointer:!1!==Ft.supportPointer&&"PointerEvent"in window&&!u,emptyInsertThreshold:5};for(n in W.initializePlugins(this,t,i),i)n in e||(e[n]=i[n]);for(o in kt(e),this)"_"===o.charAt(0)&&"function"==typeof this[o]&&(this[o]=this[o].bind(this));this.nativeDraggable=!e.forceFallback&&Nt,this.nativeDraggable&&(this.options.touchStartThreshold=1),e.supportPointer?h(t,"pointerdown",this._onTapStart):(h(t,"mousedown",this._onTapStart),h(t,"touchstart",this._onTapStart)),this.nativeDraggable&&(h(t,"dragover",this),h(t,"dragenter",this)),Dt.push(this.el),e.store&&e.store.get&&this.sort(e.store.get(this)||[]),a(this,x())}function jt(t,e,n,o,i,r,a,l){var s,c,u=t[K],d=u.options.onMove;return!window.CustomEvent||y||w?(s=document.createEvent("Event")).initEvent("move",!0,!0):s=new CustomEvent("move",{bubbles:!0,cancelable:!0}),s.to=e,s.from=t,s.dragged=n,s.draggedRect=o,s.related=i||e,s.relatedRect=r||X(e),s.willInsertAfter=l,s.originalEvent=a,t.dispatchEvent(s),c=d?d.call(u,s,a):c}function Ht(t){t.draggable=!1}function Lt(){Tt=!1}function Kt(t){return setTimeout(t,0)}function Wt(t){return clearTimeout(t)}Ft.prototype={constructor:Ft,_isOutsideThisEl:function(t){this.el.contains(t)||t===this.el||(mt=null)},_getDirection:function(t,e){return"function"==typeof this.options.direction?this.options.direction.call(this,t,e,V):this.options.direction},_onTapStart:function(e){if(e.cancelable){var n=this,o=this.el,t=this.options,i=t.preventOnFilter,r=e.type,a=e.touches&&e.touches[0]||e.pointerType&&"touch"===e.pointerType&&e,l=(a||e).target,s=e.target.shadowRoot&&(e.path&&e.path[0]||e.composedPath&&e.composedPath()[0])||l,c=t.filter;if(!function(t){xt.length=0;var e=t.getElementsByTagName("input"),n=e.length;for(;n--;){var o=e[n];o.checked&&xt.push(o)}}(o),!V&&!(/mousedown|pointerdown/.test(r)&&0!==e.button||t.disabled)&&!s.isContentEditable&&(this.nativeDraggable||!u||!l||"SELECT"!==l.tagName.toUpperCase())&&!((l=P(l,t.draggable,o,!1))&&l.animated||tt===l)){if(ot=j(l),rt=j(l,t.draggable),"function"==typeof c){if(c.call(this,e,l,this))return q({sortable:n,rootEl:s,name:"filter",targetEl:l,toEl:o,fromEl:o}),G("filter",n,{evt:e}),void(i&&e.cancelable&&e.preventDefault())}else if(c=c&&c.split(",").some(function(t){if(t=P(s,t.trim(),o,!1))return q({sortable:n,rootEl:t,name:"filter",targetEl:l,fromEl:o,toEl:o}),G("filter",n,{evt:e}),!0}))return void(i&&e.cancelable&&e.preventDefault());t.handle&&!P(s,t.handle,o,!1)||this._prepareDragStart(e,a,l)}}},_prepareDragStart:function(t,e,n){var o,i=this,r=i.el,a=i.options,l=r.ownerDocument;n&&!V&&n.parentNode===r&&(o=X(n),Q=r,Z=(V=n).parentNode,J=V.nextSibling,tt=n,lt=a.group,ct={target:Ft.dragged=V,clientX:(e||t).clientX,clientY:(e||t).clientY},ft=ct.clientX-o.left,pt=ct.clientY-o.top,this._lastX=(e||t).clientX,this._lastY=(e||t).clientY,V.style["will-change"]="all",o=function(){G("delayEnded",i,{evt:t}),Ft.eventCanceled?i._onDrop():(i._disableDelayedDragEvents(),!s&&i.nativeDraggable&&(V.draggable=!0),i._triggerDragStart(t,e),q({sortable:i,name:"choose",originalEvent:t}),k(V,a.chosenClass,!0))},a.ignore.split(",").forEach(function(t){b(V,t.trim(),Ht)}),h(l,"dragover",Yt),h(l,"mousemove",Yt),h(l,"touchmove",Yt),h(l,"mouseup",i._onDrop),h(l,"touchend",i._onDrop),h(l,"touchcancel",i._onDrop),s&&this.nativeDraggable&&(this.options.touchStartThreshold=4,V.draggable=!0),G("delayStart",this,{evt:t}),!a.delay||a.delayOnTouchOnly&&!e||this.nativeDraggable&&(w||y)?o():Ft.eventCanceled?this._onDrop():(h(l,"mouseup",i._disableDelayedDrag),h(l,"touchend",i._disableDelayedDrag),h(l,"touchcancel",i._disableDelayedDrag),h(l,"mousemove",i._delayedDragTouchMoveHandler),h(l,"touchmove",i._delayedDragTouchMoveHandler),a.supportPointer&&h(l,"pointermove",i._delayedDragTouchMoveHandler),i._dragStartTimer=setTimeout(o,a.delay)))},_delayedDragTouchMoveHandler:function(t){t=t.touches?t.touches[0]:t;Math.max(Math.abs(t.clientX-this._lastX),Math.abs(t.clientY-this._lastY))>=Math.floor(this.options.touchStartThreshold/(this.nativeDraggable&&window.devicePixelRatio||1))&&this._disableDelayedDrag()},_disableDelayedDrag:function(){V&&Ht(V),clearTimeout(this._dragStartTimer),this._disableDelayedDragEvents()},_disableDelayedDragEvents:function(){var t=this.el.ownerDocument;f(t,"mouseup",this._disableDelayedDrag),f(t,"touchend",this._disableDelayedDrag),f(t,"touchcancel",this._disableDelayedDrag),f(t,"mousemove",this._delayedDragTouchMoveHandler),f(t,"touchmove",this._delayedDragTouchMoveHandler),f(t,"pointermove",this._delayedDragTouchMoveHandler)},_triggerDragStart:function(t,e){e=e||"touch"==t.pointerType&&t,!this.nativeDraggable||e?this.options.supportPointer?h(document,"pointermove",this._onTouchMove):h(document,e?"touchmove":"mousemove",this._onTouchMove):(h(V,"dragend",this),h(Q,"dragstart",this._onDragStart));try{document.selection?Kt(function(){document.selection.empty()}):window.getSelection().removeAllRanges()}catch(t){}},_dragStarted:function(t,e){var n;wt=!1,Q&&V?(G("dragStarted",this,{evt:e}),this.nativeDraggable&&h(document,"dragover",Bt),n=this.options,t||k(V,n.dragClass,!1),k(V,n.ghostClass,!0),Ft.active=this,t&&this._appendGhost(),q({sortable:this,name:"start",originalEvent:e})):this._nulling()},_emulateDragOver:function(){if(ut){this._lastX=ut.clientX,this._lastY=ut.clientY,Rt();for(var t=document.elementFromPoint(ut.clientX,ut.clientY),e=t;t&&t.shadowRoot&&(t=t.shadowRoot.elementFromPoint(ut.clientX,ut.clientY))!==e;)e=t;if(V.parentNode[K]._isOutsideThisEl(t),e)do{if(e[K])if(e[K]._onDragOver({clientX:ut.clientX,clientY:ut.clientY,target:t,rootEl:e})&&!this.options.dragoverBubble)break}while(e=(t=e).parentNode);Xt()}},_onTouchMove:function(t){if(ct){var e=this.options,n=e.fallbackTolerance,o=e.fallbackOffset,i=t.touches?t.touches[0]:t,r=$&&v($,!0),a=$&&r&&r.a,l=$&&r&&r.d,e=Mt&&yt&&E(yt),a=(i.clientX-ct.clientX+o.x)/(a||1)+(e?e[0]-Ct[0]:0)/(a||1),l=(i.clientY-ct.clientY+o.y)/(l||1)+(e?e[1]-Ct[1]:0)/(l||1);if(!Ft.active&&!wt){if(n&&Math.max(Math.abs(i.clientX-this._lastX),Math.abs(i.clientY-this._lastY))D.right+10||S.clientY>x.bottom&&S.clientX>x.left:S.clientY>D.bottom+10||S.clientX>x.right&&S.clientY>x.top)||m.animated)){if(m&&(t=n,e=r,C=X(B((_=this).el,0,_.options,!0)),_=L(_.el,_.options,$),e?t.clientX<_.left-10||t.clientY addOrUpdateRule(sheet, selector, rules);
+
+ toggleRule('.civcardnsfw', hideAndBlur ? 'display: block;' : 'display: none;');
+ toggleRule('.civnsfw img', hideAndBlur ? 'filter: none;' : 'filter: blur(10px);');
+
+ const dateSections = document.querySelectorAll('.date-section');
+ dateSections.forEach((section) => {
+ const cards = section.querySelectorAll('.civmodelcard');
+ const nsfwCards = section.querySelectorAll('.civmodelcard.civcardnsfw');
+ section.style.display = !hideAndBlur && cards.length === nsfwCards.length ? 'none' : 'block';
+ });
+
+}
+
+// Updates site with css insertions
+function addOrUpdateRule(styleSheet, selector, newRules) {
+ for (let i = 0; i < styleSheet.cssRules.length; i++) {
+ let rule = styleSheet.cssRules[i];
+ if (rule.selectorText === selector) {
+ rule.style.cssText = newRules;
+ return;
+ }
+ }
+ styleSheet.insertRule(`${selector} { ${newRules} }`, styleSheet.cssRules.length);
+}
+
+// Updates card border
+function updateCard(modelNameWithSuffix) {
+ const lastDotIndex = modelNameWithSuffix.lastIndexOf('.');
+ const modelName = modelNameWithSuffix.slice(0, lastDotIndex);
+ const suffix = modelNameWithSuffix.slice(lastDotIndex + 1);
+ let additionalClassName = '';
+ switch(suffix) {
+ case 'None':
+ additionalClassName = '';
+ break;
+ case 'Old':
+ additionalClassName = 'civmodelcardoutdated';
+ break;
+ case 'New':
+ additionalClassName = 'civmodelcardinstalled';
+ break;
+ default:
+ return;
+ }
+ const parentDiv = document.querySelector('.civmodellist');
+ if (parentDiv) {
+ const cards = parentDiv.querySelectorAll('.civmodelcard');
+ cards.forEach((card) => {
+ const onclickAttr = card.getAttribute('onclick');
+ if (onclickAttr && onclickAttr.includes(`select_model('${modelName}', event)`)) {
+ card.className = `civmodelcard ${additionalClassName}`;
+ }
+ });
+ }
+}
+
+// Enables refresh with alt+enter and ctrl+enter
+document.addEventListener('keydown', function(e) {
+ var handled = false;
+
+ if (e.key !== undefined) {
+ if ((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
+ } else if (e.keyCode !== undefined) {
+ if ((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
+ }
+
+ if (handled) {
+ var currentTabContent = get_uiCurrentTabContent();
+ if (currentTabContent && currentTabContent.id === "tab_civitai_interface") {
+
+ var refreshButton = currentTabContent.querySelector('#refreshBtn');
+ if (!refreshButton) {
+ refreshButton = currentTabContent.querySelector('#refreshBtnL');
+ }
+ if (refreshButton) {
+ refreshButton.click();
+ }
+
+ e.preventDefault();
+ }
+ }
+});
+
+// Function for the back to top button
+function BackToTop() {
+ const c = Math.max(document.body.scrollTop, document.documentElement.scrollTop);
+ if (c > 0) {
+ window.requestAnimationFrame(BackToTop);
+ document.body.scrollTop = c - c / 8;
+ document.documentElement.scrollTop = c - c / 8;
+ }
+}
+
+// Function to adjust alignment of Filter Accordion
+function adjustFilterBoxAndButtons() {
+ const element = document.querySelector("#filterBox") || document.querySelector("#filterBoxL");
+ if (!element) return;
+
+ const childDiv = element.querySelector("div:nth-child(3)");
+ if (!childDiv) return;
+
+ const isLargeScreen = window.innerWidth >= 1250;
+ const isMediumScreen = window.innerWidth < 1250 && window.innerWidth > 915;
+ const isNarrowScreen = window.innerWidth < 800;
+ const modelBlocks = document.querySelectorAll("#civitai_preview_html .model-block");
+ const civitInfo = document.querySelector(".civitai-version-info");
+
+ if (modelBlocks) {
+ modelBlocks.forEach(modelBlock => {
+ if (isNarrowScreen) {
+ modelBlock.style.flexWrap = "wrap";
+ modelBlock.style.justifyContent = "center";
+ } else {
+ modelBlock.style.flexWrap = "nowrap";
+ modelBlock.style.justifyContent = "flex-start";
+ }
+ });
+ } if (civitInfo) {
+ if (window.innerWidth < 900) {
+ civitInfo.style.flexWrap = "wrap";
+ } else {
+ civitInfo.style.flexWrap = "nowrap";
+ }
+ }
+
+
+ childDiv.style.marginLeft = isLargeScreen ? "0px" : isMediumScreen ? `${1250 - window.innerWidth}px` : "0px";
+ element.style.justifyContent = isLargeScreen || isMediumScreen ? "center" : "flex-start";
+
+ const pageBtn1 = document.querySelector("#pageBtn1");
+ const pageBtn2 = document.querySelector("#pageBtn2");
+ const pageBox = document.querySelector("#pageBox");
+ const pageBoxMobile = document.querySelector("#pageBoxMobile");
+
+ if (window.innerWidth < 530) {
+ childDiv.style.width = "300px";
+ if (pageBoxMobile) {
+ pageBtn1 && pageBoxMobile.appendChild(pageBtn1);
+ pageBtn2 && pageBoxMobile.appendChild(pageBtn2);
+ pageBoxMobile.style.paddingBottom = "15px";
+ }
+ } else {
+ childDiv.style.width = "400px";
+ if (pageBox) {
+ pageBtn1 && pageBox.insertBefore(pageBtn1, pageBox.firstChild);
+ pageBtn2 && pageBox.appendChild(pageBtn2);
+ pageBoxMobile.style.paddingBottom = "0px";
+ }
+ }
+}
+
+// Calls the function above whenever the window is resized
+window.addEventListener("resize", adjustFilterBoxAndButtons);
+
+// Function to trigger refresh button with extra checks for page slider
+function pressRefresh() {
+ setTimeout(() => {
+ const input = document.querySelector("#pageSlider > div:nth-child(2) > div > input");
+ if (document.activeElement === input) {
+ input.addEventListener('keydown', function(event) {
+ if (event.key === 'Enter' || event.keyCode === 13) {
+ input.blur();
+ }
+ });
+ input.addEventListener('blur', function() {
+ return;
+ });
+
+ return;
+ }
+
+ let button = document.querySelector("#refreshBtn");
+ if (!button) {
+ button = document.querySelector("#refreshBtnL");
+ }
+ if (button) {
+ button.click();
+ } else {
+ console.error("Both buttons with IDs #refreshBtn and #refreshBtnL not found.");
+ }
+ }, 200);
+}
+
+// Update SVG Icons based on dark theme or light theme
+function updateSVGIcons() {
+ const isDark = document.body.classList.contains('dark');
+ const filterIconUrl = isDark ? "https://gistcdn.githack.com/BlafKing/a20124cedafad23d4eecc1367ec22896/raw/04a4dae0771353377747dadf57c91d55bf841bed/filter-light.svg" : "https://gistcdn.githack.com/BlafKing/686c3438f5d0d13e7e47135f25445ef3/raw/46477777faac7209d001829a171462d9a2ff1467/filter-dark.svg";
+ const searchIconUrl = isDark ? "https://gistcdn.githack.com/BlafKing/3f95619089bac3b4fd5470a986e1b3bb/raw/ebaa9cceee3436711eb560a7a65e151f1d651c6a/search-light.svg" : "https://gistcdn.githack.com/BlafKing/57573592d5857e102a4bfde852f62639/raw/aa213e9e82d705651603507e26545eb0ffe60c90/search-dark.svg";
+
+ const element = document.querySelector("#filterBox, #filterBoxL");
+ const childDiv = element?.querySelector("div:nth-child(3)");
+
+ if (childDiv) {
+ childDiv.style.cssText = `box-shadow: ${isDark ? '#ffffff' : '#000000'} 0px 0px 2px 0px; display: none;`;
+ }
+
+ const style = document.createElement('style');
+ style.innerHTML = `
+ #filterBox > div:nth-child(2) > span:nth-child(2)::before,
+ #filterBoxL > div:nth-child(2) > span:nth-child(2)::before {
+ background: url('${filterIconUrl}') no-repeat center center;
+ background-size: contain;
+ }
+ `;
+ document.head.appendChild(style);
+
+ const refreshBtn = document.querySelector("#refreshBtn, #refreshBtnL");
+ const targetSearchElement = refreshBtn?.firstChild || refreshBtnL?.firstChild;
+
+ if (targetSearchElement) {
+ targetSearchElement.src = searchIconUrl;
+ }
+}
+
+// Creates a tooltip if the user wants to filter liked models without a personal API key
+function createTooltip(element, hover_element, insertText) {
+ if (element) {
+ const tooltip = document.createElement('div');
+ tooltip.className = 'browser_tooltip';
+ tooltip.textContent = insertText;
+ tooltip.style.cssText = 'display: none; text-align: center; white-space: pre;';
+
+ hover_element.addEventListener('mouseover', () => {
+ tooltip.style.display = 'block';
+ });
+ hover_element.addEventListener('mouseout', () => {
+ tooltip.style.display = 'none';
+ });
+ element.appendChild(tooltip);
+ }
+}
+
+// Function that closes filter dropdown if clicked outside the dropdown
+function setupClickOutsideListener() {
+ var filterBox = document.getElementById("filterBoxL") || document.getElementById("filterBox");
+ var filterButton = filterBox.getElementsByTagName("div")[1];
+ var dropDown = filterBox.getElementsByTagName("div")[2];
+
+ function clickOutsideHandler(event) {
+ var target = event.target;
+ if (!filterBox.contains(target)) {
+ if (!dropDown.contains(target)) {
+ if (filterButton.className.endsWith("open")) {
+ filterButton.click();
+ }
+ }
+ }
+ }
+ document.addEventListener("click", clickOutsideHandler);
+}
+
+// Create hyperlink in settings to CivitAI account settings
+function createLink(infoElement) {
+
+ const existingText = "(You can create your own API key in your CivitAI account settings, this required for some downloads, Requires UI reload)";
+ const linkText = "CivitAI account settings";
+
+ const [textBefore, textAfter] = existingText.split(linkText);
+
+ const link = document.createElement('a');
+ link.textContent = linkText;
+ link.href = 'https://civitai.com/user/account';
+ link.target = '_blank';
+
+ while (infoElement.firstChild) infoElement.removeChild(infoElement.firstChild);
+
+ infoElement.appendChild(document.createTextNode(textBefore));
+ infoElement.appendChild(link);
+ infoElement.appendChild(document.createTextNode(textAfter));
+}
+
+// Function to update the visibility of backToTopDiv based on the intersection with civitaiDiv
+function updateBackToTopVisibility(entries) {
+ var backToTopDiv = document.getElementById('backToTopContainer');
+ var civitaiDiv = document.getElementById('civitai_preview_html');
+
+ if (civitaiDiv.clientHeight > 0 && entries[0].isIntersecting && window.scrollY !== 0) {
+ backToTopDiv.style.visibility = 'visible';
+ } else {
+ backToTopDiv.style.visibility = 'hidden';
+ }
+}
+
+// Options for the Intersection Observer
+var options = {
+ root: null,
+ rootMargin: '0px 0px -60px 0px',
+ threshold: 0
+};
+
+// Create an Intersection Observer instance
+const observer = new IntersectionObserver(updateBackToTopVisibility, options);
+
+function handleCivitaiDivChanges() {
+ var civitaiDiv = document.getElementById('civitai_preview_html');
+ observer.unobserve(civitaiDiv);
+ observer.observe(civitaiDiv);
+}
+
+document.addEventListener("scroll", handleCivitaiDivChanges)
+
+// Create the accordion dropdown inside the settings tab
+function createAccordion(containerDiv, subfolders, name) {
+ if (containerDiv == null || subfolders.length == 0) {
+ return;
+ }
+ var accordionContainer = document.createElement('div');
+ accordionContainer.id = 'settings-accordion';
+ var toggleButton = document.createElement('button');
+ toggleButton.id = 'accordionToggle';
+ toggleButton.innerHTML = name + '
▼
';
+ toggleButton.onclick = function () {
+ accordionDiv.style.display = (accordionDiv.style.display === 'none') ? 'block' : 'none';
+ toggleButton.lastChild.style.transform = accordionDiv.style.display === 'none' ? 'rotate(90deg)' : 'rotate(0)';
+ };
+
+ accordionContainer.appendChild(toggleButton);
+ var accordionDiv = document.createElement('div');
+ accordionDiv.classList.add('accordion');
+ accordionDiv.append(...subfolders);
+ accordionDiv.style.display = 'none';
+ accordionContainer.appendChild(accordionDiv);
+ containerDiv.appendChild(accordionContainer);
+}
+
+// Adds a button to the cards in txt2img and img2img
+function createCardButtons(event) {
+ const clickedElement = event.target;
+ const validButtonNames = ['Textual Inversion', 'Hypernetworks', 'Checkpoints', 'Lora'];
+ const validParentIds = ['txt2img_textual_inversion_cards_html', 'txt2img_hypernetworks_cards_html', 'txt2img_checkpoints_cards_html', 'txt2img_lora_cards_html'];
+
+ const hasMatchingButtonName = clickedElement && clickedElement.innerText && validButtonNames.some(buttonName =>
+ clickedElement.innerText.trim() === buttonName
+ );
+
+ const flexboxDivs = document.querySelectorAll('.layoutkit-flexbox');
+ let isLobeTheme = false;
+ flexboxDivs.forEach(div => {
+ const anchorElements = div.querySelectorAll('a');
+ const hasGitHubLink = Array.from(anchorElements).some(anchor => anchor.href === 'https://github.com/lobehub/sd-webui-lobe-theme/releases');
+ if (hasGitHubLink) {
+ isLobeTheme = true;
+ }
+ });
+
+ if (hasMatchingButtonName || isLobeTheme) {
+ const checkForCardDivs = setInterval(() => {
+ const cardDivs = document.querySelectorAll('.card');
+
+ if (cardDivs.length > 0) {
+ clearInterval(checkForCardDivs);
+
+ const cardScale = document.querySelector('#setting_extra_networks_card_text_scale > div > div > input').valueAsNumber * 100;
+ const viewBoxHeight = (cardScale < 100) ? (100 - cardScale) * 2 : -(cardScale - 100) * 2;
+
+ cardDivs.forEach(cardDiv => {
+ const buttonRow = cardDiv.querySelector('.button-row');
+ const actions = cardDiv.querySelector('.actions');
+ if (!actions) {
+ return;
+ }
+ const nameSpan = actions.querySelector('.name');
+ let modelName = nameSpan.textContent.trim();
+ let currentElement = cardDiv.parentElement;
+ let content_type = null;
+
+ while (currentElement) {
+ const parentId = currentElement.id;
+ if (validParentIds.includes(parentId)) {
+ content_type = parentId;
+ break;
+ }
+ currentElement = currentElement.parentElement;
+ }
+
+ const existingDiv = buttonRow.querySelector('.goto-civitbrowser.card-button');
+ if (existingDiv) {
+ return;
+ }
+
+ const metaDataButton = buttonRow.querySelector('.metadata-button.card-button');
+
+ const copyPathButton = buttonRow.querySelector('.copy-path-button.card-button');
+ let modelPath = "";
+ if (copyPathButton) {
+ modelPath = copyPathButton.getAttribute('data-clipboard-text');
+ }
+
+ const newDiv = document.createElement('div');
+ newDiv.classList.add('goto-civitbrowser', 'card-button');
+ newDiv.addEventListener('click', function (event) {
+ event.stopPropagation();
+ modelInfoPopUp(modelName, content_type, modelPath);
+ });
+
+ const svgIcon = document.createElementNS("http://www.w3.org/2000/svg", "svg");
+ if (isLobeTheme) {
+ svgIcon.setAttribute('width', '25');
+ svgIcon.setAttribute('height', '25');
+ } else {
+ if (metaDataButton) {
+ metaDataButton.style.paddingTop = '5px';
+ metaDataButton.style.width = '42px';
+ metaDataButton.style.fontSize = '230%';
+ }
+ svgIcon.setAttribute('width', '40');
+ svgIcon.setAttribute('height', '40');
+ newDiv.setAttribute('style', 'width: 42px !important;');
+ }
+ svgIcon.setAttribute('viewBox', `75 ${viewBoxHeight} 500 500`);
+ svgIcon.setAttribute('fill', 'white');
+ svgIcon.setAttribute('style', `scale: ${cardScale}%;`);
+
+ svgIcon.innerHTML = `
+
+
+ `;
+
+ newDiv.appendChild(svgIcon);
+ buttonRow.insertBefore(newDiv, buttonRow.firstChild);
+ });
+ }
+ }, 100);
+ }
+}
+document.addEventListener('click', createCardButtons);
+
+function modelInfoPopUp(modelName, content_type, modelPath) {
+ select_model(modelName, null, true, content_type, modelPath);
+
+ // Create the overlay
+ var overlay = document.createElement('div');
+ overlay.classList.add('civitai-overlay');
+ overlay.style.position = 'fixed';
+ overlay.style.top = '0';
+ overlay.style.left = '0';
+ overlay.style.width = '100%';
+ overlay.style.height = '100%';
+ overlay.style.backgroundColor = 'rgba(20, 20, 20, 0.95)';
+ overlay.style.zIndex = '1001';
+ overlay.style.overflowY = 'auto';
+
+ // Create the close button
+ var closeButton = document.createElement('div');
+ closeButton.classList.add('civitai-overlay-close');
+ closeButton.textContent = '×';
+ closeButton.style.zIndex = '1011';
+ closeButton.style.position = 'fixed';
+ closeButton.style.right = '22px';
+ closeButton.style.top = '0';
+ closeButton.style.cursor = 'pointer';
+ closeButton.style.color = 'white';
+ closeButton.style.fontSize = '32pt';
+ closeButton.addEventListener('click', hidePopup);
+ document.addEventListener('keydown', handleKeyPress);
+
+ // Create the pop-up window
+ var inner = document.createElement('div');
+ inner.classList.add('civitai-overlay-inner');
+ inner.style.position = 'absolute';
+ inner.style.top = '50%';
+ inner.style.left = '50%';
+ inner.style.width = 'auto';
+ inner.style.transform = 'translate(-50%, -50%)';
+ inner.style.background = 'var(--body-background-fill)';
+ inner.style.padding = '2em';
+ inner.style.borderRadius = 'var(--block-radius)';
+ inner.style.borderStyle = 'solid';
+ inner.style.borderWidth = 'var(--block-border-width)';
+ inner.style.borderColor = 'var(--block-border-color)';
+ inner.style.zIndex = '1001';
+
+ // Placeholder model content until model is loaded by other function
+ var modelInfo = document.createElement('div');
+ modelInfo.classList.add('civitai-overlay-text');
+ modelInfo.textContent = 'Loading model info, please wait!';
+ modelInfo.style.fontSize = '24px';
+ modelInfo.style.color = 'white';
+ modelInfo.style.fontFamily = 'var(--font)';
+
+ document.body.style.overflow = 'hidden';
+ document.body.appendChild(overlay);
+ overlay.appendChild(closeButton);
+ overlay.appendChild(inner);
+ inner.appendChild(modelInfo);
+
+ overlay.addEventListener('click', function (event) {
+ if (event.target === overlay) {
+ hidePopup();
+ }
+ });
+
+ setDynamicWidth(inner);
+
+ // Update width on window resize
+ window.addEventListener('resize', function() {
+ setDynamicWidth(inner);
+ });
+}
+
+function setDynamicWidth(inner) {
+ var windowWidth = window.innerWidth;
+ var dynamicWidth = Math.min(Math.max(windowWidth - 150, 350), 900);
+ inner.style.width = dynamicWidth + 'px';
+}
+
+// Function to hide the popup
+function hidePopup() {
+ var overlay = document.querySelector('.civitai-overlay');
+ if (overlay) {
+ document.body.removeChild(overlay);
+ document.body.style.overflow = 'auto';
+ window.removeEventListener('resize', setDynamicWidth);
+ }
+}
+
+// Function to handle key presses
+function handleKeyPress(event) {
+ if (event.key === 'Escape') {
+ hidePopup();
+ }
+}
+
+function inputHTMLPreviewContent(html_input) {
+ var inner = document.querySelector('.civitai-overlay-inner')
+ let startIndex = html_input.indexOf("'value': '");
+ if (startIndex !== -1) {
+ startIndex += "'value': '".length;
+ const endIndex = html_input.indexOf("', 'type': None,", startIndex);
+ if (endIndex !== -1) {
+ let extractedText = html_input.substring(startIndex, endIndex);
+ var modelIdNotFound = extractedText.includes(">Model ID not found. The");
+
+ extractedText = extractedText.replace(/\\n\s* 0) {
+ return;
+ }
+ const genButton = gradioApp().querySelector('#txt2img_extra_tabs > div > button')
+ let input = element.querySelector('dd').textContent;
+ let inf;
+ if (input.endsWith(',')) {
+ inf = input + ' ';
+ } else {
+ inf = input + ', ';
+ }
+ let is_positive = false
+ let is_negative = false
+ switch(type) {
+ case 'prompt':
+ is_positive = true
+ break;
+ case 'negativePrompt':
+ inf = 'Negative prompt: ' + inf;
+ is_negative = true
+ break;
+ case 'seed':
+ inf = 'Seed: ' + inf;
+ inf = inf + inf + inf;
+ break;
+ case 'Size':
+ inf = 'Size: ' + inf;
+ inf = inf + inf + inf;
+ break;
+ case 'Model':
+ inf = 'Model: ' + inf;
+ inf = inf + inf + inf;
+ break;
+ case 'clipSkip':
+ inf = 'Clip skip: ' + inf;
+ inf = inf + inf + inf;
+ break;
+ case 'sampler':
+ inf = 'Sampler: ' + inf;
+ inf = inf + inf + inf;
+ break;
+ case 'steps':
+ inf = 'Steps: ' + inf;
+ inf = inf + inf + inf;
+ break;
+ case 'cfgScale':
+ inf = 'CFG scale: ' + inf;
+ inf = inf + inf + inf;
+ break;
+ }
+ const prompt = gradioApp().querySelector('#txt2img_prompt textarea');
+ const neg_prompt = gradioApp().querySelector('#txt2img_neg_prompt textarea');
+ const cfg_scale = gradioApp().querySelector('#txt2img_cfg_scale > div:nth-child(2) > div > input');
+ let final = '';
+ let cfg = 'CFG scale: ' + cfg_scale.value + ", "
+ let prompt_addon = cfg + cfg + cfg
+ if (is_positive) {
+ final = inf + "\nNegative prompt: " + neg_prompt.value + "\n" + prompt_addon;
+ } else if (is_negative) {
+ final = prompt.value + "\n" + inf + "\n" + prompt_addon;
+ } else {
+ final = prompt.value + "\nNegative prompt: " + neg_prompt.value + "\n" + inf;
+ }
+ genInfo_to_txt2img(final, false)
+ hidePopup();
+ sendClick(genButton);
+}
+
+// Creates a list of the selected models
+var selectedModels = [];
+var selectedTypes = [];
+function multi_model_select(modelName, modelType, isChecked) {
+ if (arguments.length === 0) {
+ selectedModels = [];
+ selectedTypes = [];
+ return;
+ }
+ if (isChecked) {
+ if (!selectedModels.includes(modelName)) {
+ selectedModels.push(modelName);
+ }
+ selectedTypes.push(modelType)
+ } else {
+ var modelIndex = selectedModels.indexOf(modelName);
+ if (modelIndex > -1) {
+ selectedModels.splice(modelIndex, 1);
+ }
+ var typesIndex = selectedTypes.indexOf(modelType);
+ if (typesIndex > -1) {
+ selectedTypes.splice(typesIndex, 1);
+ }
+ }
+ const selected_model_list = gradioApp().querySelector('#selected_model_list textarea');
+ selected_model_list.value = JSON.stringify(selectedModels);
+
+ const selected_type_list = gradioApp().querySelector('#selected_type_list textarea');
+ selected_type_list.value = JSON.stringify(selectedTypes);
+
+ updateInput(selected_model_list);
+ updateInput(selected_type_list);
+}
+
+// Metadata button click detector
+document.addEventListener('click', function(event) {
+ var target = event.target;
+ if (target.classList.contains('edit-button') && target.classList.contains('card-button')) {
+ var parentDiv = target.parentElement;
+ var actionsDiv = parentDiv.nextElementSibling;
+ if (actionsDiv && actionsDiv.classList.contains('actions')) {
+ var nameSpan = actionsDiv.querySelector('.name');
+ if (nameSpan) {
+ var nameValue = nameSpan.textContent;
+ onEditButtonCardClick(nameValue);
+ }
+ }
+ }
+}, true);
+
+// CivitAI Link Button Creation
+function onEditButtonCardClick(nameValue) {
+ var checkInterval = setInterval(function() {
+ var globalPopupInner = document.querySelector('.global-popup-inner');
+ var titleElement = globalPopupInner.querySelector('.extra-network-name');
+ if (titleElement.textContent.trim() === nameValue.trim()) {
+ var descriptionSpan = Array.from(globalPopupInner.querySelectorAll('span')).find(span => span.textContent.trim() === "Description");
+ if (descriptionSpan) {
+ var descriptionTextarea = descriptionSpan.nextElementSibling;
+ if (descriptionTextarea.value.startsWith('Model URL:')) {
+ var matches = descriptionTextarea.value.match(/"([^"]+)"/);
+ if (matches && matches[1]) {
+ var modelUrl = matches[1];
+
+ var grandParentDiv = descriptionTextarea.parentElement.parentElement.parentElement.parentElement;
+ var imageDiv = grandParentDiv.nextElementSibling
+ var openInCivitaiDiv = document.querySelector('.open-in-civitai');
+ if (!openInCivitaiDiv) {
+ openInCivitaiDiv = document.createElement('div');
+ openInCivitaiDiv.classList.add('open-in-civitai');
+ imageDiv.appendChild(openInCivitaiDiv);
+ }
+ openInCivitaiDiv.innerHTML = 'Open on CivitAI';
+ }
+ else {
+ var openInCivitaiDiv = document.querySelector('.open-in-civitai');
+ if (openInCivitaiDiv) {
+ openInCivitaiDiv.remove();
+ }
+ }
+ } else {
+ var openInCivitaiDiv = document.querySelector('.open-in-civitai');
+ if (openInCivitaiDiv) {
+ openInCivitaiDiv.remove();
+ }
+ }
+ }
+ clearInterval(checkInterval);
+ }
+ }, 100);
+}
+
+function sendClick(location) {
+ const clickEvent = new MouseEvent('click', {
+ view: window,
+ bubbles: true,
+ cancelable: true
+ });
+ location.dispatchEvent(clickEvent);
+}
+
+let currentDlCancelled = false;
+
+function cancelCurrentDl() {
+ currentDlCancelled = true;
+}
+
+let allDlCancelled = false;
+
+function cancelAllDl() {
+ allDlCancelled = true;
+}
+
+function setSortable() {
+ new Sortable(document.getElementById('queue_list'), {
+ onEnd: function(evt) {
+ const gradio_input = document.querySelector('#civitai_dl_list.prose').innerHTML;
+ const gradio_html = gradioApp().querySelector('#queue_html_input textarea');
+ let output = gradioApp().querySelector('#arrange_dl_id textarea');
+ output.value = evt.item.getAttribute('dl_id') + "." + evt.newIndex;
+ updateInput(output);
+ gradio_html.value = gradio_input;
+ updateInput(gradio_html);
+ }
+ });
+}
+
+function cancelQueueDl() {
+ const cancelBtn = gradioApp().querySelector('#html_cancel_input textarea');
+ const randomNumber = Math.floor(Math.random() * 1000);
+ const paddedNumber = String(randomNumber).padStart(3, '0');
+ cancelBtn.value = paddedNumber;
+ updateInput(cancelBtn);cancelBtn
+}
+
+function setDownloadProgressBar() {
+ const gradio_html = gradioApp().querySelector('#queue_html_input textarea');
+ let browserContainer = document.querySelector('#DownloadProgress');
+ let browserProgress = browserContainer.querySelector('.progress-bar');
+ if (!browserProgress || !browserProgress.style.width) {
+ setTimeout(setDownloadProgressBar, 500);
+ return;
+ }
+
+ let dlList = document.getElementById('civitai_dl_list');
+ let nonQueue = dlList.querySelector('.civitai_nonqueue_list');
+ let dlItem = dlList.querySelector('.civitai_dl_item');
+ let dlBtn = dlItem.querySelector('.dl_action_btn > span');
+ dlBtn.innerText = "Cancel";
+ dlBtn.setAttribute('onclick', 'cancelQueueDl()');
+ let dlId = dlItem.getAttribute('dl_id');
+ let selector = '.civitai_dl_item[dl_id="' + parseInt(dlId) + '"]';
+
+ let dlProgressBar = null;
+ let percentage = null;
+ let dlText = null;
+
+ nonQueue.appendChild(dlItem);
+
+ const interval = setInterval(() => {
+ browserContainer = document.querySelector('#DownloadProgress');
+ browserProgress = browserContainer.querySelector('.progress-bar');
+ dlText = browserContainer.querySelector('.progress-level-inner');
+ if (!dlText) {
+ return;
+ }
+ dlText = dlText.innerText
+ percentage = parseFloat(browserProgress.style.width);
+
+ dlItem = dlList.querySelector(selector);
+ dlProgressBar = dlItem.querySelector('.dl_progress_bar');
+
+ dlProgressBar.textContent = percentage.toFixed(1) + '%';
+ dlProgressBar.style.width = percentage + '%';
+
+ if (percentage >= 100) {
+ clearInterval(interval);
+ dlBtn = dlItem.querySelector('.dl_action_btn > span');
+ dlBtn.innerText = "Remove";
+ dlBtn.setAttribute('onclick', 'removeDlItem(' + parseInt(dlId) + ', this)');
+ dlItem.className = 'civitai_dl_item_completed';
+ dlProgressBar.textContent = 'Completed';
+ dlProgressBar.style.width = '100%';
+ const gradio_input = document.querySelector('#civitai_dl_list.prose').innerHTML;
+ gradio_html.value = gradio_input
+ updateInput(gradio_html);
+ return;
+ }
+
+ if (currentDlCancelled) {
+ clearInterval(interval);
+ dlBtn = dlItem.querySelector('.dl_action_btn > span');
+ dlBtn.innerText = "Remove";
+ dlBtn.setAttribute('onclick', 'removeDlItem(' + parseInt(dlId) + ', this)');
+ currentDlCancelled = false;
+ dlItem.className = 'civitai_dl_item_failed';
+ dlProgressBar.textContent = 'Cancelled';
+ dlProgressBar.style.width = "0%";
+ const gradio_input = document.querySelector('#civitai_dl_list.prose').innerHTML;
+ gradio_html.value = gradio_input
+ updateInput(gradio_html);
+ return;
+ } else if (allDlCancelled) {
+ clearInterval(interval);
+ allDlCancelled = false;
+ let dlItems = dlList.querySelectorAll('.civitai_dl_item');
+ dlItems.forEach(function(item) {
+ dlBtn = dlItem.querySelector('.dl_action_btn > span');
+ dlBtn.innerText = "Remove";
+ dlBtn.setAttribute('onclick', 'removeDlItem(' + parseInt(dlId) + ', this)');
+ dlProgressBar = item.querySelector('.dl_progress_bar');
+ dlProgressBar.textContent = 'Cancelled';
+ dlProgressBar.style.width = "0%";
+ nonQueue.appendChild(item);
+ item.className = 'civitai_dl_item_failed';
+ });
+ const gradio_input = document.querySelector('#civitai_dl_list.prose').innerHTML;
+ gradio_html.value = gradio_input
+ updateInput(gradio_html);
+ return;
+ } else if (dlText.includes('Encountered an error during download of') || dlText.includes('not found on CivitAI servers') || dlText.includes('requires a personal CivitAI API to be downloaded')) {
+ clearInterval(interval);
+ dlBtn = dlItem.querySelector('.dl_action_btn > span');
+ dlBtn.innerText = "Remove";
+ dlBtn.setAttribute('onclick', 'removeDlItem(' + parseInt(dlId) + ', this)');
+ dlItem.className = 'civitai_dl_item_failed';
+ dlProgressBar.textContent = 'Failed';
+ dlProgressBar.style.width = "0%";
+ const gradio_input = document.querySelector('#civitai_dl_list.prose').innerHTML;
+ gradio_html.value = gradio_input
+ updateInput(gradio_html);
+ return;
+ }
+ }, 500);
+}
+
+function removeDlItem(dl_id, element) {
+ const gradio_html = gradioApp().querySelector('#queue_html_input textarea');
+ const output = gradioApp().querySelector('#remove_dl_id textarea');
+ var dl_item = element.parentNode.parentNode;
+ dl_item.parentNode.removeChild(dl_item);
+ output.value = dl_id
+ updateInput(output);
+
+ const gradio_input = document.querySelector('#civitai_dl_list.prose').innerHTML;
+ gradio_html.value = gradio_input;
+ updateInput(gradio_html);
+}
+
+// Selects all models
+function selectAllModels() {
+ const checkboxes = Array.from(document.querySelectorAll('.model-checkbox'));
+ const allChecked = checkboxes.every(checkbox => checkbox.checked);
+ const allUnchecked = checkboxes.every(checkbox => !checkbox.checked);
+ if (allChecked || allUnchecked) {
+ checkboxes.forEach(sendClick);
+ } else {
+ checkboxes.filter(checkbox => !checkbox.checked).forEach(sendClick);
+ }
+}
+
+// Deselects all models
+function deselectAllModels() {
+ setTimeout(() => {
+ const checkboxes = Array.from(document.querySelectorAll('.model-checkbox'));
+ checkboxes.filter(checkbox => checkbox.checked).forEach(sendClick);
+ }, 1000);
+}
+
+// Sends Image URL to Python to pull generation info
+function sendImgUrl(image_url) {
+ const randomNumber = Math.floor(Math.random() * 1000);
+ const genButton = gradioApp().querySelector('#txt2img_extra_tabs > div > button')
+ const paddedNumber = String(randomNumber).padStart(3, '0');
+ const input = gradioApp().querySelector('#civitai_text2img_input textarea');
+ input.value = paddedNumber + "." + image_url;
+ updateInput(input);
+ hidePopup();
+ sendClick(genButton);
+}
+
+// Sends txt2img info to txt2img tab
+function genInfo_to_txt2img(genInfo, do_slice=true) {
+ let insert = gradioApp().querySelector('#txt2img_prompt textarea');
+ let pasteButton = gradioApp().querySelector('#paste');
+ if (genInfo) {
+ insert.value = do_slice ? genInfo.slice(5) : genInfo;
+ insert.dispatchEvent(new Event('input', { bubbles: true }));
+ pasteButton.dispatchEvent(new Event('click', { bubbles: true }));
+ }
+}
+
+// Hide installed models
+function hideInstalled(toggleValue) {
+ const modelList = document.querySelectorAll('.column.civmodellist > .civmodelcardinstalled')
+ modelList.forEach(item => {
+ item.style.display = toggleValue ? 'none' : 'block';
+ });
+}
+
+// Runs all functions when the page is fully loaded
+function onPageLoad() {
+ const divElement = document.getElementById('setting_custom_api_key');
+ let civitaiDiv = document.getElementById('civitai_preview_html');
+ let queue_list = document.querySelector("#queue_list");
+ const infoElement = divElement?.querySelector('.info');
+ if (!infoElement) {
+ return;
+ }
+
+ clearInterval(intervalID);
+ updateSVGIcons();
+
+ let subfolderDiv = document.querySelector("#settings_civitai_browser_plus > div > div");
+ let downloadDiv = document.querySelector("#settings_civitai_browser_download > div > div");
+ if (subfolderDiv || downloadDiv) {
+ let div = subfolderDiv || downloadDiv;
+ let subfolders = div.querySelectorAll("[id$='subfolder']");
+ createAccordion(div, subfolders, "Default sub folders");
+ }
+
+ let upscalerDiv = document.querySelector("#settings_civitai_browser_plus > div > div > #settings-accordion > div");
+ let downloadDivSub = document.querySelector("#settings_civitai_browser_download > div > div > #settings-accordion > div");
+ if (upscalerDiv || downloadDivSub) {
+ let div = upscalerDiv || downloadDivSub;
+ let upscalers = div.querySelectorAll("[id$='upscale_subfolder']");
+ createAccordion(div, upscalers, "Upscalers");
+ }
+
+ let settingsDiv = document.querySelector("#settings_civitai_browser > div > div");
+ if (subfolderDiv || settingsDiv) {
+ let div = subfolderDiv || settingsDiv;
+ let subfolders = div.querySelectorAll("[id^='setting_insert_sub']");
+ createAccordion(div, subfolders, "Insert sub folder options");
+ }
+
+
+ let toggle4L = document.getElementById('toggle4L');
+ let toggle4 = document.getElementById('toggle4');
+ if (toggle4L || toggle4) {
+ let like_toggle = toggle4L || toggle4;
+ let insertText = 'Requires an API Key\nConfigurable in CivitAI settings tab';
+ createTooltip(like_toggle, like_toggle, insertText);
+ }
+
+ let hash_toggle_hover = document.querySelector('#skip_hash_toggle > label');
+ let hash_toggle = document.querySelector('#skip_hash_toggle');
+ if (hash_toggle) {
+ let insertText = 'This option generates unique hashes for models that were not downloaded with this extension.\nA hash is required for any of the options below to work, a model with no hash will be skipped.\nInitial hash generation is a one-time process per file.';
+ createTooltip(hash_toggle, hash_toggle_hover, insertText);
+ }
+
+ observer.observe(civitaiDiv);
+ queueObserver.observe(queue_list, queueObserverOptions);
+ adjustFilterBoxAndButtons();
+ setupClickOutsideListener();
+ createLink(infoElement);
+ updateBackToTopVisibility([{isIntersecting: false}]);
+}
+
+// Checks every second if the page is fully loaded
+let intervalID = setInterval(onPageLoad, 1000);
\ No newline at end of file
diff --git a/sd-civitai-browser-plus/scripts/civitai_api.py b/sd-civitai-browser-plus/scripts/civitai_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..4736155acddec5815ec6a50a3b4bf18fcb4f47a7
--- /dev/null
+++ b/sd-civitai-browser-plus/scripts/civitai_api.py
@@ -0,0 +1,1222 @@
+import requests
+import json
+import gradio as gr
+import urllib.request
+import urllib.parse
+import urllib.error
+import os
+import re
+import datetime
+import platform
+from PIL import Image
+from io import BytesIO
+from collections import defaultdict
+from modules.images import read_info_from_image
+from modules.shared import cmd_opts, opts
+from modules.paths import models_path, extensions_dir, data_path
+from html import escape
+from scripts.civitai_global import print
+import scripts.civitai_global as gl
+import scripts.civitai_download as _download
+try:
+ from fake_useragent import UserAgent
+except ImportError:
+ print("Python module 'fake_useragent' has not been imported correctly, please try to restart or install it manually.")
+
+gl.init()
+
+def contenttype_folder(content_type, desc=None, fromCheck=False, custom_folder=None):
+ use_LORA = getattr(opts, "use_LORA", False)
+ folder = None
+ if desc:
+ desc = desc.upper()
+ else:
+ desc = "PLACEHOLDER"
+ if custom_folder:
+ main_models = custom_folder
+ main_data = custom_folder
+ else:
+ main_models = models_path
+ main_data = data_path
+
+ if content_type == "modelFolder":
+ folder = os.path.join(main_models)
+
+ if content_type == "Checkpoint":
+ if cmd_opts.ckpt_dir and not custom_folder:
+ folder = cmd_opts.ckpt_dir
+ else:
+ folder = os.path.join(main_models,"Stable-diffusion")
+
+ elif content_type == "Hypernetwork":
+ if cmd_opts.hypernetwork_dir and not custom_folder:
+ folder = cmd_opts.hypernetwork_dir
+ else:
+ folder = os.path.join(main_models, "hypernetworks")
+
+ elif content_type == "TextualInversion":
+ if cmd_opts.embeddings_dir and not custom_folder:
+ folder = cmd_opts.embeddings_dir
+ else:
+ folder = os.path.join(main_data, "embeddings")
+
+ elif content_type == "AestheticGradient":
+ if not custom_folder:
+ folder = os.path.join(extensions_dir, "stable-diffusion-webui-aesthetic-gradients", "aesthetic_embeddings")
+ else:
+ folder = os.path.join(custom_folder, "aesthetic_embeddings")
+ elif content_type == "LORA":
+ if cmd_opts.lora_dir and not custom_folder:
+ folder = cmd_opts.lora_dir
+ else:
+ folder = folder = os.path.join(main_models, "Lora")
+
+ elif content_type == "LoCon":
+ folder = os.path.join(main_models, "LyCORIS")
+ if use_LORA and not fromCheck:
+ if cmd_opts.lora_dir and not custom_folder:
+ folder = cmd_opts.lora_dir
+ else:
+ folder = folder = os.path.join(main_models, "Lora")
+
+ elif content_type == "VAE":
+ if cmd_opts.vae_dir and not custom_folder:
+ folder = cmd_opts.vae_dir
+ else:
+ folder = os.path.join(main_models, "VAE")
+
+ elif content_type == "Controlnet":
+ folder = os.path.join(main_models, "ControlNet")
+
+ elif content_type == "Poses":
+ folder = os.path.join(main_models, "Poses")
+
+ elif content_type == "Upscaler":
+ if "SWINIR" in desc:
+ if cmd_opts.swinir_models_path and not custom_folder:
+ folder = cmd_opts.swinir_models_path
+ else:
+ folder = os.path.join(main_models, "SwinIR")
+ elif "REALESRGAN" in desc:
+ if cmd_opts.realesrgan_models_path and not custom_folder:
+ folder = cmd_opts.realesrgan_models_path
+ else:
+ folder = os.path.join(main_models, "RealESRGAN")
+ elif "GFPGAN" in desc:
+ if cmd_opts.gfpgan_models_path and not custom_folder:
+ folder = cmd_opts.gfpgan_models_path
+ else:
+ folder = os.path.join(main_models, "GFPGAN")
+ elif "BSRGAN" in desc:
+ if cmd_opts.bsrgan_models_path and not custom_folder:
+ folder = cmd_opts.bsrgan_models_path
+ else:
+ folder = os.path.join(main_models, "BSRGAN")
+ else:
+ if cmd_opts.esrgan_models_path and not custom_folder:
+ folder = cmd_opts.esrgan_models_path
+ else:
+ folder = os.path.join(main_models, "ESRGAN")
+
+ elif content_type == "MotionModule":
+ folder = os.path.join(extensions_dir, "sd-webui-animatediff", "model")
+
+ elif content_type == "Workflows":
+ folder = os.path.join(main_models, "Workflows")
+
+ elif content_type == "Other":
+ if "ADETAILER" in desc:
+ folder = os.path.join(main_models, "adetailer")
+ else:
+ folder = os.path.join(main_models, "Other")
+
+ elif content_type == "Wildcards":
+ folder = os.path.join(extensions_dir, "UnivAICharGen", "wildcards")
+ if not os.path.exists(folder):
+ folder = os.path.join(extensions_dir, "sd-dynamic-prompts", "wildcards")
+
+ return folder
+
+def api_to_data(content_type, sort_type, period_type, use_search_term, current_page, base_filter, only_liked, tile_count, search_term=None, nsfw=None, timeOut=None, isNext=None, inputs_changed=None):
+ if current_page in [0, None, ""]:
+ current_page = 1
+ if inputs_changed:
+ gl.file_scan = False
+ api_url = f"https://civitai.com/api/v1/models?limit={tile_count}&page=1"
+ else:
+ api_url = f"https://civitai.com/api/v1/models?limit={tile_count}&page={current_page}"
+
+ if timeOut:
+ if isNext:
+ next_page = str(int(current_page) + 1)
+ else:
+ if current_page not in [1, 0, None, ""]:
+ next_page = str(int(current_page) - 1)
+ api_url = f"https://civitai.com/api/v1/models?limit={tile_count}&page={next_page}"
+
+ if period_type:
+ period_type = period_type.replace(" ", "")
+ query = {'sort': sort_type, 'period': period_type}
+
+ types_query_str = ""
+
+ if content_type:
+ types_query_str = "".join([f"&types={type}" for type in content_type])
+
+ query_str = urllib.parse.urlencode(query, quote_via=urllib.parse.quote)
+
+ if types_query_str:
+ query_str += types_query_str
+
+ if use_search_term != "None" and search_term:
+ search_term = search_term.replace("\\","\\\\")
+ if "civitai.com" in search_term:
+ match = re.search(r'models/(\d+)', search_term)
+ model_number = match.group(1)
+ query_str = f"&ids={urllib.parse.quote(model_number)}"
+ elif use_search_term == "User name":
+ query_str += f"&username={urllib.parse.quote(search_term)}"
+ elif use_search_term == "Tag":
+ query_str += f"&tag={urllib.parse.quote(search_term)}"
+ else:
+ query_str += f"&query={urllib.parse.quote(search_term)}"
+
+ if base_filter:
+ for base in base_filter:
+ query_str += f"&baseModels={urllib.parse.quote(base)}"
+
+ if only_liked:
+ query_str += f"&favorites=true"
+
+ if nsfw == False:
+ query_str += f"&nsfw=false"
+
+ full_url = f"{api_url}&{query_str}"
+
+ if gl.file_scan:
+ highest_number = max(gl.url_list_with_numbers.keys())
+ full_url = gl.url_list_with_numbers.get(int(current_page))
+ nextPage = int(current_page) + 1
+ prevPage = int(current_page) - 1
+ data = request_civit_api(full_url)
+ data["metadata"]["currentPage"] = current_page
+ data["metadata"]["totalPages"] = highest_number
+ if not nextPage > highest_number:
+ data["metadata"]["nextPage"] = gl.url_list_with_numbers.get(nextPage)
+ if not prevPage == 0:
+ data["metadata"]["prevPage"] = gl.url_list_with_numbers.get(prevPage)
+ else:
+ data = request_civit_api(full_url)
+
+ return data
+
+def model_list_html(json_data):
+ video_playback = getattr(opts, "video_playback", True)
+ playback = ""
+ if video_playback: playback = "autoplay loop"
+
+ hide_early_access = getattr(opts, "hide_early_access", True)
+ filtered_items = []
+ current_time = datetime.datetime.utcnow()
+
+ for item in json_data['items']:
+ versions_to_keep = []
+
+ for version in item['modelVersions']:
+ if not version['files']:
+ continue
+ if hide_early_access:
+ early_access_days = version['earlyAccessTimeFrame']
+ if early_access_days != 0:
+ published_at_str = version.get('publishedAt')
+ if published_at_str is not None:
+ published_at = datetime.datetime.strptime(version['publishedAt'], "%Y-%m-%dT%H:%M:%S.%fZ")
+ adjusted_date = published_at + datetime.timedelta(days=early_access_days)
+ if not current_time > adjusted_date or not published_at_str:
+ continue
+ versions_to_keep.append(version)
+
+ if versions_to_keep:
+ item['modelVersions'] = versions_to_keep
+ filtered_items.append(item)
+
+ json_data['items'] = filtered_items
+
+ HTML = '
'
+ sorted_models = {}
+ existing_files = set()
+ existing_files_sha256 = set()
+ model_folders = set()
+
+ for item in json_data['items']:
+ model_folder = os.path.join(contenttype_folder(item['type'], item['description']))
+ model_folders.add(model_folder)
+
+ for folder in model_folders:
+ for root, dirs, files in os.walk(folder, followlinks=True):
+ for file in files:
+ existing_files.add(file)
+ if file.endswith('.json'):
+ json_path = os.path.join(root, file)
+ with open(json_path, 'r', encoding="utf-8") as f:
+ try:
+ json_file = json.load(f)
+ if isinstance(json_file, dict):
+ sha256 = json_file.get('sha256')
+ if sha256:
+ existing_files_sha256.add(sha256.upper())
+ else:
+ print(f"Invalid JSON data in {json_path}. Expected a dictionary.")
+ except Exception as e:
+ print(f"Error decoding JSON in {json_path}: {e}")
+
+ for item in json_data['items']:
+ model_id = item.get('id')
+ model_name = item.get('name')
+ nsfw = ""
+ installstatus = ""
+ baseModel = ""
+ try:
+ if 'baseModel' in item['modelVersions'][0]:
+ baseModel = item['modelVersions'][0]['baseModel']
+ except:
+ baseModel = "Not Found"
+
+ try:
+ if 'updatedAt' in item['modelVersions'][0]:
+ date = item['modelVersions'][0]['updatedAt'].split('T')[0]
+ except:
+ baseModel = "Not Found"
+
+ if gl.sortNewest:
+ if date not in sorted_models:
+ sorted_models[date] = []
+
+ if any(item['modelVersions']):
+ if len(item['modelVersions'][0]['images']) > 0:
+ if item["modelVersions"][0]["images"][0]['nsfw'] not in ["None", "Soft"]:
+ nsfw = "civcardnsfw"
+ media_type = item["modelVersions"][0]["images"][0]["type"]
+ image = item["modelVersions"][0]["images"][0]["url"]
+ if media_type == "video":
+ image = image.replace("width=", "transcode=true,width=")
+ imgtag = f''
+ else:
+ imgtag = f''
+ else:
+ imgtag = f''
+
+ installstatus = None
+
+ for version in reversed(item['modelVersions']):
+ for file in version.get('files', []):
+ file_name = file['name']
+ file_sha256 = file.get('hashes', {}).get('SHA256', "").upper()
+
+ name_match = file_name in existing_files
+ sha256_match = file_sha256 in existing_files_sha256
+ if name_match or sha256_match:
+ if version == item['modelVersions'][0]:
+ installstatus = "civmodelcardinstalled"
+ else:
+ installstatus = "civmodelcardoutdated"
+ model_name_js = model_name.replace("'", "\\'")
+ model_string = escape(f"{model_name_js} ({model_id})")
+ model_card = f''
+
+ if gl.sortNewest:
+ sorted_models[date].append(model_card)
+ else:
+ HTML += model_card
+
+ if gl.sortNewest:
+ for date, cards in sorted(sorted_models.items(), reverse=True):
+ HTML += f'
{date}
'
+ HTML += '
'
+ for card in cards:
+ HTML += card
+ HTML += '
'
+
+ HTML += '
'
+ return HTML
+
+def update_prev_page(content_type, sort_type, period_type, use_search_term, search_term, current_page, base_filter, only_liked, nsfw, tile_count):
+ return update_next_page(content_type, sort_type, period_type, use_search_term, search_term, current_page, base_filter, only_liked, nsfw, tile_count, isNext=False)
+
+def update_next_page(content_type, sort_type, period_type, use_search_term, search_term, current_page, base_filter, only_liked, nsfw, tile_count, isNext=True):
+ use_LORA = getattr(opts, "use_LORA", False)
+
+ if content_type:
+ if use_LORA and 'LORA & LoCon' in content_type:
+ content_type.remove('LORA & LoCon')
+ if 'LORA' not in content_type:
+ content_type.append('LORA')
+ if 'LoCon' not in content_type:
+ content_type.append('LoCon')
+
+ if gl.json_data is None or gl.json_data == "timeout":
+ timeOut = True
+ return_values = update_model_list(content_type, sort_type, period_type, use_search_term, search_term, current_page, base_filter, only_liked, nsfw, timeOut=timeOut, isNext=isNext)
+ timeOut = False
+
+ return return_values
+
+ current_inputs = (content_type, sort_type, period_type, use_search_term, search_term, tile_count, base_filter, nsfw)
+ if current_inputs != gl.previous_inputs and gl.previous_inputs != None:
+ inputs_changed = True
+ else:
+ inputs_changed = False
+
+ if inputs_changed:
+ return_values = update_model_list(content_type, sort_type, period_type, use_search_term, search_term, current_page, base_filter, only_liked, nsfw, tile_count)
+ return return_values
+
+ if not gl.file_scan:
+ if isNext:
+ if gl.json_data['metadata']['nextPage'] is not None:
+ gl.json_data = request_civit_api(gl.json_data['metadata']['nextPage'])
+ else:
+ gl.json_data = None
+ else:
+ if gl.json_data['metadata']['prevPage'] is not None:
+ gl.json_data = request_civit_api(gl.json_data['metadata']['prevPage'])
+ else:
+ gl.json_data = None
+ else:
+ highest_number = max(gl.url_list_with_numbers.keys())
+ if isNext:
+ if gl.json_data['metadata']['nextPage'] is not None:
+ currentPage = int(gl.json_data['metadata']['currentPage'])
+ nextPage = currentPage + 2
+ prevPage = currentPage
+ pageCount = currentPage + 1
+ gl.json_data = request_civit_api(gl.json_data['metadata']['nextPage'])
+
+ gl.json_data["metadata"]["totalPages"] = highest_number
+ if not nextPage > highest_number:
+ gl.json_data["metadata"]["nextPage"] = gl.url_list_with_numbers.get(nextPage)
+ if not prevPage == 0:
+ gl.json_data["metadata"]["prevPage"] = gl.url_list_with_numbers.get(prevPage)
+ gl.json_data["metadata"]["currentPage"] = pageCount
+ else:
+ gl.json_data = None
+ else:
+ if gl.json_data['metadata']['prevPage'] is not None:
+ currentPage = int(gl.json_data['metadata']['currentPage'])
+ nextPage = currentPage
+ prevPage = currentPage - 2
+ pageCount = currentPage - 1
+ gl.json_data = request_civit_api(gl.json_data['metadata']['prevPage'])
+
+ gl.json_data["metadata"]["totalPages"] = highest_number
+ if not nextPage > highest_number:
+ gl.json_data["metadata"]["nextPage"] = gl.url_list_with_numbers.get(nextPage)
+ if not prevPage == 0:
+ gl.json_data["metadata"]["prevPage"] = gl.url_list_with_numbers.get(prevPage)
+ gl.json_data["metadata"]["currentPage"] = pageCount
+ else:
+ gl.json_data = None
+
+ if gl.json_data is None:
+ return
+
+ if gl.json_data == "timeout":
+ HTML = '
The Civit-API has timed out, please try again. The servers might be too busy or down if the issue persists.
'
+ hasPrev = current_page not in [0, 1]
+ hasNext = current_page == 1 or hasPrev
+ model_dict = {}
+
+ if gl.json_data != None and gl.json_data != "timeout":
+ (hasPrev, hasNext, current_page, total_pages) = pagecontrol(gl.json_data)
+ model_dict = {}
+ try:
+ gl.json_data['items']
+ except TypeError:
+ return gr.Dropdown.update(choices=[], value=None)
+
+ HTML = model_list_html(gl.json_data)
+
+ page_string = f"Page: {current_page}/{total_pages}"
+
+ return (
+ gr.Dropdown.update(choices=[v for k, v in model_dict.items()], value="", interactive=True), # Model List
+ gr.Dropdown.update(choices=[], value=""), # Version List
+ gr.HTML.update(value=HTML), # HTML Tiles
+ gr.Button.update(interactive=hasPrev), # Prev Page Button
+ gr.Button.update(interactive=hasNext), # Next Page Button
+ gr.Slider.update(value=current_page, maximum=total_pages, label=page_string), # Page Count
+ gr.Button.update(interactive=False), # Save Tags
+ gr.Button.update(interactive=False), # Save Images
+ gr.Button.update(interactive=False, visible=False if gl.isDownloading else True), # Download Button
+ gr.Button.update(interactive=False, visible=False), # Delete Button
+ gr.Textbox.update(interactive=False, value=None), # Install Path
+ gr.Dropdown.update(choices=[], value="", interactive=False), # Sub Folder List
+ gr.Dropdown.update(choices=[], value="", interactive=False), # File List
+ gr.HTML.update(value=''), # Preview HTML
+ gr.Textbox.update(value=None), # Trained Tags
+ gr.Textbox.update(value=None), # Base Model
+ gr.Textbox.update(value=None) # Model Filename
+ )
+
+def pagecontrol(json_data):
+ current_page = f"{json_data['metadata']['currentPage']}"
+ total_pages = f"{json_data['metadata']['totalPages']}"
+ hasNext = False
+ hasPrev = False
+ if 'nextPage' in json_data['metadata']:
+ hasNext = True
+ if 'prevPage' in json_data['metadata']:
+ hasPrev = True
+ return hasPrev, hasNext, current_page, total_pages
+
+def update_model_list(content_type=None, sort_type=None, period_type=None, use_search_term=None, search_term=None, current_page=None, base_filter=None, only_liked=None, nsfw=None, tile_count=None, timeOut=None, isNext=None, from_ver=False, from_installed=False):
+ use_LORA = getattr(opts, "use_LORA", False)
+ model_list = []
+ id_list = []
+
+ if content_type:
+ if use_LORA and 'LORA & LoCon' in content_type:
+ content_type.remove('LORA & LoCon')
+ if 'LORA' not in content_type:
+ content_type.append('LORA')
+ if 'LoCon' not in content_type:
+ content_type.append('LoCon')
+
+ if not from_ver and not from_installed:
+ gl.ver_json = None
+
+ current_inputs = (content_type, sort_type, period_type, use_search_term, search_term, tile_count, base_filter, nsfw)
+ if current_inputs != gl.previous_inputs and gl.previous_inputs != None:
+ inputs_changed = True
+ else:
+ inputs_changed = False
+
+ gl.previous_inputs = current_inputs
+
+ gl.json_data = api_to_data(content_type, sort_type, period_type, use_search_term, current_page, base_filter, only_liked, tile_count, search_term, nsfw, timeOut, isNext, inputs_changed)
+ if gl.json_data == "timeout":
+ HTML = '
The Civit-API has timed out, please try again. The servers might be too busy or down if the issue persists.
'
+ hasPrev = current_page not in [0, 1]
+ hasNext = current_page == 1 or hasPrev
+
+ if gl.json_data is None:
+ return
+
+ if from_installed or from_ver:
+ gl.json_data = gl.ver_json
+
+ if gl.json_data != None and gl.json_data != "timeout":
+ if not from_ver:
+ (hasPrev, hasNext, current_page, total_pages) = pagecontrol(gl.json_data)
+ else:
+ current_page = 1
+ total_pages = 1
+ hasPrev = False
+ hasNext = False
+ for item in gl.json_data['items']:
+ model_list.append(f"{item['name']} ({item['id']})")
+
+ HTML = model_list_html(gl.json_data)
+ else:
+ current_page = 1
+ total_pages = 1
+
+ page_string = f"Page: {current_page}/{total_pages}"
+
+ return (
+ gr.Dropdown.update(choices=model_list, value="", interactive=True), # Model List
+ gr.Dropdown.update(choices=[], value=""), # Version List
+ gr.HTML.update(value=HTML), # HTML Tiles
+ gr.Button.update(interactive=hasPrev), # Prev Page Button
+ gr.Button.update(interactive=hasNext), # Next Page Button
+ gr.Slider.update(value=current_page, maximum=total_pages, label=page_string), # Page Count
+ gr.Button.update(interactive=False), # Save Tags
+ gr.Button.update(interactive=False), # Save Images
+ gr.Button.update(interactive=False, visible=False if gl.isDownloading else True), # Download Button
+ gr.Button.update(interactive=False, visible=False), # Delete Button
+ gr.Textbox.update(interactive=False, value=None, visible=True), # Install Path
+ gr.Dropdown.update(choices=[], value="", interactive=False), # Sub Folder List
+ gr.Dropdown.update(choices=[], value="", interactive=False), # File List
+ gr.HTML.update(value=''), # Preview HTML
+ gr.Textbox.update(value=None), # Trained Tags
+ gr.Textbox.update(value=None), # Base Model
+ gr.Textbox.update(value=None) # Model Filename
+ )
+
+def update_model_versions(model_id, json_input=None):
+ if json_input:
+ api_json = json_input
+ else:
+ api_json = gl.json_data
+ for item in api_json['items']:
+ if int(item['id']) == int(model_id):
+ content_type = item['type']
+ desc = item.get('description', "None")
+
+ versions_dict = defaultdict(list)
+ installed_versions = set()
+
+ model_folder = os.path.join(contenttype_folder(content_type, desc))
+ gl.main_folder = model_folder
+ versions = item['modelVersions']
+
+ version_files = set()
+ for version in versions:
+ versions_dict[version['name']].append(item["name"])
+ for version_file in version['files']:
+ file_sha256 = version_file.get('hashes', {}).get('SHA256', "").upper()
+ version_filename = version_file['name']
+ version_files.add((version['name'], version_filename, file_sha256))
+
+ for root, _, files in os.walk(model_folder, followlinks=True):
+ for file in files:
+ if file.endswith('.json'):
+ try:
+ json_path = os.path.join(root, file)
+ with open(json_path, 'r', encoding="utf-8") as f:
+ json_data = json.load(f)
+ if isinstance(json_data, dict):
+ if 'sha256' in json_data and json_data['sha256']:
+ sha256 = json_data.get('sha256', "").upper()
+ for version_name, _, file_sha256 in version_files:
+ if sha256 == file_sha256:
+ installed_versions.add(version_name)
+ break
+ except Exception as e:
+ print(f"failed to read: \"{file}\": {e}")
+
+ for version_name, version_filename, _ in version_files:
+ if file == version_filename:
+ installed_versions.add(version_name)
+ break
+
+ version_names = list(versions_dict.keys())
+ display_version_names = [f"{v} [Installed]" if v in installed_versions else v for v in version_names]
+ default_installed = next((f"{v} [Installed]" for v in installed_versions), None)
+ default_value = default_installed or next(iter(version_names), None)
+
+ return gr.Dropdown.update(choices=display_version_names, value=default_value, interactive=True) # Version List
+
+ return gr.Dropdown.update(choices=[], value=None, interactive=False) # Version List
+
+def cleaned_name(file_name):
+ if platform.system() == "Windows":
+ illegal_chars_pattern = r'[\\/:*?"<>|]'
+ else:
+ illegal_chars_pattern = r'/'
+
+ name, extension = os.path.splitext(file_name)
+ clean_name = re.sub(illegal_chars_pattern, '', name)
+
+ return f"{clean_name}{extension}"
+
+def fetch_and_process_image(image_url):
+ response = requests.get(image_url)
+ if response.status_code == 200:
+ image = Image.open(BytesIO(response.content))
+ geninfo, _ = read_info_from_image(image)
+ return geninfo
+ return None
+
+def extract_model_info(input_string):
+ last_open_parenthesis = input_string.rfind("(")
+ last_close_parenthesis = input_string.rfind(")")
+
+ name = input_string[:last_open_parenthesis].strip()
+ id_number = input_string[last_open_parenthesis + 1:last_close_parenthesis]
+
+ return name, int(id_number)
+
+def update_model_info(model_string=None, model_version=None, only_html=False, input_id=None, json_input=None, from_preview=False):
+ video_playback = getattr(opts, "video_playback", True)
+ meta_btn = getattr(opts, "individual_meta_btn", True)
+ playback = ""
+ if video_playback: playback = "autoplay loop"
+
+ if json_input:
+ api_data = json_input
+ else:
+ api_data = gl.json_data
+
+ BtnDownInt = True
+ BtnDel = False
+ BtnImage = False
+ model_id = None
+
+ if not input_id:
+ _, model_id = extract_model_info(model_string)
+ else:
+ model_id = input_id
+
+ if model_version and "[Installed]" in model_version:
+ model_version = model_version.replace(" [Installed]", "")
+ if model_id:
+ output_html = ""
+ output_training = ""
+ output_basemodel = ""
+ img_html = ""
+ dl_dict = {}
+ is_LORA = False
+ file_list = []
+ file_dict = []
+ default_file = None
+ model_filename = None
+ sha256_value = None
+ for item in api_data['items']:
+ if int(item['id']) == int(model_id):
+ content_type = item['type']
+ if content_type == "LORA":
+ is_LORA = True
+ desc = item['description']
+ model_name = item['name']
+ model_folder = os.path.join(contenttype_folder(content_type, desc))
+ model_uploader = item['creator']['username']
+ uploader_avatar = item['creator']['image']
+ if uploader_avatar is None:
+ uploader_avatar = ''
+ else:
+ uploader_avatar = f''
+ tags = item.get('tags', "")
+ model_desc = item.get('description', "")
+ if model_desc:
+ model_desc = model_desc.replace('', '')
+ if model_version is None:
+ selected_version = item['modelVersions'][0]
+ else:
+ for model in item['modelVersions']:
+ if model['name'] == model_version:
+ selected_version = model
+ break
+
+ if selected_version['trainedWords']:
+ output_training = ",".join(selected_version['trainedWords'])
+ output_training = re.sub(r'<[^>]*:[^>]*>', '', output_training)
+ output_training = re.sub(r', ?', ', ', output_training)
+ output_training = output_training.strip(', ')
+ if selected_version['baseModel']:
+ output_basemodel = selected_version['baseModel']
+ for file in selected_version['files']:
+ dl_dict[file['name']] = file['downloadUrl']
+
+ if not model_filename:
+ model_filename = file['name']
+ dl_url = file['downloadUrl']
+ gl.json_info = item
+ sha256_value = file['hashes'].get('SHA256', 'Unknown')
+
+ size = file['metadata'].get('size', 'Unknown')
+ format = file['metadata'].get('format', 'Unknown')
+ fp = file['metadata'].get('fp', 'Unknown')
+ sizeKB = file.get('sizeKB', 0) * 1024
+ filesize = _download.convert_size(sizeKB)
+
+ unique_file_name = f"{size} {format} {fp} ({filesize})"
+ is_primary = file.get('primary', False)
+ file_list.append(unique_file_name)
+ file_dict.append({
+ "format": format,
+ "sizeKB": sizeKB
+ })
+ if is_primary:
+ default_file = unique_file_name
+ model_filename = file['name']
+ dl_url = file['downloadUrl']
+ gl.json_info = item
+ sha256_value = file['hashes'].get('SHA256', 'Unknown')
+
+ safe_tensor_found = False
+ pickle_tensor_found = False
+ if is_LORA and file_dict:
+ for file_info in file_dict:
+ file_format = file_info.get("format", "")
+ if "SafeTensor" in file_format:
+ safe_tensor_found = True
+ if "PickleTensor" in file_format:
+ pickle_tensor_found = True
+
+ if safe_tensor_found and pickle_tensor_found:
+ if "PickleTensor" in file_dict[0].get("format", ""):
+ if file_dict[0].get("sizeKB", 0) <= 100:
+ model_folder = os.path.join(contenttype_folder("TextualInversion"))
+
+ model_url = selected_version.get('downloadUrl', '')
+ model_main_url = f"https://civitai.com/models/{item['id']}"
+ img_html = '
'
+ for index, pic in enumerate(selected_version['images']):
+ meta_button = False
+ meta = pic['meta']
+ if meta and meta.get('prompt'):
+ meta_button = True
+ BtnImage = True
+ # Change width value in URL to original image width
+ image_url = re.sub(r'/width=\d+', f'/width={pic["width"]}', pic["url"])
+ if pic['type'] == "video":
+ image_url = image_url.replace("width=", "transcode=true,width=")
+ nsfw = 'class="model-block"'
+
+ if pic['nsfw'] not in ["None", "Soft"]:
+ nsfw = 'class="civnsfw model-block"'
+
+ img_html += f'''
+
'
+ # Define the preferred order of keys and convert them to lowercase
+ preferred_order = ["prompt", "negativePrompt", "seed", "Size", "Model", "clipSkip", "sampler", "steps", "cfgScale"]
+ preferred_order_lower = [key.lower() for key in preferred_order]
+ # Loop through the keys in the preferred order and add them to the HTML
+ for key in preferred_order:
+ if key in meta:
+ value = meta[key]
+ if meta_btn:
+ img_html += f'
{escape(str(key).capitalize())}
{escape(str(value))}
'
+ else:
+ img_html += f'
{escape(str(key).capitalize())}
{escape(str(value))}
'
+ # Check if there are remaining keys in meta
+ remaining_keys = [key for key in meta if key.lower() not in preferred_order_lower]
+
+ # Add the rest
+ if remaining_keys:
+ img_html += f"""
+
+
+
+
+
+ """
+ for key in remaining_keys:
+ value = meta[key]
+ img_html += f'
{escape(str(key).capitalize())}
{escape(str(value))}
'
+ img_html = img_html + '
'
+
+ img_html += '
'
+
+ img_html = img_html + '
'
+ img_html = img_html + ''
+ tags_html = ''.join([f'{escape(str(tag))}' for tag in tags])
+ def perms_svg(color):
+ return f''\
+ f''
+ deny_svg = f'{perms_svg("red")}'
+ perms_html= '
'\
+ f'{allow_svg if item.get("allowNoCredit") else deny_svg} Use the model without crediting the creator '\
+ f'{allow_svg if item.get("allowCommercialUse") in ["Image", "Rent", "RentCivit", "Sell"] else deny_svg} Sell images they generate '\
+ f'{allow_svg if item.get("allowCommercialUse") in ["Rent", "Sell"] else deny_svg} Run on services that generate images for money '\
+ f'{allow_svg if item.get("allowCommercialUse") in ["RentCivit", "Rent", "Sell"] else deny_svg} Run on Civitai '\
+ f'{allow_svg if item.get("allowDerivatives") else deny_svg} Share merges using this model '\
+ f'{allow_svg if item.get("allowCommercialUse") == "Sell" else deny_svg} Sell this model or merges using this model '\
+ f'{allow_svg if item.get("allowDifferentLicense") else deny_svg} Have different permissions when sharing merges'\
+ '
+ '''
+
+ if only_html:
+ return output_html
+
+ folder_location = "None"
+ default_subfolder = "None"
+ sub_folders = ["None"]
+
+ for root, dirs, files in os.walk(model_folder, followlinks=True):
+ for filename in files:
+ if filename.endswith('.json'):
+ json_file_path = os.path.join(root, filename)
+ with open(json_file_path, 'r', encoding="utf-8") as f:
+ try:
+ data = json.load(f)
+ sha256 = data.get('sha256')
+ if sha256:
+ sha256 = sha256.upper()
+ if sha256 == sha256_value:
+ folder_location = root
+ BtnDownInt = False
+ BtnDel = True
+
+ break
+ except Exception as e:
+ print(f"Error decoding JSON: {str(e)}")
+ else:
+ for filename in files:
+ if filename == model_filename or filename == cleaned_name(model_filename):
+ folder_location = root
+ BtnDownInt = False
+ BtnDel = True
+ break
+
+ if folder_location != "None":
+ break
+
+ insert_sub_1 = getattr(opts, "insert_sub_1", False)
+ insert_sub_2 = getattr(opts, "insert_sub_2", False)
+ insert_sub_3 = getattr(opts, "insert_sub_3", False)
+ insert_sub_4 = getattr(opts, "insert_sub_4", False)
+ insert_sub_5 = getattr(opts, "insert_sub_5", False)
+ insert_sub_6 = getattr(opts, "insert_sub_6", False)
+ insert_sub_7 = getattr(opts, "insert_sub_7", False)
+ insert_sub_8 = getattr(opts, "insert_sub_8", False)
+ insert_sub_9 = getattr(opts, "insert_sub_9", False)
+ insert_sub_10 = getattr(opts, "insert_sub_10", False)
+ insert_sub_11 = getattr(opts, "insert_sub_11", False)
+ insert_sub_12 = getattr(opts, "insert_sub_12", False)
+ insert_sub_13 = getattr(opts, "insert_sub_13", False)
+ insert_sub_14 = getattr(opts, "insert_sub_14", False)
+ dot_subfolders = getattr(opts, "dot_subfolders", True)
+
+ try:
+ sub_folders = ["None"]
+ for root, dirs, _ in os.walk(model_folder, followlinks=True):
+ if dot_subfolders:
+ dirs = [d for d in dirs if not d.startswith('.')]
+ dirs = [d for d in dirs if not any(part.startswith('.') for part in os.path.join(root, d).split(os.sep))]
+ for d in dirs:
+ sub_folder = os.path.relpath(os.path.join(root, d), model_folder)
+ if sub_folder:
+ sub_folders.append(f'{os.sep}{sub_folder}')
+
+ sub_folders.remove("None")
+ sub_folders = sorted(sub_folders, key=lambda x: (x.lower(), x))
+ sub_folders.insert(0, "None")
+ base = cleaned_name(model_uploader)
+ author = cleaned_name(model_uploader)
+ name = cleaned_name(model_name)
+ ver = cleaned_name(model_version)
+
+ if insert_sub_1:
+ sub_folders.insert(1, os.path.join(os.sep, base))
+ if insert_sub_2:
+ sub_folders.insert(2, os.path.join(os.sep, base, author))
+ if insert_sub_3:
+ sub_folders.insert(3, os.path.join(os.sep, base, author, name))
+ if insert_sub_4:
+ sub_folders.insert(4, os.path.join(os.sep, base, author, name, ver))
+ if insert_sub_5:
+ sub_folders.insert(5, os.path.join(os.sep, base, name))
+ if insert_sub_6:
+ sub_folders.insert(6, os.path.join(os.sep, base, name, ver))
+ if insert_sub_7:
+ sub_folders.insert(7, os.path.join(os.sep, author))
+ if insert_sub_8:
+ sub_folders.insert(8, os.path.join(os.sep, author, base))
+ if insert_sub_9:
+ sub_folders.insert(9, os.path.join(os.sep, author, base, name))
+ if insert_sub_10:
+ sub_folders.insert(10, os.path.join(os.sep, author, base, name, ver))
+ if insert_sub_11:
+ sub_folders.insert(11, os.path.join(os.sep, author, name))
+ if insert_sub_12:
+ sub_folders.insert(12, os.path.join(os.sep, author, name, ver))
+ if insert_sub_13:
+ sub_folders.insert(13, os.path.join(os.sep, name))
+ if insert_sub_14:
+ sub_folders.insert(14, os.path.join(os.sep, name, ver))
+
+ list = set()
+ sub_folders = [x for x in sub_folders if not (x in list or list.add(x))]
+ except:
+ sub_folders = ["None"]
+
+ default_sub = sub_folder_value(content_type, desc)
+
+ variable_mapping = {
+ "Base model": base,
+ "Author name": author,
+ "Model name": name,
+ "Model version": ver
+ }
+
+ if any(key in default_sub for key in variable_mapping.keys()):
+ path_components = [variable_mapping.get(component.strip(os.sep), component.strip(os.sep)) for component in default_sub.split(os.sep)]
+ default_sub = os.path.join(*path_components)
+
+ if folder_location == "None":
+ folder_location = model_folder
+ if default_sub != "None":
+ folder_path = folder_location + default_sub
+ else:
+ folder_path = folder_location
+ else:
+ folder_path = folder_location
+ relative_path = os.path.relpath(folder_location, model_folder)
+ default_subfolder = f'{os.sep}{relative_path}' if relative_path != "." else default_sub if BtnDel == False else "None"
+ if gl.isDownloading:
+ item = gl.download_queue[0]
+ if int(model_id) == int(item['model_id']):
+ BtnDel = False
+ BtnDownTxt = "Download model"
+ if len(gl.download_queue) > 0:
+ BtnDownTxt = "Add to queue"
+ for item in gl.download_queue:
+ if item['version_name'] == model_version and int(item['model_id']) == int(model_id):
+ BtnDownInt = False
+ break
+
+ return (
+ gr.HTML.update(value=output_html), # Preview HTML
+ gr.Textbox.update(value=output_training, interactive=True), # Trained Tags
+ gr.Textbox.update(value=output_basemodel), # Base Model Number
+ gr.Button.update(visible=False if BtnDel else True, interactive=BtnDownInt, value=BtnDownTxt), # Download Button
+ gr.Button.update(interactive=BtnImage), # Images Button
+ gr.Button.update(visible=BtnDel, interactive=BtnDel), # Delete Button
+ gr.Dropdown.update(choices=file_list, value=default_file, interactive=True), # File List
+ gr.Textbox.update(value=cleaned_name(model_filename), interactive=True), # Model File Name
+ gr.Textbox.update(value=dl_url), # Download URL
+ gr.Textbox.update(value=model_id), # Model ID
+ gr.Textbox.update(value=sha256_value), # SHA256
+ gr.Textbox.update(interactive=True, value=folder_path if model_name else None), # Install Path
+ gr.Dropdown.update(choices=sub_folders, value=default_subfolder, interactive=True) # Sub Folder List
+ )
+ else:
+ return (
+ gr.HTML.update(value=None), # Preview HTML
+ gr.Textbox.update(value=None, interactive=False), # Trained Tags
+ gr.Textbox.update(value=''), # Base Model Number
+ gr.Button.update(visible=False if BtnDel else True, value="Download model"), # Download Button
+ gr.Button.update(interactive=False), # Images Button
+ gr.Button.update(visible=BtnDel, interactive=BtnDel), # Delete Button
+ gr.Dropdown.update(choices=None, value=None, interactive=False), # File List
+ gr.Textbox.update(value=None, interactive=False), # Model File Name
+ gr.Textbox.update(value=None), # Download URL
+ gr.Textbox.update(value=None), # Model ID
+ gr.Textbox.update(value=None), # SHA256
+ gr.Textbox.update(interactive=False, value=None), # Install Path
+ gr.Dropdown.update(choices=None, value=None, interactive=False) # Sub Folder List
+ )
+
+def sub_folder_value(content_type, desc=None):
+ use_LORA = getattr(opts, "use_LORA", False)
+ if content_type in ["LORA", "LoCon"] and use_LORA:
+ folder = getattr(opts, "LORA_LoCon_subfolder", "None")
+ elif content_type == "Upscaler":
+ for upscale_type in ["SWINIR", "REALESRGAN", "GFPGAN", "BSRGAN"]:
+ if upscale_type in desc:
+ folder = getattr(opts, f"{upscale_type}_subfolder", "None")
+ folder = getattr(opts, "ESRGAN_subfolder", "None")
+ else:
+ folder = getattr(opts, f"{content_type}_subfolder", "None")
+ if folder == None:
+ return "None"
+ return folder
+
+def update_file_info(model_string, model_version, file_metadata):
+ file_list = []
+ is_LORA = False
+ embed_check = False
+ model_name = None
+ model_id = None
+ model_name, model_id = extract_model_info(model_string)
+
+ if model_version and "[Installed]" in model_version:
+ model_version = model_version.replace(" [Installed]", "")
+ if model_id and model_version:
+ for item in gl.json_data['items']:
+ if int(item['id']) == int(model_id):
+ content_type = item['type']
+ if content_type == "LORA":
+ is_LORA = True
+ desc = item['description']
+ for model in item['modelVersions']:
+ if model['name'] == model_version:
+ for file in model['files']:
+ size = file['metadata'].get('size', 'Unknown')
+ format = file['metadata'].get('format', 'Unknown')
+ unique_file_name = f"{size} {format}"
+ file_list.append(unique_file_name)
+ pass
+
+ if is_LORA and file_list:
+ extracted_formats = [file.split(' ')[1] for file in file_list]
+ if "SafeTensor" in extracted_formats and "PickleTensor" in extracted_formats:
+ embed_check = True
+
+ for file in model['files']:
+ model_id = item['id']
+ file_name = file.get('name', 'Unknown')
+ sha256 = file['hashes'].get('SHA256', 'Unknown')
+ metadata = file.get('metadata', {})
+ file_size = metadata.get('size', 'Unknown')
+ file_format = metadata.get('format', 'Unknown')
+ file_fp = metadata.get('fp', 'Unknown')
+ sizeKB = file.get('sizeKB', 0)
+ sizeB = sizeKB * 1024
+ filesize = _download.convert_size(sizeB)
+
+ if f"{file_size} {file_format} {file_fp} ({filesize})" == file_metadata:
+ installed = False
+ folder_location = "None"
+ model_folder = os.path.join(contenttype_folder(content_type, desc))
+ if embed_check and file_format == "PickleTensor":
+ if sizeKB <= 100:
+ model_folder = os.path.join(contenttype_folder("TextualInversion"))
+ dl_url = file['downloadUrl']
+ gl.json_info = item
+ for root, _, files in os.walk(model_folder, followlinks=True):
+ if file_name in files:
+ installed = True
+ folder_location = root
+ break
+
+ if not installed:
+ for root, _, files in os.walk(model_folder, followlinks=True):
+ for filename in files:
+ if filename.endswith('.json'):
+ with open(os.path.join(root, filename), 'r', encoding="utf-8") as f:
+ try:
+ data = json.load(f)
+ sha256_value = data.get('sha256')
+ if sha256_value != None and sha256_value.upper() == sha256:
+ folder_location = root
+ installed = True
+ break
+ except Exception as e:
+ print(f"Error decoding JSON: {str(e)}")
+ default_sub = sub_folder_value(content_type, desc)
+ if folder_location == "None":
+ folder_location = model_folder
+ if default_sub != "None":
+ folder_path = folder_location + default_sub
+ else:
+ folder_path = folder_location
+ else:
+ folder_path = folder_location
+ relative_path = os.path.relpath(folder_location, model_folder)
+ default_subfolder = f'{os.sep}{relative_path}' if relative_path != "." else default_sub if installed == False else "None"
+ BtnDownInt = not installed
+ BtnDownTxt = "Download model"
+ if len(gl.download_queue) > 0:
+ BtnDownTxt = "Add to queue"
+ for item in gl.download_queue:
+ if item['version_name'] == model_version:
+ BtnDownInt = False
+ break
+
+ return (
+ gr.Textbox.update(value=cleaned_name(file['name']), interactive=True), # Model File Name Textbox
+ gr.Textbox.update(value=dl_url), # Download URL Textbox
+ gr.Textbox.update(value=model_id), # Model ID Textbox
+ gr.Textbox.update(value=sha256), # sha256 textbox
+ gr.Button.update(interactive=BtnDownInt, visible=False if installed else True, value=BtnDownTxt), # Download Button
+ gr.Button.update(interactive=True if installed else False, visible=True if installed else False), # Delete Button
+ gr.Textbox.update(interactive=True, value=folder_path if model_name else None), # Install Path
+ gr.Dropdown.update(value=default_subfolder, interactive=True) # Sub Folder List
+ )
+
+ return (
+ gr.Textbox.update(value=None, interactive=False), # Model File Name Textbox
+ gr.Textbox.update(value=None), # Download URL Textbox
+ gr.Textbox.update(value=None), # Model ID Textbox
+ gr.Textbox.update(value=None), # sha256 textbox
+ gr.Button.update(interactive=False, visible=True), # Download Button
+ gr.Button.update(interactive=False, visible=False), # Delete Button
+ gr.Textbox.update(interactive=False, value=None), # Install Path
+ gr.Dropdown.update(choices=None, value=None, interactive=False) # Sub Folder List
+ )
+
+def get_headers():
+ api_key = getattr(opts, "custom_api_key", "")
+ try:
+ user_agent = UserAgent().chrome
+ except ImportError:
+ user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36"
+ headers = {
+ 'User-Agent': user_agent,
+ 'Sec-Ch-Ua': '"Brave";v="119", "Chromium";v="119", "Not?A_Brand";v="24"',
+ 'Sec-Ch-Ua-Mobile': '?0',
+ 'Sec-Ch-Ua-Platform': '"Windows"',
+ 'Sec-Fetch-Dest': 'document',
+ 'Sec-Fetch-Mode': 'navigate',
+ 'Sec-Fetch-Site': 'none',
+ 'Sec-Fetch-User': '?1',
+ 'Sec-Gpc': '1',
+ 'Upgrade-Insecure-Requests': '1',
+ }
+ if api_key:
+ headers['Authorization'] = f'Bearer {api_key}'
+
+ return headers
+
+def request_civit_api(api_url=None):
+ headers = get_headers()
+ try:
+ response = requests.get(api_url, headers=headers, timeout=(10, 30))
+ response.raise_for_status()
+ except requests.exceptions.RequestException as e:
+ print(f"Error: {e}")
+ return "timeout"
+ else:
+ response.encoding = "utf-8"
+ try:
+ data = json.loads(response.text)
+ except json.JSONDecodeError:
+ print("The CivitAI servers are currently offline. Please try again later.")
+ return "timeout"
+ return data
\ No newline at end of file
diff --git a/sd-civitai-browser-plus/scripts/civitai_download.py b/sd-civitai-browser-plus/scripts/civitai_download.py
new file mode 100644
index 0000000000000000000000000000000000000000..71bfceab62b6ea493f70d3c1dcb91e211b36cde3
--- /dev/null
+++ b/sd-civitai-browser-plus/scripts/civitai_download.py
@@ -0,0 +1,768 @@
+import requests
+import gradio as gr
+import time
+import subprocess
+import threading
+import os
+import re
+import random
+import platform
+import stat
+import json
+import time
+from pathlib import Path
+from modules.shared import opts, cmd_opts
+from scripts.civitai_global import print
+import scripts.civitai_global as gl
+import scripts.civitai_api as _api
+import scripts.civitai_file_manage as _file
+try:
+ from zip_unicode import ZipHandler
+except ImportError:
+ print("Python module 'ZipUnicode' has not been imported correctly, please try to restart or install it manually.")
+
+total_count = 0
+current_count = 0
+dl_manager_count = 0
+
+def random_number(prev=None):
+ number = str(random.randint(10000, 99999))
+ while number == prev:
+ number = str(random.randint(10000, 99999))
+
+ return number
+
+gl.init()
+rpc_secret = "R7T5P2Q9K6"
+try:
+ queue = not cmd_opts.no_gradio_queue
+except AttributeError:
+ queue = not cmd_opts.disable_queue
+except:
+ queue = True
+
+def start_aria2_rpc():
+ start_file = os.path.join(aria2path, '_')
+ running_file = os.path.join(aria2path, 'running')
+ null = open(os.devnull, 'w')
+
+ if os.path.exists(running_file):
+ try:
+ if os_type == 'Linux':
+ env = os.environ.copy()
+ env['PATH'] = '/usr/bin:' + env['PATH']
+ subprocess.Popen("pkill aria2", shell=True, env=env)
+ else:
+ subprocess.Popen(stop_rpc, stdout=null, stderr=null)
+ time.sleep(1)
+ except Exception as e:
+ print(f"Failed to stop Aria2 RPC : {e}")
+ else:
+ if os.path.exists(start_file):
+ os.rename(start_file, running_file)
+ return
+ else:
+ with open(start_file, 'w', encoding="utf-8"):
+ pass
+
+ try:
+ show_log = getattr(opts, "show_log", False)
+ aria2_flags = getattr(opts, "aria2_flags", "")
+ cmd = f'"{aria2}" --enable-rpc --rpc-listen-all --rpc-listen-port=24000 --rpc-secret {rpc_secret} --check-certificate=false --ca-certificate=" " --file-allocation=none {aria2_flags}'
+ subprocess_args = {'shell': True}
+ if not show_log:
+ subprocess_args.update({'stdout': subprocess.DEVNULL, 'stderr': subprocess.DEVNULL})
+
+ subprocess.Popen(cmd, **subprocess_args)
+ if os.path.exists(running_file):
+ print("Aria2 RPC restarted")
+ else:
+ print("Aria2 RPC started")
+ except Exception as e:
+ print(f"Failed to start Aria2 RPC server: {e}")
+
+aria2path = Path(__file__).resolve().parents[1] / "aria2"
+os_type = platform.system()
+
+if os_type == 'Windows':
+ aria2 = os.path.join(aria2path, 'win', 'aria2.exe')
+ stop_rpc = "taskkill /im aria2.exe /f"
+ start_aria2_rpc()
+elif os_type == 'Linux':
+ aria2 = os.path.join(aria2path, 'lin', 'aria2')
+ st = os.stat(aria2)
+ os.chmod(aria2, st.st_mode | stat.S_IEXEC)
+ stop_rpc = "pkill aria2"
+ start_aria2_rpc()
+
+class TimeOutFunction(Exception):
+ pass
+
+def create_model_item(dl_url, model_filename, install_path, model_name, version_name, model_sha256, model_id, create_json, from_batch=False):
+ global dl_manager_count
+ if model_id:
+ model_id = int(model_id)
+ if model_sha256:
+ model_sha256 = model_sha256.upper()
+
+ filtered_items = []
+
+ for item in gl.json_data['items']:
+ if int(item['id']) == int(model_id):
+ filtered_items.append(item)
+ content_type = item['type']
+ desc = item['description']
+ main_folder = _api.contenttype_folder(content_type, desc)
+ break
+
+ sub_folder = os.path.normpath(os.path.relpath(install_path, main_folder))
+
+ model_json = {"items": filtered_items}
+ model_versions = _api.update_model_versions(model_id)
+ (preview_html,_,_,_,_,_,_,_,_,_,_,existing_path,_) = _api.update_model_info(None, model_versions.get('value'), False, model_id)
+
+ for item in gl.download_queue:
+ if item['dl_url'] == dl_url:
+ return None
+
+ dl_manager_count += 1
+
+ item = {
+ "dl_id": dl_manager_count,
+ "dl_url" : dl_url,
+ "model_filename" : model_filename,
+ "install_path" : install_path,
+ "model_name" : model_name,
+ "version_name" : version_name,
+ "model_sha256" : model_sha256,
+ "model_id" : model_id,
+ "create_json" : create_json,
+ "model_json" : model_json,
+ "model_versions" : model_versions,
+ "preview_html" : preview_html['value'],
+ "existing_path": existing_path['value'],
+ "from_batch" : from_batch,
+ "sub_folder" : sub_folder
+ }
+
+ return item
+
+def selected_to_queue(model_list, subfolder, download_start, create_json, current_html):
+ global total_count, current_count
+ if gl.download_queue:
+ number = download_start
+ else:
+ number = random_number(download_start)
+ total_count = 0
+ current_count = 0
+
+ model_list = json.loads(model_list)
+
+ for model_string in model_list:
+ model_name, model_id = _api.extract_model_info(model_string)
+ for item in gl.json_data['items']:
+ if int(item['id']) == int(model_id):
+ model_id, desc, content_type = item['id'], item['description'], item['type']
+ version = item.get('modelVersions', [])[0]
+ version_name = version.get('name')
+ files = version.get('files', [])
+ primary_file = next((file for file in files if file.get('primary', False)), None)
+ if primary_file:
+ model_filename = _api.cleaned_name(primary_file.get('name'))
+ model_sha256 = primary_file.get('hashes', {}).get('SHA256')
+ dl_url = primary_file.get('downloadUrl')
+ else:
+ model_filename = _api.cleaned_name(files[0].get('name'))
+ model_sha256 = files[0].get('hashes', {}).get('SHA256')
+ dl_url = files[0].get('downloadUrl')
+ break
+
+ model_folder = _api.contenttype_folder(content_type, desc)
+
+ sub_opt1 = os.path.join(os.sep, _api.cleaned_name(model_name))
+ sub_opt2 = os.path.join(os.sep, _api.cleaned_name(model_name), _api.cleaned_name(version_name))
+
+ default_sub = _api.sub_folder_value(content_type, desc)
+ if default_sub == f"{os.sep}Model Name":
+ default_sub = sub_opt1
+ elif default_sub == f"{os.sep}Model Name{os.sep}Version Name":
+ default_sub = sub_opt2
+
+ if subfolder and subfolder != "None" and subfolder != "Only available if the selected files are of the same model type":
+ from_batch = False
+ if platform.system() == "Windows":
+ subfolder = re.sub(r'[/:*?"<>|]', '', subfolder)
+
+ if not subfolder.startswith(os.sep):
+ subfolder = os.sep + subfolder
+ install_path = model_folder + subfolder
+ else:
+ from_batch = True
+ if default_sub != "None":
+ install_path = model_folder + default_sub
+ else:
+ install_path = model_folder
+
+ model_item = create_model_item(dl_url, model_filename, install_path, model_name, version_name, model_sha256, model_id, create_json, from_batch)
+ if model_item:
+ gl.download_queue.append(model_item)
+ total_count += 1
+
+ html = download_manager_html(current_html)
+
+ return (
+ gr.Button.update(interactive=False, visible=False), # Download Button
+ gr.Button.update(interactive=True, visible=True), # Cancel Button
+ gr.Button.update(interactive=True if len(gl.download_queue) > 1 else False, visible=True), # Cancel All Button
+ gr.Textbox.update(value=number), # Download Start Trigger
+ gr.HTML.update(value=''), # Download Progress
+ gr.HTML.update(value=html) # Download Manager HTML
+ )
+
+def download_start(download_start, dl_url, model_filename, install_path, model_string, version_name, model_sha256, model_id, create_json, current_html):
+ global total_count, current_count
+ if model_string:
+ model_name, _ = _api.extract_model_info(model_string)
+ model_item = create_model_item(dl_url, model_filename, install_path, model_name, version_name, model_sha256, model_id, create_json)
+
+ gl.download_queue.append(model_item)
+
+ if len(gl.download_queue) > 1:
+ number = download_start
+ total_count += 1
+ else:
+ number = random_number(download_start)
+ total_count = 1
+ current_count = 0
+
+ html = download_manager_html(current_html)
+
+ return (
+ gr.Button.update(interactive=False, visible=True), # Download Button
+ gr.Button.update(interactive=True, visible=True), # Cancel Button
+ gr.Button.update(interactive=True if len(gl.download_queue) > 1 else False, visible=True), # Cancel All Button
+ gr.Textbox.update(value=number), # Download Start Trigger
+ gr.HTML.update(value=''), # Download Progress
+ gr.HTML.update(value=html) # Download Manager HTML
+ )
+
+def download_finish(model_filename, version, model_id):
+ if model_id:
+ model_id = int(model_id)
+ model_versions = _api.update_model_versions(model_id)
+ else:
+ model_versions = None
+ if model_versions:
+ version_choices = model_versions.get('choices', [])
+ else:
+ version_choices = []
+ prev_version = gl.last_version + " [Installed]"
+
+ if prev_version in version_choices:
+ version = prev_version
+ Del = True
+ Down = False
+ else:
+ Del = False
+ Down = True
+
+ if gl.cancel_status:
+ Del = False
+ Down = True
+
+ gl.download_fail = False
+ gl.cancel_status = False
+
+ return (
+ gr.Button.update(interactive=model_filename, visible=Down, value="Download model"), # Download Button
+ gr.Button.update(interactive=False, visible=False), # Cancel Button
+ gr.Button.update(interactive=False, visible=False), # Cancel All Button
+ gr.Button.update(interactive=Del, visible=Del), # Delete Button
+ gr.HTML.update(value=''), # Download Progress
+ gr.Dropdown.update(value=version, choices=version_choices) # Version Dropdown
+ )
+
+def download_cancel():
+ gl.cancel_status = True
+ gl.download_fail = True
+ if gl.download_queue:
+ item = gl.download_queue[0]
+
+ while True:
+ if not gl.isDownloading:
+ if item:
+ model_string = f"{item['model_name']} ({item['model_id']})"
+ _file.delete_model(0, item['model_filename'], model_string, item['version_name'], False, model_ver=item['model_versions'], model_json=item['model_json'])
+ break
+ else:
+ time.sleep(0.5)
+ return
+
+def download_cancel_all():
+ gl.cancel_status = True
+ gl.download_fail = True
+
+ if gl.download_queue:
+ item = gl.download_queue[0]
+
+ while True:
+ if not gl.isDownloading:
+ if item:
+ model_string = f"{item['model_name']} ({item['model_id']})"
+ _file.delete_model(0, item['model_filename'], model_string, item['version_name'], False, model_ver=item['model_versions'], model_json=item['model_json'])
+ gl.download_queue = []
+ break
+ else:
+ time.sleep(0.5)
+ return
+
+def convert_size(size):
+ for unit in ['bytes', 'KB', 'MB', 'GB']:
+ if size < 1024:
+ return f"{size:.2f} {unit}"
+ size /= 1024
+ return f"{size:.2f} GB"
+
+def get_download_link(url):
+ headers = _api.get_headers()
+
+ response = requests.get(url, headers=headers, allow_redirects=False)
+
+ if 300 <= response.status_code <= 308:
+ if "login?returnUrl" in response.text and "reason=download-auth" in response.text:
+ return "NO_API"
+
+ download_link = response.headers["Location"]
+ return download_link
+ else:
+ return None
+
+def download_file(url, file_path, install_path, progress=gr.Progress() if queue else None):
+ try:
+ disable_dns = getattr(opts, "disable_dns", False)
+ split_aria2 = getattr(opts, "split_aria2", 64)
+ max_retries = 5
+ gl.download_fail = False
+ aria2_rpc_url = "http://localhost:24000/jsonrpc"
+
+ file_name = os.path.basename(file_path)
+
+ download_link = get_download_link(url)
+ if not download_link:
+ print(f'File: "{file_name}" not found on CivitAI servers, it looks like the file is not available for download.')
+ gl.download_fail = True
+ return
+
+ elif download_link == "NO_API":
+ print(f'File: "{file_name}" requires a personal CivitAI API to be downloaded, you can set your own API key in the CivitAI Browser+ settings in the SD-WebUI settings tab')
+ gl.download_fail = "NO_API"
+ if progress != None:
+ progress(0, desc=f'File: "{file_name}" requires a personal CivitAI API to be downloaded, you can set your own API key in the CivitAI Browser+ settings in the SD-WebUI settings tab')
+ time.sleep(5)
+ return
+
+ if os.path.exists(file_path):
+ os.remove(file_path)
+
+ if disable_dns:
+ dns = "false"
+ else:
+ dns = "true"
+
+ options = {
+ "dir": install_path,
+ "max-connection-per-server": str(f"{split_aria2}"),
+ "split": str(f"{split_aria2}"),
+ "async-dns": dns,
+ "out": file_name
+ }
+
+ payload = json.dumps({
+ "jsonrpc": "2.0",
+ "id": "1",
+ "method": "aria2.addUri",
+ "params": ["token:" + rpc_secret, [download_link], options]
+ })
+
+ try:
+ response = requests.post(aria2_rpc_url, data=payload)
+ data = json.loads(response.text)
+ if 'result' not in data:
+ raise ValueError(f'Failed to start download: {data}')
+ gid = data['result']
+ except Exception as e:
+ print(f"Failed to start download: {e}")
+ gl.download_fail = True
+ return
+
+ while True:
+ if gl.cancel_status:
+ payload = json.dumps({
+ "jsonrpc": "2.0",
+ "id": "1",
+ "method": "aria2.remove",
+ "params": ["token:" + rpc_secret, gid]
+ })
+ requests.post(aria2_rpc_url, data=payload)
+ if progress != None:
+ progress(0, desc=f"Download cancelled.")
+ return
+
+ try:
+ payload = json.dumps({
+ "jsonrpc": "2.0",
+ "id": "1",
+ "method": "aria2.tellStatus",
+ "params": ["token:" + rpc_secret, gid]
+ })
+
+ response = requests.post(aria2_rpc_url, data=payload)
+ status_info = json.loads(response.text)['result']
+
+ total_length = int(status_info['totalLength'])
+ completed_length = int(status_info['completedLength'])
+ download_speed = int(status_info['downloadSpeed'])
+
+ progress_percent = (completed_length / total_length) * 100 if total_length else 0
+
+ remaining_size = total_length - completed_length
+ if download_speed > 0:
+ eta_seconds = remaining_size / download_speed
+ eta_formatted = time.strftime("%H:%M:%S", time.gmtime(eta_seconds))
+ else:
+ eta_formatted = "XX:XX:XX"
+ if progress != None:
+ progress(progress_percent / 100, desc=f"Downloading: {file_name} - {convert_size(completed_length)}/{convert_size(total_length)} - Speed: {convert_size(download_speed)}/s - ETA: {eta_formatted} - Queue: {current_count}/{total_count}")
+
+ if status_info['status'] == 'complete':
+ print(f"Model saved to: {file_path}")
+ if progress != None:
+ progress(1, desc=f"Model saved to: {file_path}")
+ gl.download_fail = False
+ return
+
+ if status_info['status'] == 'error':
+ if progress != None:
+ progress(0, desc=f"Encountered an error during download of: \"{file_name}\" Please try again.")
+ gl.download_fail = True
+ return
+
+ time.sleep(0.25)
+
+ except Exception as e:
+ print(f"Error occurred during Aria2 status update: {e}")
+ max_retries -= 1
+ if max_retries == 0:
+ if progress != None:
+ progress(0, desc=f"Encountered an error during download of: \"{file_name}\" Please try again.")
+ gl.download_fail = True
+ return
+ time.sleep(5)
+ except:
+ if progress != None:
+ progress(0, desc=f"Encountered an error during download of: \"{file_name}\" Please try again.")
+ gl.download_fail = True
+ if os.path.exists(file_path):
+ os.remove(file_path)
+ time.sleep(5)
+
+def info_to_json(install_path, model_id, model_sha256, unpackList=None):
+ json_file = os.path.splitext(install_path)[0] + ".json"
+ if os.path.exists(json_file):
+ try:
+ with open(json_file, 'r', encoding="utf-8") as f:
+ data = json.load(f)
+ except Exception as e:
+ print(f"Failed to open {json_file}: {e}")
+ else:
+ data = {}
+
+ data['modelId'] = model_id
+ data['sha256'] = model_sha256
+ if unpackList:
+ data['unpackList'] = unpackList
+
+ with open(json_file, 'w', encoding="utf-8") as f:
+ json.dump(data, f, indent=4)
+
+def download_file_old(url, file_path, progress=gr.Progress() if queue else None):
+ try:
+ gl.download_fail = False
+ max_retries = 5
+ if os.path.exists(file_path):
+ os.remove(file_path)
+
+ downloaded_size = 0
+ tokens = re.split(re.escape(os.sep), file_path)
+ file_name_display = tokens[-1]
+ start_time = time.time()
+ last_update_time = 0
+ update_interval = 0.25
+
+ download_link = get_download_link(url)
+ if not download_link:
+ print(f'File: "{file_name_display}" not found on CivitAI servers, it looks like the file is not available for download.')
+ if progress != None:
+ progress(0, desc=f'File: "{file_name_display}" not found on CivitAI servers, it looks like the file is not available for download.')
+ time.sleep(5)
+ gl.download_fail = True
+ return
+
+ elif download_link == "NO_API":
+ print(f'File: "{file_name_display}" requires a personal CivitAI API key to be downloaded, you can set your own API key in the CivitAI Browser+ settings in the SD-WebUI settings tab')
+ gl.download_fail = "NO_API"
+ if progress != None:
+ progress(0, desc=f'File: "{file_name_display}" requires a personal CivitAI API key to be downloaded, you can set your own API key in the CivitAI Browser+ settings in the SD-WebUI settings tab')
+ time.sleep(5)
+ return
+
+ while True:
+ if gl.cancel_status:
+ if progress != None:
+ progress(0, desc=f"Download cancelled.")
+ return
+ if os.path.exists(file_path):
+ downloaded_size = os.path.getsize(file_path)
+ headers = {"Range": f"bytes={downloaded_size}-"}
+ else:
+ headers = {}
+ with open(file_path, "ab") as f:
+ while gl.isDownloading:
+ try:
+ if gl.cancel_status:
+ if progress != None:
+ progress(0, desc=f"Download cancelled.")
+ return
+ try:
+ if gl.cancel_status:
+ if progress != None:
+ progress(0, desc=f"Download cancelled.")
+ return
+ response = requests.get(download_link, headers=headers, stream=True, timeout=4)
+ if response.status_code == 404:
+ if progress != None:
+ progress(0, desc=f"Encountered an error during download of: {file_name_display}, file is not found on CivitAI servers.")
+ gl.download_fail = True
+ return
+ total_size = int(response.headers.get("Content-Length", 0))
+ except:
+ raise TimeOutFunction("Timed Out")
+
+ if total_size == 0:
+ total_size = downloaded_size
+
+ for chunk in response.iter_content(chunk_size=1024):
+ if chunk:
+ if gl.cancel_status:
+ if progress != None:
+ progress(0, desc=f"Download cancelled.")
+ return
+ f.write(chunk)
+ downloaded_size += len(chunk)
+ elapsed_time = time.time() - start_time
+ download_speed = downloaded_size / elapsed_time
+ remaining_size = total_size - downloaded_size
+ if download_speed > 0:
+ eta_seconds = remaining_size / download_speed
+ eta_formatted = time.strftime("%H:%M:%S", time.gmtime(eta_seconds))
+ else:
+ eta_formatted = "XX:XX:XX"
+ current_time = time.time()
+ if current_time - last_update_time >= update_interval:
+ if progress != None:
+ progress(downloaded_size / total_size, desc=f"Downloading: {file_name_display} {convert_size(downloaded_size)} / {convert_size(total_size)} - Speed: {convert_size(int(download_speed))}/s - ETA: {eta_formatted} - Queue: {current_count}/{total_count}")
+ last_update_time = current_time
+ if gl.isDownloading == False:
+ response.close
+ break
+ downloaded_size = os.path.getsize(file_path)
+ break
+
+ except TimeOutFunction:
+ if progress != None:
+ progress(0, desc="CivitAI API did not respond, retrying...")
+ max_retries -= 1
+ if max_retries == 0:
+ if progress != None:
+ progress(0, desc=f"Encountered an error during download of: {file_name_display}, please try again.")
+ gl.download_fail = True
+ return
+ time.sleep(5)
+
+ if (gl.isDownloading == False):
+ break
+
+ gl.isDownloading = False
+ downloaded_size = os.path.getsize(file_path)
+ if downloaded_size >= total_size:
+ if not gl.cancel_status:
+ print(f"Model saved to: {file_path}")
+ if progress != None:
+ progress(1, desc=f"Model saved to: {file_path}")
+ gl.download_fail = False
+ return
+
+ else:
+ if progress != None:
+ progress(0, desc=f"Encountered an error during download of: {file_name_display}, please try again.")
+ print(f"File download failed: {file_name_display}")
+ gl.download_fail = True
+ if os.path.exists(file_path):
+ os.remove(file_path)
+ except:
+ if progress != None:
+ progress(0, desc=f"Encountered an error during download of: {file_name_display}, please try again.")
+ gl.download_fail = True
+ if os.path.exists(file_path):
+ os.remove(file_path)
+ time.sleep(5)
+
+def download_create_thread(download_finish, queue_trigger, progress=gr.Progress() if queue else None):
+ global current_count
+ current_count += 1
+ if not gl.download_queue:
+ return (
+ gr.HTML.update(), # Download Progress HTML
+ gr.Textbox.update(value=None), # Current Model
+ gr.Textbox.update(value=random_number(download_finish)), # Download Finish Trigger
+ gr.Textbox.update(value=queue_trigger), # Queue Trigger
+ gr.Button.update(interactive=False) # Cancel All Button
+ )
+ item = gl.download_queue[0]
+ gl.cancel_status = False
+ use_aria2 = getattr(opts, "use_aria2", True)
+ unpack_zip = getattr(opts, "unpack_zip", False)
+ save_all_images = getattr(opts, "auto_save_all_img", False)
+ gl.recent_model = item['model_name']
+ gl.last_version = item['version_name']
+
+ if item['from_batch']:
+ item['install_path'] = item['existing_path']
+
+ gl.isDownloading = True
+ _file.make_dir(item['install_path'])
+
+ path_to_new_file = os.path.join(item['install_path'], item['model_filename'])
+
+ if use_aria2 and os_type != 'Darwin':
+ thread = threading.Thread(target=download_file, args=(item['dl_url'], path_to_new_file, item['install_path'], progress))
+ else:
+ thread = threading.Thread(target=download_file_old, args=(item['dl_url'], path_to_new_file, progress))
+ thread.start()
+ thread.join()
+
+ if not gl.cancel_status or gl.download_fail:
+ if os.path.exists(path_to_new_file):
+ unpackList = []
+ if unpack_zip:
+ try:
+ if path_to_new_file.endswith('.zip'):
+ directory = Path(os.path.dirname(path_to_new_file))
+ zip_handler = ZipHandler(path_to_new_file)
+
+ for _, decoded_name in zip_handler.name_map.items():
+ unpackList.append(decoded_name)
+
+ zip_handler.extract_all(directory)
+ zip_handler.zip_ref.close()
+
+ print(f"Successfully extracted {item['model_filename']} to {directory}")
+ os.remove(path_to_new_file)
+ except ImportError:
+ print("Python module 'ZipUnicode' has not been imported correctly, cannot extract zip file. Please try to restart or install it manually.")
+ except Exception as e:
+ print(f"Failed to extract {item['model_filename']} with error: {e}")
+ if not gl.cancel_status:
+ if item['create_json']:
+ _file.save_model_info(item['install_path'], item['model_filename'], item['sub_folder'], item['model_sha256'], item['preview_html'], api_response=item['model_json'])
+ info_to_json(path_to_new_file, item['model_id'], item['model_sha256'], unpackList)
+ _file.save_preview(path_to_new_file, item['model_json'], True, item['model_sha256'])
+ if save_all_images:
+ _file.save_images(item['preview_html'], item['model_filename'], item['install_path'], item['sub_folder'], api_response=item['model_json'])
+
+ base_name = os.path.splitext(item['model_filename'])[0]
+ base_name_preview = base_name + '.preview'
+
+ if gl.download_fail:
+ for root, dirs, files in os.walk(item['install_path'], followlinks=True):
+ for file in files:
+ file_base_name = os.path.splitext(file)[0]
+ if file_base_name == base_name or file_base_name == base_name_preview:
+ path_file = os.path.join(root, file)
+ os.remove(path_file)
+
+ if gl.cancel_status:
+ print(f'Cancelled download of "{item["model_filename"]}"')
+ else:
+ if not gl.download_fail == "NO_API":
+ print(f'Error occured during download of "{item["model_filename"]}"')
+
+ if gl.cancel_status:
+ card_name = None
+ else:
+ model_string = f"{item['model_name']} ({item['model_id']})"
+ (card_name, _, _) = _file.card_update(item['model_versions'], model_string, item['version_name'], True)
+
+ if len(gl.download_queue) != 0:
+ gl.download_queue.pop(0)
+ gl.isDownloading = False
+ time.sleep(2)
+
+ if len(gl.download_queue) == 0:
+ finish_nr = random_number(download_finish)
+ queue_nr = queue_trigger
+ else:
+ finish_nr = download_finish
+ queue_nr = random_number(queue_trigger)
+
+ return (
+ gr.HTML.update(), # Download Progress HTML
+ gr.Textbox.update(value=card_name), # Current Model
+ gr.Textbox.update(value=finish_nr), # Download Finish Trigger
+ gr.Textbox.update(value=queue_nr), # Queue Trigger
+ gr.Button.update(interactive=True if len(gl.download_queue) > 1 else False) # Cancel All Button
+ )
+
+def remove_from_queue(dl_id):
+ global total_count
+ for item in gl.download_queue:
+ if int(dl_id) == int(item['dl_id']):
+ gl.download_queue.remove(item)
+ total_count -= 1
+ return
+
+def arrange_queue(input):
+ id_and_index = input.split('.')
+ dl_id = int(id_and_index[0])
+ index = int(id_and_index[1]) + 1
+ for item in gl.download_queue:
+ if int(item['dl_id']) == dl_id:
+ current_item = gl.download_queue.pop(gl.download_queue.index(item))
+ gl.download_queue.insert(index, current_item)
+ break
+
+def get_style(size, left_border):
+ return f"flex-grow: {size};" + ("border-left: 1px solid var(--border-color-primary);" if left_border else "") + "padding: 5px 10px 5px 10px;width: 0;align-self: center;"
+
+def download_manager_html(current_html):
+ html = current_html.rsplit("", 1)[0]
+ pattern = r'dl_id="(\d+)"'
+ matches = re.findall(pattern, html)
+ existing_item_ids = [int(match) for match in matches]
+
+ for item in gl.download_queue:
+ if not item['dl_id'] in existing_item_ids:
+ download_item = f'''
+
+
{item['model_name']}
+
{item['version_name']}
+
{item['install_path']}
+
In queue...
+
Remove
+
+ '''
+ html = html + download_item
+
+ html = html + ""
+
+ return html
\ No newline at end of file
diff --git a/sd-civitai-browser-plus/scripts/civitai_file_manage.py b/sd-civitai-browser-plus/scripts/civitai_file_manage.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e3e51667d54e49ee804b72fb1da1b9c517a4821
--- /dev/null
+++ b/sd-civitai-browser-plus/scripts/civitai_file_manage.py
@@ -0,0 +1,1158 @@
+import json
+import gradio as gr
+import urllib.request
+import urllib.parse
+import urllib.error
+import os
+import io
+import re
+import time
+import errno
+import requests
+import hashlib
+from pathlib import Path
+from urllib.parse import urlparse
+
+from sympy import preview
+from modules.shared import cmd_opts, opts
+from scripts.civitai_global import print
+import scripts.civitai_global as gl
+import scripts.civitai_api as _api
+import scripts.civitai_file_manage as _file
+import scripts.civitai_download as _download
+
+try:
+ from send2trash import send2trash
+except ImportError:
+ print("Python module 'send2trash' has not been imported correctly, please try to restart or install it manually.")
+try:
+ from bs4 import BeautifulSoup
+except ImportError:
+ print("Python module 'BeautifulSoup' has not been imported correctly, please try to restart or install it manually.")
+
+gl.init()
+
+offlineHTML = '
The Civit-API has timed out, please try again. The servers might be too busy or the selected model could not be found.
'
+css_path = Path(__file__).resolve().parents[1] / "style_html.css"
+no_update = False
+from_ver = False
+from_tag = False
+from_installed = False
+try:
+ queue = not cmd_opts.no_gradio_queue
+except AttributeError:
+ queue = not cmd_opts.disable_queue
+except:
+ queue = True
+
+def delete_model(delete_finish=None, model_filename=None, model_string=None, list_versions=None, sha256=None, selected_list=None, model_ver=None, model_json=None):
+ deleted = False
+ model_id = None
+
+ if model_string:
+ _, model_id = _api.extract_model_info(model_string)
+
+ if not model_ver:
+ model_versions = _api.update_model_versions(model_id)
+ else: model_versions = model_ver
+
+ (model_name, ver_value, ver_choices) = _file.card_update(model_versions, model_string, list_versions, False)
+ if not model_json:
+ if model_id != None:
+ selected_content_type = None
+ for item in gl.json_data['items']:
+ if int(item['id']) == int(model_id):
+ selected_content_type = item['type']
+ desc = item['description']
+ break
+
+ if selected_content_type is None:
+ print("Model ID not found in json_data. (delete_model)")
+ return
+ else:
+ for item in model_json["items"]:
+ selected_content_type = item['type']
+ desc = item['description']
+
+ model_folder = os.path.join(_api.contenttype_folder(selected_content_type, desc))
+
+ # Delete based on provided SHA-256 hash
+ if sha256:
+ sha256_upper = sha256.upper()
+ for root, _, files in os.walk(model_folder, followlinks=True):
+ for file in files:
+ if file.endswith('.json'):
+ file_path = os.path.join(root, file)
+ try:
+ with open(file_path, 'r', encoding="utf-8") as json_file:
+ data = json.load(json_file)
+ file_sha256 = data.get('sha256', '')
+ if file_sha256:
+ file_sha256 = file_sha256.upper()
+ except Exception as e:
+ print(f"Failed to open: {file_path}: {e}")
+ file_sha256 = "0"
+
+ if file_sha256 == sha256_upper:
+ unpack_list = data.get('unpackList', [])
+ for unpacked_file in unpack_list:
+ unpacked_file_path = os.path.join(root, unpacked_file)
+ if os.path.isfile(unpacked_file_path):
+ try:
+ send2trash(unpacked_file_path)
+ print(f"File moved to trash based on unpackList: {unpacked_file_path}")
+ except:
+ os.remove(unpacked_file_path)
+ print(f"File deleted based on unpackList: {unpacked_file_path}")
+
+ base_name, _ = os.path.splitext(file)
+ if os.path.isfile(file_path):
+ try:
+ send2trash(file_path)
+ print(f"Model moved to trash based on SHA-256: {file_path}")
+ except:
+ os.remove(file_path)
+ print(f"Model deleted based on SHA-256: {file_path}")
+ delete_associated_files(root, base_name)
+ deleted = True
+
+ # Fallback to delete based on filename if not deleted based on SHA-256
+ filename_to_delete = os.path.splitext(model_filename)[0]
+ aria2_file = model_filename + ".aria2"
+ if not deleted:
+ for root, dirs, files in os.walk(model_folder, followlinks=True):
+ for file in files:
+ current_file_name = os.path.splitext(file)[0]
+ if filename_to_delete == current_file_name or aria2_file == file:
+ path_file = os.path.join(root, file)
+ if os.path.isfile(path_file):
+ try:
+ send2trash(path_file)
+ print(f"Model moved to trash based on filename: {path_file}")
+ except:
+ os.remove(path_file)
+ print(f"Model deleted based on filename: {path_file}")
+ delete_associated_files(root, current_file_name)
+
+ number = _download.random_number(delete_finish)
+
+
+ btnDwn = not selected_list or selected_list == "[]"
+
+ return (
+ gr.Button.update(interactive=btnDwn, visible=btnDwn), # Download Button
+ gr.Button.update(interactive=False, visible=False), # Cancel Button
+ gr.Button.update(interactive=False, visible=False), # Delete Button
+ gr.Textbox.update(value=number), # Delete Finish Trigger
+ gr.Textbox.update(value=model_name), # Current Model
+ gr.Dropdown.update(value=ver_value, choices=ver_choices) # Version List
+ )
+
+def delete_associated_files(directory, base_name):
+ for file in os.listdir(directory):
+ current_base_name, ext = os.path.splitext(file)
+ if current_base_name == base_name or current_base_name == f"{base_name}.preview" or current_base_name == f"{base_name}.api_info":
+ path_to_delete = os.path.join(directory, file)
+ try:
+ send2trash(path_to_delete)
+ print(f"Associated file moved to trash: {path_to_delete}")
+ except:
+ os.remove(path_to_delete)
+ print(f"Associated file deleted: {path_to_delete}")
+
+def save_preview(file_path, api_response, overwrite_toggle=False, sha256=None):
+ json_file = os.path.splitext(file_path)[0] + ".json"
+ install_path, file_name = os.path.split(file_path)
+ name = os.path.splitext(file_name)[0]
+ filename = f'{name}.preview.png'
+ image_path = os.path.join(install_path, filename)
+
+ if not overwrite_toggle:
+ if os.path.exists(image_path):
+ return
+
+ if not sha256:
+ if os.path.exists(json_file):
+ try:
+ with open(json_file, 'r', encoding="utf-8") as f:
+ data = json.load(f)
+ if 'sha256' in data and data['sha256']:
+ sha256 = data['sha256'].upper()
+ except Exception as e:
+ print(f"Failed to open {json_file}: {e}")
+ else:
+ sha256 = sha256.upper()
+
+ for item in api_response["items"]:
+ for version in item["modelVersions"]:
+ for file_entry in version["files"]:
+ if file_entry["hashes"].get("SHA256") == sha256:
+ for image in version["images"]:
+ if image["type"] == "image":
+ url_with_width = re.sub(r'/width=\d+', f'/width={image["width"]}', image["url"])
+
+ response = requests.get(url_with_width)
+ if response.status_code == 200:
+ with open(image_path, 'wb') as img_file:
+ img_file.write(response.content)
+ print(f"Preview saved at \"{image_path}\"")
+ else:
+ print(f"Failed to save preview. Status code: {response.status_code}")
+
+ return
+ print(f"No preview images found for \"{name}\"")
+ return
+
+def get_image_path(install_path, api_response, sub_folder):
+ image_location = getattr(opts, "image_location", r"")
+ sub_image_location = getattr(opts, "sub_image_location", True)
+ image_path = install_path
+ if api_response:
+ json_info = api_response['items'][0]
+ else:
+ json_info = gl.json_info
+ if image_location:
+ if sub_image_location:
+ desc = json_info['description']
+ content_type = json_info['type']
+ image_path = os.path.join(_api.contenttype_folder(content_type, desc, custom_folder=image_location))
+
+ if sub_folder and sub_folder != "None" and sub_folder != "Only available if the selected files are of the same model type":
+ image_path = os.path.join(image_path, sub_folder.lstrip("/").lstrip("\\"))
+ else:
+ image_path = Path(image_location)
+ make_dir(image_path)
+ return image_path
+
+def save_images(preview_html, model_filename, install_path, sub_folder, api_response=None):
+ image_path = get_image_path(install_path, api_response, sub_folder)
+ img_urls = re.findall(r'data-sampleimg="true" src=[\'"]?([^\'" >]+)', preview_html)
+
+ name = os.path.splitext(model_filename)[0]
+
+ opener = urllib.request.build_opener()
+ opener.addheaders = [('User-agent', 'Mozilla/5.0')]
+ urllib.request.install_opener(opener)
+
+ for i, img_url in enumerate(img_urls):
+ filename = f'{name}_{i}.png'
+ img_url = urllib.parse.quote(img_url, safe=':/=')
+ try:
+ with urllib.request.urlopen(img_url) as url:
+ with open(os.path.join(image_path, filename), 'wb') as f:
+ f.write(url.read())
+ print(f"Downloaded {filename}")
+
+ except urllib.error.URLError as e:
+ print(f'Error: {e.reason}')
+
+def card_update(gr_components, model_name, list_versions, is_install):
+ if gr_components:
+ version_choices = gr_components['choices']
+ else:
+ print("Couldn't retrieve version, defaulting to installed")
+ model_name += ".New"
+ return model_name, None, None
+
+ if is_install and not gl.download_fail and not gl.cancel_status:
+ version_value_clean = list_versions + " [Installed]"
+ version_choices_clean = [version if version + " [Installed]" != version_value_clean else version_value_clean for version in version_choices]
+
+ else:
+ version_value_clean = list_versions.replace(" [Installed]", "")
+ version_choices_clean = [version if version.replace(" [Installed]", "") != version_value_clean else version_value_clean for version in version_choices]
+
+ first_version_installed = "[Installed]" in version_choices_clean[0]
+ any_later_version_installed = any("[Installed]" in version for version in version_choices_clean[1:])
+
+ if first_version_installed:
+ model_name += ".New"
+ elif any_later_version_installed:
+ model_name += ".Old"
+ else:
+ model_name += ".None"
+
+ return model_name, version_value_clean, version_choices_clean
+
+def list_files(folders):
+ model_files = []
+
+ extensions = ['.pt', '.ckpt', '.pth', '.safetensors', '.th', '.zip', '.vae']
+
+ for folder in folders:
+ if folder and os.path.exists(folder):
+ for root, _, files in os.walk(folder, followlinks=True):
+ for file in files:
+ _, file_extension = os.path.splitext(file)
+ if file_extension.lower() in extensions:
+ model_files.append(os.path.join(root, file))
+
+ model_files = sorted(list(set(model_files)))
+ return model_files
+
+def gen_sha256(file_path):
+ json_file = os.path.splitext(file_path)[0] + ".json"
+
+ if os.path.exists(json_file):
+ try:
+ with open(json_file, 'r', encoding="utf-8") as f:
+ data = json.load(f)
+
+ if 'sha256' in data and data['sha256']:
+ hash_value = data['sha256']
+ return hash_value
+ except Exception as e:
+ print(f"Failed to open {json_file}: {e}")
+
+ def read_chunks(file, size=io.DEFAULT_BUFFER_SIZE):
+ while True:
+ chunk = file.read(size)
+ if not chunk:
+ break
+ yield chunk
+
+ blocksize = 1 << 20
+ h = hashlib.sha256()
+ length = 0
+ with open(os.path.realpath(file_path), 'rb') as f:
+ for block in read_chunks(f, size=blocksize):
+ length += len(block)
+ h.update(block)
+
+ hash_value = h.hexdigest()
+
+ if os.path.exists(json_file):
+ try:
+ with open(json_file, 'r', encoding="utf-8") as f:
+ data = json.load(f)
+
+ if 'sha256' in data and data['sha256']:
+ data['sha256'] = hash_value
+
+ with open(json_file, 'w', encoding="utf-8") as f:
+ json.dump(data, f, indent=4)
+ except Exception as e:
+ print(f"Failed to open {json_file}: {e}")
+ else:
+ data = {'sha256': hash_value}
+ with open(json_file, 'w', encoding="utf-8") as f:
+ json.dump(data, f, indent=4)
+
+ return hash_value
+
+def model_from_sent(model_name, content_type, tile_count, path_input):
+ modelID_failed = False
+ output_html = None
+ model_file = None
+ use_local_html = getattr(opts, "use_local_html", False)
+ local_path_in_html = getattr(opts, "local_path_in_html", False)
+
+ div = '
'
+ not_found = div + "Model ID not found. Maybe the model doesn\'t exist on CivitAI?
"
+ path_not_found = div + "Model ID not found. Could not locate the model path."
+ offline = div + "CivitAI failed to respond. The servers are likely offline."
+
+ if local_path_in_html:
+ use_local_html = False
+
+ if path_input == "Not Found":
+ model_name = re.sub(r'\.\d{3}$', '', model_name)
+ content_type = re.sub(r'\.\d{3}$', '', content_type)
+ content_mapping = {
+ "txt2img_textual_inversion_cards_html": ['TextualInversion'],
+ "txt2img_hypernetworks_cards_html": ['Hypernetwork'],
+ "txt2img_checkpoints_cards_html": ['Checkpoint'],
+ "txt2img_lora_cards_html": ['LORA', 'LoCon']
+ }
+ content_type = content_mapping.get(content_type, content_type)
+
+ extensions = ['.pt', '.ckpt', '.pth', '.safetensors', '.th', '.zip', '.vae']
+
+ for content_type_item in content_type:
+ folder = _api.contenttype_folder(content_type_item)
+ for folder_path, _, files in os.walk(folder, followlinks=True):
+ for file in files:
+ if file.startswith(model_name) and file.endswith(tuple(extensions)):
+ model_file = os.path.join(folder_path, file)
+ if not model_file:
+ output_html = path_not_found
+ print(f'Error: Could not find model path for model: "{model_name}"')
+ print(f'Content type: "{content_type}"')
+ print(f'Main folder path: "{folder}"')
+ use_local_html = False
+ else:
+ model_file = path_input
+
+ if use_local_html:
+ html_file = os.path.splitext(model_file)[0] + ".html"
+ if os.path.exists(html_file):
+ with open(html_file, 'r', encoding='utf-8') as html:
+ output_html = html.read()
+ index = output_html.find("")
+ if index != -1:
+ output_html = output_html[index + len(""):]
+
+ if not output_html:
+ modelID = get_models(model_file, True)
+ if not modelID or modelID == "Model not found":
+ output_html = not_found
+ modelID_failed = True
+ if modelID == "offline":
+ output_html = offline
+ modelID_failed = True
+ if not modelID_failed:
+ json_data = _api.api_to_data(content_type, "Newest", "AllTime", "Model name", None, None, None, tile_count, f"civitai.com/models/{modelID}")
+ else:
+ json_data = None
+
+ if json_data == "timeout":
+ output_html = offline
+ if json_data != None and json_data != "timeout":
+ model_versions = _api.update_model_versions(modelID, json_data)
+ output_html = _api.update_model_info(None, model_versions.get('value'), True, modelID, json_data)
+
+ css_path = Path(__file__).resolve().parents[1] / "style_html.css"
+ with open(css_path, 'r', encoding='utf-8') as css_file:
+ css = css_file.read()
+ replacements = {
+ '#0b0f19': 'var(--body-background-fill)',
+ '#F3F4F6': 'var(--body-text-color)',
+ 'white': 'var(--body-text-color)',
+ '#80a6c8': 'var(--secondary-300)',
+ '#60A5FA': 'var(--link-text-color-hover)',
+ '#1F2937': 'var(--input-background-fill)',
+ '#374151': 'var(--input-border-color)',
+ '#111827': 'var(--error-background-fill)',
+ 'top: 50%;': '',
+ 'padding-top: 0px;': 'padding-top: 475px;',
+ '.civitai_txt2img': '.civitai_placeholder'
+ }
+
+ for old, new in replacements.items():
+ css = css.replace(old, new)
+
+ style_tag = f''
+ head_section = f'{style_tag}'
+
+ output_html = output_html.replace('display:flex;align-items:flex-start;', 'display:flex;align-items:flex-start;flex-wrap:wrap;justify-content:center;')
+ output_html = str(head_section + output_html)
+ output_html = output_html.replace('zoom-radio', 'zoom-preview-radio')
+ output_html = output_html.replace('zoomRadio', 'zoomPreviewRadio')
+ output_html = output_html.replace('zoom-overlay', 'zoom-preview-overlay')
+ output_html = output_html.replace('resetZoom', 'resetPreviewZoom')
+
+ number = _download.random_number()
+
+ return (
+ gr.Textbox.update(value=output_html, placeholder=number), # Preview HTML
+ )
+
+def is_image_url(url):
+ image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp']
+ parsed = urlparse(url)
+ path = parsed.path
+ return any(path.endswith(ext) for ext in image_extensions)
+
+def clean_description(desc):
+ try:
+ soup = BeautifulSoup(desc, 'html.parser')
+ for a in soup.find_all('a', href=True):
+ link_text = a.text + ' ' + a['href']
+ if not is_image_url(a['href']):
+ a.replace_with(link_text)
+
+ cleaned_text = soup.get_text()
+ except ImportError:
+ print("Python module 'BeautifulSoup' was not imported correctly, cannot clean description. Please try to restart or install it manually.")
+ cleaned_text = desc
+ return cleaned_text
+
+def make_dir(path):
+ try:
+ if not os.path.exists(path):
+ os.makedirs(path)
+ except OSError as e:
+ if e.errno == errno.EACCES:
+ try:
+ os.makedirs(path, mode=0o777)
+ except OSError as e:
+ if e.errno == errno.EACCES:
+ print("Permission denied even with elevated permissions.")
+ else:
+ print(f"Error creating directory: {e}")
+ else:
+ print(f"Error creating directory: {e}")
+ except Exception as e:
+ print(f"Error creating directory: {e}")
+
+def save_model_info(install_path, file_name, sub_folder, sha256=None, preview_html=None, overwrite_toggle=False, api_response=None):
+ filename = os.path.splitext(file_name)[0]
+ json_file = os.path.join(install_path, f'{filename}.json')
+ make_dir(install_path)
+
+ save_api_info = getattr(opts, "save_api_info", False)
+ use_local = getattr(opts, "local_path_in_html", False)
+ save_to_custom = getattr(opts, "save_to_custom", False)
+
+ if not api_response:
+ api_response = gl.json_data
+
+ image_path = get_image_path(install_path, api_response, sub_folder)
+
+ if save_to_custom:
+ save_path = image_path
+ else:
+ save_path = install_path
+
+ result = find_and_save(api_response, sha256, file_name, json_file, False, overwrite_toggle)
+ if result != "found":
+ result = find_and_save(api_response, sha256, file_name, json_file, True, overwrite_toggle)
+
+ if preview_html:
+ if use_local:
+ img_urls = re.findall(r'data-sampleimg="true" src=[\'"]?([^\'" >]+)', preview_html)
+ for i, img_url in enumerate(img_urls):
+ img_name = f'{filename}_{i}.png'
+ preview_html = preview_html.replace(img_url,f'{os.path.join(image_path, img_name)}')
+
+ match = re.search(r'(\s*)
', preview_html)
+ if match:
+ indentation = match.group(1)
+ else:
+ indentation = ''
+ css_link = f''
+ utf8_meta_tag = f'{indentation}'
+ head_section = f'{indentation}{indentation} {utf8_meta_tag}{indentation} {css_link}{indentation}'
+ HTML = head_section + preview_html
+ path_to_new_file = os.path.join(save_path, f'{filename}.html')
+ with open(path_to_new_file, 'wb') as f:
+ f.write(HTML.encode('utf8'))
+
+ if save_api_info:
+ path_to_new_file = os.path.join(save_path, f'{filename}.api_info.json')
+ with open(path_to_new_file, mode="w", encoding="utf-8") as f:
+ json.dump(gl.json_info, f, indent=4, ensure_ascii=False)
+
+
+def find_and_save(api_response, sha256=None, file_name=None, json_file=None, no_hash=None, overwrite_toggle=None):
+ for item in api_response.get('items', []):
+ for model_version in item.get('modelVersions', []):
+ for file in model_version.get('files', []):
+ file_name_api = file.get('name', '')
+ sha256_api = file.get('hashes', {}).get('SHA256', '')
+
+ if file_name == file_name_api if no_hash else sha256 == sha256_api:
+ gl.json_info = item
+ trained_words = model_version.get('trainedWords', [])
+ model_id = model_version.get('modelId', '')
+
+ if model_id:
+ model_url = f'Model URL: \"https://civitai.com/models/{model_id}\"\n'
+
+ description = item.get('description', '')
+ if description != None:
+ description = clean_description(description)
+ description = model_url + description
+
+ base_model = model_version.get('baseModel', '')
+
+ if base_model.startswith("SD 1"):
+ base_model = "SD1"
+ elif base_model.startswith("SD 2"):
+ base_model = "SD2"
+ elif base_model.startswith("SDXL"):
+ base_model = "SDXL"
+ else:
+ base_model = "Other"
+
+ if isinstance(trained_words, list):
+ trained_tags = ",".join(trained_words)
+ trained_tags = re.sub(r'<[^>]*:[^>]*>', '', trained_tags)
+ trained_tags = re.sub(r', ?', ', ', trained_tags)
+ trained_tags = trained_tags.strip(', ')
+ else:
+ trained_tags = trained_words
+
+ if os.path.exists(json_file):
+ with open(json_file, 'r', encoding="utf-8") as f:
+ try:
+ content = json.load(f)
+ except:
+ content = {}
+ else:
+ content = {}
+ changed = False
+ if not overwrite_toggle:
+ if "activation text" not in content or not content["activation text"]:
+ content["activation text"] = trained_tags
+ changed = True
+ if "description" not in content or not content["description"]:
+ content["description"] = description
+ changed = True
+ if "sd version" not in content or not content["sd version"]:
+ content["sd version"] = base_model
+ changed = True
+ else:
+ content["activation text"] = trained_tags
+ content["description"] = description
+ content["sd version"] = base_model
+ changed = True
+
+ with open(json_file, 'w', encoding="utf-8") as f:
+ json.dump(content, f, indent=4)
+
+ if changed: print(f"Model info saved to \"{json_file}\"")
+ return "found"
+
+ return "not found"
+
+def get_models(file_path, gen_hash=None):
+ modelId = None
+ sha256 = None
+ json_file = os.path.splitext(file_path)[0] + ".json"
+ if os.path.exists(json_file):
+ try:
+ with open(json_file, 'r', encoding="utf-8") as f:
+ data = json.load(f)
+
+ if 'modelId' in data:
+ modelId = data['modelId']
+ if 'sha256' in data and data['sha256']:
+ sha256 = data['sha256']
+ except Exception as e:
+ print(f"Failed to open {json_file}: {e}")
+
+ if not modelId or not sha256 or modelId == "Model not found":
+ if gen_hash:
+ if not sha256:
+ sha256 = gen_sha256(file_path)
+ by_hash = f"https://civitai.com/api/v1/model-versions/by-hash/{sha256}"
+ else:
+ if modelId:
+ return modelId
+ else:
+ return None
+
+ try:
+ if not modelId or modelId == "Model not found":
+ response = requests.get(by_hash, timeout=(10,30))
+ if response.status_code == 200:
+ api_response = response.json()
+ if 'error' in api_response:
+ print(f"\"{file_path}\": {api_response['error']}")
+ return None
+ else:
+ modelId = api_response.get("modelId", "")
+ elif response.status_code == 503:
+ return "offline"
+ elif response.status_code == 404:
+ api_response = response.json()
+ modelId = api_response.get("error", "")
+
+ if os.path.exists(json_file):
+ try:
+ with open(json_file, 'r', encoding="utf-8") as f:
+ data = json.load(f)
+
+ data['modelId'] = modelId
+ data['sha256'] = sha256.upper()
+
+ with open(json_file, 'w', encoding="utf-8") as f:
+ json.dump(data, f, indent=4)
+ except Exception as e:
+ print(f"Failed to open {json_file}: {e}")
+ else:
+ data = {
+ 'modelId': modelId,
+ 'sha256': sha256.upper()
+ }
+ with open(json_file, 'w', encoding="utf-8") as f:
+ json.dump(data, f, indent=4)
+
+ return modelId
+ except requests.exceptions.Timeout:
+ print(f"Request timed out for {file_path}. Skipping...")
+ return "offline"
+ except requests.exceptions.ConnectionError:
+ print("Failed to connect to the API. The CivitAI servers might be offline.")
+ return "offline"
+ except Exception as e:
+ print(f"An error occurred for {file_path}: {str(e)}")
+ return None
+
+def version_match(file_paths, api_response):
+ updated_models = []
+ outdated_models = []
+ sha256_hashes = {}
+ for file_path in file_paths:
+ json_path = f"{os.path.splitext(file_path)[0]}.json"
+ if os.path.exists(json_path):
+ with open(json_path, 'r', encoding="utf-8") as f:
+ try:
+ json_data = json.load(f)
+ sha256 = json_data.get('sha256')
+ if sha256:
+ sha256_hashes[os.path.basename(file_path)] = sha256.upper()
+ except Exception as e:
+ print(f"Failed to open {json_path}: {e}")
+
+ file_names_and_hashes = set()
+ for file_path in file_paths:
+ file_name = os.path.basename(file_path)
+ file_name_without_ext = os.path.splitext(file_name)[0]
+ file_sha256 = sha256_hashes.get(file_name, "")
+ if file_sha256: file_sha256 = file_sha256.upper()
+ file_names_and_hashes.add((file_name_without_ext, file_sha256))
+
+ for item in api_response.get('items', []):
+ model_versions = item.get('modelVersions', [])
+
+ if not model_versions:
+ continue
+
+ for idx, model_version in enumerate(model_versions):
+ files = model_version.get('files', [])
+ match_found = False
+ for file_entry in files:
+ entry_name = os.path.splitext(file_entry.get('name', ''))[0]
+ entry_sha256 = file_entry.get('hashes', {}).get('SHA256', "")
+ if entry_sha256: entry_sha256 = entry_sha256.upper()
+
+ if (entry_name, entry_sha256) in file_names_and_hashes:
+ match_found = True
+ break
+
+ if match_found:
+ if idx == 0:
+ updated_models.append((f"&ids={item['id']}", item["name"]))
+ else:
+ outdated_models.append((f"&ids={item['id']}", item["name"]))
+ break
+
+ return updated_models, outdated_models
+
+def get_content_choices(scan_choices=False):
+ use_LORA = getattr(opts, "use_LORA", False)
+ if use_LORA:
+ content_list = ["Checkpoint", "TextualInversion", "LORA & LoCon", "Poses", "Controlnet", "Hypernetwork", "AestheticGradient", "VAE", "Upscaler", "MotionModule", "Wildcards", "Workflows", "Other"]
+ else:
+ content_list = ["Checkpoint", "TextualInversion", "LORA", "LoCon", "Poses", "Controlnet", "Hypernetwork", "AestheticGradient", "VAE", "Upscaler", "MotionModule", "Wildcards", "Workflows", "Other"]
+ if scan_choices:
+ content_list.insert(0, 'All')
+ return content_list
+ return content_list
+
+def file_scan(folders, ver_finish, tag_finish, installed_finish, preview_finish, overwrite_toggle, tile_count, gen_hash, progress=gr.Progress() if queue else None):
+ global from_ver, from_installed, no_update
+ update_log = getattr(opts, "update_log", True)
+ gl.scan_files = True
+ no_update = False
+ if from_ver:
+ number = _download.random_number(ver_finish)
+ elif from_tag:
+ number = _download.random_number(tag_finish)
+ elif from_installed:
+ number = _download.random_number(installed_finish)
+ elif from_preview:
+ number = _download.random_number(preview_finish)
+
+ if not folders:
+ if progress != None:
+ progress(0, desc=f"No folder selected.")
+ no_update = True
+ gl.scan_files = False
+ from_ver, from_installed = False, False
+ time.sleep(2)
+ return (gr.HTML.update(value=''),
+ gr.Textbox.update(value=number))
+
+ folders_to_check = []
+ if 'All' in folders:
+ folders = _file.get_content_choices()
+
+ for item in folders:
+ if item == "LORA & LoCon":
+ folder = _api.contenttype_folder("LORA")
+ if folder:
+ folders_to_check.append(folder)
+ folder = _api.contenttype_folder("LoCon", fromCheck=True)
+ if folder:
+ folders_to_check.append(folder)
+ elif item == "Upscaler":
+ folder = _api.contenttype_folder(item, "SwinIR")
+ if folder:
+ folders_to_check.append(folder)
+ folder = _api.contenttype_folder(item, "RealESRGAN")
+ if folder:
+ folders_to_check.append(folder)
+ folder = _api.contenttype_folder(item, "GFPGAN")
+ if folder:
+ folders_to_check.append(folder)
+ folder = _api.contenttype_folder(item, "BSRGAN")
+ if folder:
+ folders_to_check.append(folder)
+ folder = _api.contenttype_folder(item, "ESRGAN")
+ if folder:
+ folders_to_check.append(folder)
+ else:
+ folder = _api.contenttype_folder(item)
+ if folder:
+ folders_to_check.append(folder)
+
+ total_files = 0
+ files_done = 0
+
+ files = list_files(folders_to_check)
+ total_files += len(files)
+
+ if total_files == 0:
+ if progress != None:
+ progress(1, desc=f"No files in selected folder.")
+ no_update = True
+ gl.scan_files = False
+ from_ver, from_installed = False, False
+ time.sleep(2)
+ return (gr.HTML.update(value=''),
+ gr.Textbox.update(value=number))
+
+ updated_models = []
+ outdated_models = []
+ all_model_ids = []
+ file_paths = []
+ all_ids = []
+
+ for file_path in files:
+ if gl.cancel_status:
+ if progress != None:
+ progress(files_done / total_files, desc=f"Processing files cancelled.")
+ no_update = True
+ gl.scan_files = False
+ from_ver, from_installed = False, False
+ time.sleep(2)
+ return (gr.HTML.update(value=''),
+ gr.Textbox.update(value=number))
+ file_name = os.path.basename(file_path)
+ if progress != None:
+ progress(files_done / total_files, desc=f"Processing file: {file_name}")
+
+ model_id = get_models(file_path, gen_hash)
+ if model_id == "offline":
+ print("The CivitAI servers did not respond, unable to retrieve Model ID")
+ elif model_id == "Model not found" and update_log:
+ print(f"model: \"{file_name}\" not found on CivitAI servers.")
+ elif model_id != None:
+ all_model_ids.append(f"&ids={model_id}")
+ all_ids.append(model_id)
+ file_paths.append(file_path)
+ elif not model_id and update_log:
+ print(f"model ID not found for: \"{file_name}\"")
+ files_done += 1
+
+ all_items = []
+
+ all_model_ids = list(set(all_model_ids))
+
+ if not all_model_ids:
+ progress(1, desc=f"No model IDs could be retrieved.")
+ print("Could not retrieve any Model IDs, please make sure to turn on the \"One-Time Hash Generation for externally downloaded models.\" option if you haven't already.")
+ no_update = True
+ gl.scan_files = False
+ from_ver, from_installed = False, False
+ time.sleep(2)
+ return (gr.HTML.update(value=''),
+ gr.Textbox.update(value=number))
+
+ def chunks(lst, n):
+ for i in range(0, len(lst), n):
+ yield lst[i:i + n]
+
+ if not from_installed:
+ model_chunks = list(chunks(all_model_ids, 500))
+
+ base_url = "https://civitai.com/api/v1/models?limit=100"
+ url_list = [f"{base_url}{''.join(chunk)}" for chunk in model_chunks]
+
+ url_count = len(all_model_ids) // 100
+ if len(all_model_ids) % 100 != 0:
+ url_count += 1
+ url_done = 0
+ api_response = {}
+ for url in url_list:
+ while url:
+ try:
+ if progress is not None:
+ progress(url_done / url_count, desc=f"Sending API request... {url_done}/{url_count}")
+ response = requests.get(url, timeout=(10, 30))
+ if response.status_code == 200:
+ api_response_json = response.json()
+
+ all_items.extend(api_response_json['items'])
+ metadata = api_response_json.get('metadata', {})
+ url = metadata.get('nextPage', None)
+ elif response.status_code == 503:
+ return (
+ gr.HTML.update(value=offlineHTML),
+ gr.Textbox.update(value=number)
+ )
+ else:
+ print(f"Error: Received status code {response.status_code} with URL: {url}")
+ url = None
+ url_done += 1
+ except requests.exceptions.Timeout:
+ print(f"Request timed out for {url}. Skipping...")
+ url = None
+ except requests.exceptions.ConnectionError:
+ print("Failed to connect to the API. The servers might be offline.")
+ url = None
+ except Exception as e:
+ print(f"An unexpected error occurred: {e}")
+ url = None
+
+ api_response['items'] = all_items
+ if api_response['items'] == []:
+ return (
+ gr.HTML.update(value=offlineHTML),
+ gr.Textbox.update(value=number)
+ )
+
+ if progress != None:
+ progress(1, desc="Processing final results...")
+
+ if from_ver:
+ updated_models, outdated_models = version_match(file_paths, api_response)
+
+ updated_set = set(updated_models)
+ outdated_set = set(outdated_models)
+ outdated_set = {model for model in outdated_set if model[0] not in {updated_model[0] for updated_model in updated_set}}
+
+ all_model_ids = [model[0] for model in outdated_set]
+ all_model_names = [model[1] for model in outdated_set]
+
+ if update_log:
+ for model_name in all_model_names:
+ print(f'"{model_name}" is currently outdated.')
+
+ if len(all_model_ids) == 0:
+ no_update = True
+ gl.scan_files = False
+ from_ver = False
+ return (
+ gr.HTML.update(value='
No updates found for selected models.
'),
+ gr.Textbox.update(value=number)
+ )
+
+ model_chunks = list(chunks(all_model_ids, tile_count))
+
+ base_url = "https://civitai.com/api/v1/models?limit=100"
+ gl.url_list_with_numbers = {i+1: f"{base_url}{''.join(chunk)}" for i, chunk in enumerate(model_chunks)}
+
+ url_error = False
+ api_url = gl.url_list_with_numbers.get(1)
+
+ if not url_error:
+ response = requests.get(api_url, timeout=(10,30))
+ try:
+ if response.status_code == 200:
+ response.encoding = "utf-8"
+ gl.ver_json = json.loads(response.text)
+
+ highest_number = max(gl.url_list_with_numbers.keys())
+ gl.ver_json["metadata"]["totalPages"] = highest_number
+
+ if highest_number > 1:
+ gl.ver_json["metadata"]["nextPage"] = gl.url_list_with_numbers.get(2)
+ else:
+ print(f"Error: Received status code {response.status_code} for URL: {url}")
+ url_error = True
+ except requests.exceptions.Timeout:
+ print(f"Request timed out for {url}. Skipping...")
+ url_error = True
+ except requests.exceptions.ConnectionError:
+ print("Failed to connect to the API. The servers might be offline.")
+ url_error = True
+ except Exception as e:
+ print(f"An unexpected error occurred: {e}")
+ url_error = True
+
+ if url_error:
+ gl.scan_files = False
+ return (
+ gr.HTML.update(value=offlineHTML),
+ gr.Textbox.update(value=number)
+ )
+ elif from_ver:
+ gl.scan_files = False
+ return (
+ gr.HTML.update(value='
Outdated models have been found. Please press the button above to load the models into the browser tab
Installed models have been loaded. Please press the button above to load the models into the browser tab
'),
+ gr.Textbox.update(value=number)
+ )
+
+ elif from_tag:
+ for file_path, id_value in zip(file_paths, all_ids):
+ install_path, file_name = os.path.split(file_path)
+ model_versions = _api.update_model_versions(id_value, api_response)
+ preview_html = _api.update_model_info(None, model_versions.get('value'), True, id_value, api_response)
+ sub_folder = os.path.normpath(os.path.relpath(install_path, gl.main_folder))
+ save_model_info(install_path, file_name, sub_folder, preview_html=preview_html, api_response=api_response, overwrite_toggle=overwrite_toggle)
+ if progress != None:
+ progress(1, desc=f"All tags succesfully saved!")
+ gl.scan_files = False
+ time.sleep(2)
+ return (
+ gr.HTML.update(value=''),
+ gr.Textbox.update(value=number)
+ )
+
+ elif from_preview:
+ completed_preview = 0
+ preview_count = len(file_paths)
+ for file in file_paths:
+ _, file_name = os.path.split(file)
+ name = os.path.splitext(file_name)[0]
+ if progress != None:
+ progress(completed_preview / preview_count, desc=f"Saving preview images... {completed_preview}/{preview_count} | {name}")
+ save_preview(file, api_response, overwrite_toggle)
+ completed_preview += 1
+ gl.scan_files = False
+ return (
+ gr.HTML.update(value=''),
+ gr.Textbox.update(value=number)
+ )
+
+def save_tag_start(tag_start):
+ global from_tag, from_ver, from_installed, from_preview
+ from_tag, from_ver, from_installed, from_preview = True, False, False, False
+ number = _download.random_number(tag_start)
+ return (
+ gr.Textbox.update(value=number),
+ gr.Button.update(interactive=False, visible=False),
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=False, visible=True),
+ gr.Button.update(interactive=False, visible=True),
+ gr.Button.update(interactive=False, visible=True),
+ gr.HTML.update(value='')
+ )
+
+def save_preview_start(preview_start):
+ global from_tag, from_ver, from_installed, from_preview
+ from_preview, from_tag, from_ver, from_installed = True, False, False, False
+ number = _download.random_number(preview_start)
+ return (
+ gr.Textbox.update(value=number),
+ gr.Button.update(interactive=False, visible=False),
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=False, visible=True),
+ gr.Button.update(interactive=False, visible=True),
+ gr.Button.update(interactive=False, visible=True),
+ gr.HTML.update(value='')
+ )
+
+def installed_models_start(installed_start):
+ global from_installed, from_ver, from_tag, from_preview
+ from_installed, from_ver, from_tag, from_preview = True, False, False, False
+ number = _download.random_number(installed_start)
+ return (
+ gr.Textbox.update(value=number),
+ gr.Button.update(interactive=False, visible=False),
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=False, visible=True),
+ gr.Button.update(interactive=False, visible=True),
+ gr.Button.update(interactive=False, visible=True),
+ gr.HTML.update(value='')
+ )
+
+def ver_search_start(ver_start):
+ global from_ver, from_tag, from_installed, from_preview
+ from_ver, from_tag, from_installed, from_preview = True, False, False, False
+ number = _download.random_number(ver_start)
+ return (
+ gr.Textbox.update(value=number),
+ gr.Button.update(interactive=False, visible=False),
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=False, visible=True),
+ gr.Button.update(interactive=False, visible=True),
+ gr.Button.update(interactive=False, visible=True),
+ gr.HTML.update(value='')
+ )
+
+def save_tag_finish():
+ global from_tag
+ from_tag = False
+ return (
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=False, visible=False)
+ )
+
+def save_preview_finish():
+ global from_preview
+ from_preview = False
+ return (
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=False, visible=False)
+ )
+
+def scan_finish():
+ return (
+ gr.Button.update(interactive=no_update, visible=no_update),
+ gr.Button.update(interactive=no_update, visible=no_update),
+ gr.Button.update(interactive=no_update, visible=no_update),
+ gr.Button.update(interactive=no_update, visible=no_update),
+ gr.Button.update(interactive=False, visible=False),
+ gr.Button.update(interactive=not no_update, visible=not no_update)
+ )
+
+def load_to_browser(content_type, sort_type, period_type, use_search_term, search_term, tile_count, base_filter, nsfw):
+ global from_ver, from_installed
+ if from_ver:
+ model_list_return = _api.update_model_list(from_ver=True, tile_count=tile_count)
+ if from_installed:
+ model_list_return = _api.update_model_list(from_installed=True, tile_count=tile_count)
+
+ use_LORA = getattr(opts, "use_LORA", False)
+ if content_type:
+ if use_LORA and 'LORA & LoCon' in content_type:
+ content_type.remove('LORA & LoCon')
+ if 'LORA' not in content_type:
+ content_type.append('LORA')
+ if 'LoCon' not in content_type:
+ content_type.append('LoCon')
+
+ current_inputs = (content_type, sort_type, period_type, use_search_term, search_term, tile_count, base_filter, nsfw)
+ gl.previous_inputs = current_inputs
+
+ gl.file_scan = True
+ from_ver, from_installed = False, False
+ return (
+ *model_list_return,
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=True, visible=True),
+ gr.Button.update(interactive=False, visible=False),
+ gr.Button.update(interactive=False, visible=False),
+ gr.HTML.update(value='')
+ )
+
+def cancel_scan():
+ gl.cancel_status = True
+
+ while True:
+ if not gl.scan_files:
+ gl.cancel_status = False
+ return
+ else:
+ time.sleep(0.5)
+ continue
diff --git a/sd-civitai-browser-plus/scripts/civitai_global.py b/sd-civitai-browser-plus/scripts/civitai_global.py
new file mode 100644
index 0000000000000000000000000000000000000000..d322db6549da0fbe55f194921757d44d78818d9c
--- /dev/null
+++ b/sd-civitai-browser-plus/scripts/civitai_global.py
@@ -0,0 +1,24 @@
+def init():
+ global download_queue, last_version, cancel_status, recent_model, json_data, json_info, main_folder, previous_inputs, download_fail, sortNewest, isDownloading, old_download, scan_files, ver_json, file_scan, url_list_with_numbers, print
+
+ cancel_status = None
+ recent_model = None
+ json_data = None
+ json_info = None
+ main_folder = None
+ previous_inputs = None
+ last_version = None
+ ver_json = None
+ url_list_with_numbers = None
+ download_queue = []
+
+ file_scan = False
+ scan_files = False
+ download_fail = False
+ sortNewest = False
+ isDownloading = False
+ old_download = False
+
+_print = print
+def print(print_message):
+ _print(f'\033[96mCivitAI Browser+\033[0m: {print_message}')
\ No newline at end of file
diff --git a/sd-civitai-browser-plus/scripts/civitai_gui.py b/sd-civitai-browser-plus/scripts/civitai_gui.py
new file mode 100644
index 0000000000000000000000000000000000000000..735c6fb8238eac0019384f58cd6eda2d2cbf9e50
--- /dev/null
+++ b/sd-civitai-browser-plus/scripts/civitai_gui.py
@@ -0,0 +1,1263 @@
+import gradio as gr
+from modules import script_callbacks, shared
+import os
+import json
+import fnmatch
+import re
+import subprocess
+from modules.shared import opts, cmd_opts
+from modules.paths import extensions_dir
+from scripts.civitai_global import print
+import scripts.civitai_global as gl
+import scripts.civitai_download as _download
+import scripts.civitai_file_manage as _file
+import scripts.civitai_api as _api
+
+def git_tag():
+ try:
+ return subprocess.check_output([os.environ.get('GIT', "git"), "describe", "--tags"], shell=False, encoding='utf8').strip()
+ except:
+ return None
+
+try:
+ import modules_forge
+ forge = True
+ ver_bool = True
+except ImportError:
+ forge = False
+
+if not forge:
+ try:
+ from packaging import version
+ ver = git_tag()
+
+ if not ver:
+ try:
+ from modules import launch_utils
+ ver = launch_utils.git_tag()
+ except:
+ print("Failed to fetch SD-WebUI version")
+ ver_bool = False
+ if ver:
+ ver = ver.split('-')[0].rsplit('-', 1)[0]
+ ver_bool = version.parse(ver[0:]) >= version.parse("1.7")
+ except ImportError:
+ print("Python module 'packaging' has not been imported correctly, please try to restart or install it manually.")
+ ver_bool = False
+
+gl.init()
+
+def saveSettings(ust, ct, pt, st, bf, cj, td, ol, hi, sn, ss, ts):
+ config = cmd_opts.ui_config_file
+
+ # Create a dictionary to map the settings to their respective variables
+ settings_map = {
+ "civitai_interface/Search type:/value": ust,
+ "civitai_interface/Content type:/value": ct,
+ "civitai_interface/Time period:/value": pt,
+ "civitai_interface/Sort by:/value": st,
+ "civitai_interface/Base model:/value": bf,
+ "civitai_interface/Save info after download/value": cj,
+ "civitai_interface/Divide cards by date/value": td,
+ "civitai_interface/Liked models only/value": ol,
+ "civitai_interface/Hide installed models/value": hi,
+ "civitai_interface/NSFW content/value": sn,
+ "civitai_interface/Tile size:/value": ss,
+ "civitai_interface/Tile count:/value": ts
+ }
+
+ # Load the current contents of the config file into a dictionary
+ try:
+ with open(config, "r", encoding="utf8") as file:
+ data = json.load(file)
+ except:
+ print(f"Cannot save settings, failed to open \"{file}\"")
+ print("Please try to manually repair the file or remove it to reset settings.")
+ return
+
+ # Remove any keys containing the text `civitai_interface`
+ keys_to_remove = [key for key in data if "civitai_interface" in key]
+ for key in keys_to_remove:
+ del data[key]
+
+ # Update the dictionary with the new settings
+ data.update(settings_map)
+
+ # Save the modified content back to the file
+ with open(config, 'w', encoding="utf-8") as file:
+ json.dump(data, file, indent=4)
+ print(f"Updated settings to: {config}")
+
+def all_visible(html_check):
+ return gr.Button.update(visible="model-checkbox" in html_check)
+
+def show_multi_buttons(model_list, type_list, version_value):
+ model_list = json.loads(model_list)
+ type_list = json.loads(type_list)
+ otherButtons = True
+ multi_file_subfolder = False
+ default_subfolder = "Only available if the selected files are of the same model type"
+ sub_folders = ["None"]
+ BtnDwn = version_value and not version_value.endswith('[Installed]') and not model_list
+ BtnDel = version_value.endswith('[Installed]')
+
+ dot_subfolders = getattr(opts, "dot_subfolders", True)
+
+ multi = bool(model_list) and not len(gl.download_queue) > 0
+ if model_list:
+ otherButtons = False
+ if type_list and all(x == type_list[0] for x in type_list):
+ multi_file_subfolder = True
+ model_folder = os.path.join(_api.contenttype_folder(type_list[0]))
+ default_subfolder = "None"
+ try:
+ for root, dirs, _ in os.walk(model_folder, followlinks=True):
+ if dot_subfolders:
+ dirs = [d for d in dirs if not d.startswith('.')]
+ dirs = [d for d in dirs if not any(part.startswith('.') for part in os.path.join(root, d).split(os.sep))]
+ for d in dirs:
+ sub_folder = os.path.relpath(os.path.join(root, d), model_folder)
+ if sub_folder:
+ sub_folders.append(f'{os.sep}{sub_folder}')
+ sub_folders.remove("None")
+ sub_folders = sorted(sub_folders, key=lambda x: (x.lower(), x))
+ sub_folders.insert(0, "None")
+
+ list = set()
+ sub_folders = [x for x in sub_folders if not (x in list or list.add(x))]
+ except:
+ sub_folders = ["None"]
+
+ return (gr.Button.update(visible=multi, interactive=multi), # Download Multi Button
+ gr.Button.update(visible=BtnDwn if multi else True if not version_value.endswith('[Installed]') else False), # Download Button
+ gr.Button.update(visible=BtnDel if not model_list else False), # Delete Button
+ gr.Button.update(visible=otherButtons), # Save model info Button
+ gr.Button.update(visible=otherButtons), # Save images Button
+ gr.Dropdown.update(visible=multi, interactive=multi_file_subfolder, choices=sub_folders, value=default_subfolder) # Selected type sub folder
+ )
+
+def txt2img_output(image_url):
+ clean_url = image_url[4:]
+ geninfo = _api.fetch_and_process_image(clean_url)
+ if geninfo:
+ nr = _download.random_number()
+ geninfo = nr + geninfo
+ return gr.Textbox.update(value=geninfo)
+
+def on_ui_tabs():
+ page_header = getattr(opts, "page_header", False)
+ lobe_directory = None
+
+ for root, dirs, files in os.walk(extensions_dir, followlinks=True):
+ for dir_name in fnmatch.filter(dirs, '*lobe*'):
+ lobe_directory = os.path.join(root, dir_name)
+ break
+
+ # Different ID's for Lobe Theme
+ component_id = "togglesL" if lobe_directory else "toggles"
+ toggle1 = "toggle1L" if lobe_directory else "toggle1"
+ toggle2 = "toggle2L" if lobe_directory else "toggle2"
+ toggle3 = "toggle3L" if lobe_directory else "toggle3"
+ toggle5 = "toggle5L" if lobe_directory else "toggle5"
+ refreshbtn = "refreshBtnL" if lobe_directory else "refreshBtn"
+ filterBox = "filterBoxL" if lobe_directory else "filterBox"
+
+ if page_header:
+ header = "headerL" if lobe_directory else "header"
+ else:
+ header = "header_off"
+
+ api_key = getattr(opts, "custom_api_key", "")
+ if api_key:
+ toggle4 = "toggle4L_api" if lobe_directory else "toggle4_api"
+ show_only_liked = True
+ else:
+ toggle4 = "toggle4L" if lobe_directory else "toggle4"
+ show_only_liked = False
+
+ content_choices = _file.get_content_choices()
+ scan_choices = _file.get_content_choices(scan_choices=True)
+ with gr.Blocks() as civitai_interface:
+ with gr.Tab(label="Browser", elem_id="browserTab"):
+ with gr.Row(elem_id="searchRow"):
+ with gr.Accordion(label="", open=False, elem_id=filterBox):
+ with gr.Row():
+ use_search_term = gr.Radio(label="Search type:", choices=["Model name", "User name", "Tag"], value="Model name", elem_id="searchType")
+ with gr.Row():
+ content_type = gr.Dropdown(label='Content type:', choices=content_choices, value=None, type="value", multiselect=True, elem_id="centerText")
+ with gr.Row():
+ base_filter = gr.Dropdown(label='Base model:', multiselect=True, choices=["SD 1.4","SD 1.5","SD 1.5 LCM","SD 2.0","SD 2.0 768","SD 2.1","SD 2.1 768","SD 2.1 Unclip","SDXL 0.9","SDXL 1.0","SDXL 1.0 LCM","SDXL Distilled","SDXL Turbo","SVD","SVD XT","Playground v2","PixArt a","Other"], value=None, type="value", elem_id="centerText")
+ with gr.Row():
+ period_type = gr.Dropdown(label='Time period:', choices=["All Time", "Year", "Month", "Week", "Day"], value="All Time", type="value", elem_id="centerText")
+ sort_type = gr.Dropdown(label='Sort by:', choices=["Newest","Most Downloaded","Highest Rated","Most Liked", "Most Buzz","Most Discussed","Most Collected","Most Images"], value="Most Downloaded", type="value", elem_id="centerText")
+ with gr.Row(elem_id=component_id):
+ create_json = gr.Checkbox(label=f"Save info after download", value=True, elem_id=toggle1, min_width=171)
+ show_nsfw = gr.Checkbox(label="NSFW content", value=False, elem_id=toggle2, min_width=107)
+ toggle_date = gr.Checkbox(label="Divide cards by date", value=False, elem_id=toggle3, min_width=142)
+ only_liked = gr.Checkbox(label="Liked models only", value=False, interactive=show_only_liked, elem_id=toggle4, min_width=163)
+ hide_installed = gr.Checkbox(label="Hide installed models", value=False, elem_id=toggle5, min_width=170)
+ with gr.Row():
+ size_slider = gr.Slider(minimum=4, maximum=20, value=8, step=0.25, label='Tile size:')
+ tile_count_slider = gr.Slider(label="Tile count:", minimum=1, maximum=100, value=15, step=1, max_width=100)
+ with gr.Row(elem_id="save_set_box"):
+ save_settings = gr.Button(value="Save settings as default", elem_id="save_set_btn")
+ search_term = gr.Textbox(label="", placeholder="Search CivitAI", elem_id="searchBox")
+ refresh = gr.Button(label="", value="", elem_id=refreshbtn, icon="placeholder")
+ with gr.Row(elem_id=header):
+ with gr.Row(elem_id="pageBox"):
+ get_prev_page = gr.Button(value="Prev page", interactive=False, elem_id="pageBtn1")
+ page_slider = gr.Slider(label='Current page', step=1, minimum=1, maximum=1, value=1, min_width=80, elem_id="pageSlider")
+ get_next_page = gr.Button(value="Next page", interactive=False, elem_id="pageBtn2")
+ with gr.Row(elem_id="pageBoxMobile"):
+ pass # Row used for button placement on mobile
+ with gr.Row(elem_id="select_all_models_container"):
+ select_all = gr.Button(value="Select All", elem_id="select_all_models", visible=False)
+ with gr.Row():
+ list_html = gr.HTML(value='
Click the search icon to load models. Use the filter icon to filter results.
')
+ with gr.Row():
+ download_progress = gr.HTML(value='', elem_id="DownloadProgress")
+ with gr.Row():
+ list_models = gr.Dropdown(label="Model:", choices=[], interactive=False, elem_id="quicksettings1", value=None)
+ list_versions = gr.Dropdown(label="Version:", choices=[], interactive=False, elem_id="quicksettings", value=None)
+ file_list = gr.Dropdown(label="File:", choices=[], interactive=False, elem_id="file_list", value=None)
+ with gr.Row():
+ with gr.Column(scale=4):
+ install_path = gr.Textbox(label="Download folder:", interactive=False, max_lines=1)
+ with gr.Column(scale=2):
+ sub_folder = gr.Dropdown(label="Sub folder:", choices=[], interactive=False, value=None)
+ with gr.Row():
+ with gr.Column(scale=4):
+ trained_tags = gr.Textbox(label='Trained tags (if any):', value=None, interactive=False, lines=1)
+ with gr.Column(scale=2, elem_id="spanWidth"):
+ base_model = gr.Textbox(label='Base model: ', value='', interactive=False, lines=1, elem_id="baseMdl")
+ model_filename = gr.Textbox(label="Model filename:", interactive=False, value=None)
+ with gr.Row():
+ save_info = gr.Button(value="Save model info", interactive=False)
+ save_images = gr.Button(value="Save images", interactive=False)
+ delete_model = gr.Button(value="Delete model", interactive=False, visible=False)
+ download_model = gr.Button(value="Download model", interactive=False)
+ subfolder_selected = gr.Dropdown(label="Sub folder for selected files:", choices=[], interactive=False, visible=False, value=None, allow_custom_value=True)
+ download_selected = gr.Button(value="Download all selected", interactive=False, visible=False, elem_id="download_all_button")
+ with gr.Row():
+ cancel_all_model = gr.Button(value="Cancel all downloads", interactive=False, visible=False)
+ cancel_model = gr.Button(value="Cancel current download", interactive=False, visible=False)
+ with gr.Row():
+ preview_html = gr.HTML(elem_id="civitai_preview_html")
+ with gr.Row(elem_id="backToTopContainer"):
+ back_to_top = gr.Button(value="↑", elem_id="backToTop")
+ with gr.Tab("Update Models"):
+ with gr.Row():
+ selected_tags = gr.CheckboxGroup(elem_id="selected_tags", label="Scan for:", choices=scan_choices)
+ with gr.Row():
+ overwrite_toggle = gr.Checkbox(elem_id="overwrite_toggle", label="Overwrite any existing previews, tags or descriptions.", value=True)
+ with gr.Row():
+ skip_hash_toggle = gr.Checkbox(elem_id="skip_hash_toggle", label="One-Time Hash Generation for externally downloaded models.", value=True)
+ with gr.Row():
+ save_all_tags = gr.Button(value="Update model info & tags", interactive=True, visible=True)
+ cancel_all_tags = gr.Button(value="Cancel updating model info & tags", interactive=False, visible=False)
+ with gr.Row():
+ tag_progress = gr.HTML(value='')
+ with gr.Row():
+ update_preview = gr.Button(value="Update model preview", interactive=True, visible=True)
+ cancel_update_preview = gr.Button(value="Cancel updating model previews", interactive=False, visible=False)
+ with gr.Row():
+ preview_progress = gr.HTML(value='')
+ with gr.Row():
+ ver_search = gr.Button(value="Scan for available updates", interactive=True, visible=True)
+ cancel_ver_search = gr.Button(value="Cancel updates scan", interactive=False, visible=False)
+ load_to_browser = gr.Button(value="Load outdated models to browser", interactive=False, visible=False)
+ with gr.Row():
+ version_progress = gr.HTML(value='')
+ with gr.Row():
+ load_installed = gr.Button(value="Load all installed models", interactive=True, visible=True)
+ cancel_installed = gr.Button(value="Cancel loading models", interactive=False, visible=False)
+ load_to_browser_installed = gr.Button(value="Load installed models to browser", interactive=False, visible=False)
+ with gr.Row():
+ installed_progress = gr.HTML(value='')
+ with gr.Tab("Download Queue"):
+
+ def get_style(size, left_border):
+ return f"flex-grow: {size};" + ("border-left: 1px solid var(--border-color-primary);" if left_border else "") + "border-bottom: 1px solid var(--border-color-primary);padding: 5px 10px 5px 10px;width: 0;"
+
+ download_manager_html = gr.HTML(elem_id="civitai_dl_list", value=f'''
+
+
Model:
+
Version:
+
Path:
+
Status:
+
Action:
+
+
+
+ In queue: (drag items to rearrange queue order)
+
+
+ ''')
+
+ #Invisible triggers/variables
+
+ model_id = gr.Textbox(visible=False)
+ queue_trigger = gr.Textbox(visible=False)
+ dl_url = gr.Textbox(visible=False)
+ civitai_text2img_output = gr.Textbox(visible=False)
+ civitai_text2img_input = gr.Textbox(elem_id="civitai_text2img_input", visible=False)
+ selected_model_list = gr.Textbox(elem_id="selected_model_list", visible=False)
+ selected_type_list = gr.Textbox(elem_id="selected_type_list", visible=False)
+ html_cancel_input = gr.Textbox(elem_id="html_cancel_input", visible=False)
+ queue_html_input = gr.Textbox(elem_id="queue_html_input", visible=False)
+ model_path_input = gr.Textbox(elem_id="model_path_input", visible=False)
+ arrange_dl_id = gr.Textbox(elem_id="arrange_dl_id", visible=False)
+ remove_dl_id = gr.Textbox(elem_id="remove_dl_id", visible=False)
+ model_select = gr.Textbox(elem_id="model_select", visible=False)
+ model_sent = gr.Textbox(elem_id="model_sent", visible=False)
+ type_sent = gr.Textbox(elem_id="type_sent", visible=False)
+ download_start = gr.Textbox(visible=False)
+ download_finish = gr.Textbox(visible=False)
+ tag_start = gr.Textbox(visible=False)
+ tag_finish = gr.Textbox(visible=False)
+ preview_start = gr.Textbox(visible=False)
+ preview_finish = gr.Textbox(visible=False)
+ ver_start = gr.Textbox(visible=False)
+ ver_finish = gr.Textbox(visible=False)
+ installed_start = gr.Textbox(visible=None)
+ installed_finish = gr.Textbox(visible=None)
+ delete_finish = gr.Textbox(visible=False)
+ current_model = gr.Textbox(visible=False)
+ current_sha256 = gr.Textbox(visible=False)
+ model_preview_html = gr.Textbox(visible=False)
+
+ def ToggleDate(toggle_date):
+ gl.sortNewest = toggle_date
+
+ def select_subfolder(sub_folder):
+ if sub_folder == "None" or sub_folder == "Only available if the selected files are of the same model type":
+ newpath = gl.main_folder
+ else:
+ newpath = gl.main_folder + sub_folder
+ return gr.Textbox.update(value=newpath)
+
+ # Javascript Functions #
+
+ list_html.change(fn=None, inputs=hide_installed, _js="(toggleValue) => hideInstalled(toggleValue)")
+ hide_installed.input(fn=None, inputs=hide_installed, _js="(toggleValue) => hideInstalled(toggleValue)")
+
+ civitai_text2img_output.change(fn=None, inputs=civitai_text2img_output, _js="(genInfo) => genInfo_to_txt2img(genInfo)")
+
+ download_selected.click(fn=None, _js="() => deselectAllModels()")
+
+ select_all.click(fn=None, _js="() => selectAllModels()")
+
+ list_models.select(fn=None, inputs=list_models, _js="(list_models) => select_model(list_models)")
+
+ preview_html.change(fn=None, _js="() => adjustFilterBoxAndButtons()")
+
+ back_to_top.click(fn=None, _js="() => BackToTop()")
+
+ page_slider.release(fn=None, _js="() => pressRefresh()")
+
+ card_updates = [queue_trigger, download_finish, delete_finish]
+ for func in card_updates:
+ func.change(fn=None, inputs=current_model, _js="(modelName) => updateCard(modelName)")
+
+ list_html.change(fn=None, inputs=show_nsfw, _js="(hideAndBlur) => toggleNSFWContent(hideAndBlur)")
+ show_nsfw.change(fn=None, inputs=show_nsfw, _js="(hideAndBlur) => toggleNSFWContent(hideAndBlur)")
+
+ list_html.change(fn=None, inputs=size_slider, _js="(size) => updateCardSize(size, size * 1.5)")
+ size_slider.change(fn=None, inputs=size_slider, _js="(size) => updateCardSize(size, size * 1.5)")
+
+ model_preview_html.change(fn=None, inputs=model_preview_html, _js="(html_input) => inputHTMLPreviewContent(html_input)")
+
+ download_manager_html.change(fn=None, _js="() => setSortable()")
+
+ # Filter button Functions #
+
+ def HTMLChange(input):
+ return gr.HTML.update(value=input)
+
+ queue_html_input.change(fn=HTMLChange, inputs=[queue_html_input], outputs=download_manager_html)
+
+ remove_dl_id.change(
+ fn=_download.remove_from_queue,
+ inputs=[remove_dl_id]
+ )
+
+ arrange_dl_id.change(
+ fn=_download.arrange_queue,
+ inputs=[arrange_dl_id]
+ )
+
+ html_cancel_input.change(
+ fn=_download.download_cancel
+ )
+
+ html_cancel_input.change(fn=None, _js="() => cancelCurrentDl()")
+
+ save_settings.click(
+ fn=saveSettings,
+ inputs=[
+ use_search_term,
+ content_type,
+ period_type,
+ sort_type,
+ base_filter,
+ create_json,
+ toggle_date,
+ only_liked,
+ hide_installed,
+ show_nsfw,
+ size_slider,
+ tile_count_slider
+ ]
+ )
+
+ toggle_date.input(
+ fn=ToggleDate,
+ inputs=[toggle_date]
+ )
+
+ # Model Button Functions #
+
+ civitai_text2img_input.change(fn=txt2img_output,inputs=civitai_text2img_input,outputs=civitai_text2img_output)
+
+ list_html.change(fn=all_visible,inputs=list_html,outputs=select_all)
+
+ def update_models_dropdown(input):
+ model_string = re.sub(r'\.\d{3}$', '', input)
+ model_name, model_id = _api.extract_model_info(model_string)
+ model_versions = _api.update_model_versions(model_id)
+ (html, tags, base_mdl, DwnButton, SaveImages, DelButton, filelist, filename, dl_url, id, current_sha256, install_path, sub_folder) = _api.update_model_info(model_string, model_versions.get('value'))
+ return (gr.Dropdown.update(value=model_string, interactive=True),
+ model_versions,html,tags,base_mdl,filename,install_path,sub_folder,DwnButton,SaveImages,DelButton,filelist,dl_url,id,current_sha256,
+ gr.Button.update(interactive=True))
+
+ model_select.change(
+ fn=update_models_dropdown,
+ inputs=[model_select],
+ outputs=[
+ list_models,
+ list_versions,
+ preview_html,
+ trained_tags,
+ base_model,
+ model_filename,
+ install_path,
+ sub_folder,
+ download_model,
+ save_images,
+ delete_model,
+ file_list,
+ dl_url,
+ model_id,
+ current_sha256,
+ save_info
+ ]
+ )
+
+ model_sent.change(
+ fn=_file.model_from_sent,
+ inputs=[model_sent, type_sent, tile_count_slider, model_path_input],
+ outputs=[model_preview_html]
+ )
+
+ sub_folder.select(
+ fn=select_subfolder,
+ inputs=[sub_folder],
+ outputs=[install_path]
+ )
+
+ list_versions.select(
+ fn=_api.update_model_info,
+ inputs=[
+ list_models,
+ list_versions
+ ],
+ outputs=[
+ preview_html,
+ trained_tags,
+ base_model,
+ download_model,
+ save_images,
+ delete_model,
+ file_list,
+ model_filename,
+ dl_url,
+ model_id,
+ current_sha256,
+ install_path,
+ sub_folder
+ ]
+ )
+
+ file_list.input(
+ fn=_api.update_file_info,
+ inputs=[
+ list_models,
+ list_versions,
+ file_list
+ ],
+ outputs=[
+ model_filename,
+ dl_url,
+ model_id,
+ current_sha256,
+ download_model,
+ delete_model,
+ install_path,
+ sub_folder
+ ]
+ )
+
+ # Download/Save Model Button Functions #
+
+ selected_model_list.change(
+ fn=show_multi_buttons,
+ inputs=[selected_model_list, selected_type_list, list_versions],
+ outputs=[
+ download_selected,
+ download_model,
+ delete_model,
+ save_info,
+ save_images,
+ subfolder_selected
+ ]
+ )
+
+ download_model.click(
+ fn=_download.download_start,
+ inputs=[
+ download_start,
+ dl_url,
+ model_filename,
+ install_path,
+ list_models,
+ list_versions,
+ current_sha256,
+ model_id,
+ create_json,
+ download_manager_html
+ ],
+ outputs=[
+ download_model,
+ cancel_model,
+ cancel_all_model,
+ download_start,
+ download_progress,
+ download_manager_html
+ ]
+ )
+
+ download_selected.click(
+ fn=_download.selected_to_queue,
+ inputs=[
+ selected_model_list,
+ subfolder_selected,
+ download_start,
+ create_json,
+ download_manager_html
+ ],
+ outputs=[
+ download_model,
+ cancel_model,
+ cancel_all_model,
+ download_start,
+ download_progress,
+ download_manager_html
+ ]
+ )
+
+
+ for component in [download_start, queue_trigger]:
+ component.change(fn=None, _js="() => setDownloadProgressBar()")
+ component.change(
+ fn=_download.download_create_thread,
+ inputs=[download_finish, queue_trigger],
+ outputs=[
+ download_progress,
+ current_model,
+ download_finish,
+ queue_trigger
+ ]
+ )
+
+ download_finish.change(
+ fn=_download.download_finish,
+ inputs=[
+ model_filename,
+ list_versions,
+ model_id
+ ],
+ outputs=[
+ download_model,
+ cancel_model,
+ cancel_all_model,
+ delete_model,
+ download_progress,
+ list_versions
+ ]
+ )
+
+ cancel_model.click(_download.download_cancel)
+ cancel_all_model.click(_download.download_cancel_all)
+
+ cancel_model.click(fn=None, _js="() => cancelCurrentDl()")
+ cancel_all_model.click(fn=None, _js="() => cancelAllDl()")
+
+ delete_model.click(
+ fn=_file.delete_model,
+ inputs=[
+ delete_finish,
+ model_filename,
+ list_models,
+ list_versions,
+ current_sha256,
+ selected_model_list
+ ],
+ outputs=[
+ download_model,
+ cancel_model,
+ delete_model,
+ delete_finish,
+ current_model,
+ list_versions
+ ]
+ )
+
+ save_info.click(
+ fn=_file.save_model_info,
+ inputs=[
+ install_path,
+ model_filename,
+ sub_folder,
+ current_sha256,
+ preview_html
+ ],
+ outputs=[]
+ )
+
+ save_images.click(
+ fn=_file.save_images,
+ inputs=[
+ preview_html,
+ model_filename,
+ install_path,
+ sub_folder
+ ],
+ outputs=[]
+ )
+
+ # Common input&output lists #
+
+ page_inputs = [
+ content_type,
+ sort_type,
+ period_type,
+ use_search_term,
+ search_term,
+ page_slider,
+ base_filter,
+ only_liked,
+ show_nsfw,
+ tile_count_slider
+ ]
+
+ page_outputs = [
+ list_models,
+ list_versions,
+ list_html,
+ get_prev_page,
+ get_next_page,
+ page_slider,
+ save_info,
+ save_images,
+ download_model,
+ delete_model,
+ install_path,
+ sub_folder,
+ file_list,
+ preview_html,
+ trained_tags,
+ base_model,
+ model_filename
+ ]
+
+ file_scan_inputs = [
+ selected_tags,
+ ver_finish,
+ tag_finish,
+ installed_finish,
+ preview_finish,
+ overwrite_toggle,
+ tile_count_slider,
+ skip_hash_toggle
+ ]
+
+ load_to_browser_inputs = [
+ content_type,
+ sort_type,
+ period_type,
+ use_search_term,
+ search_term,
+ tile_count_slider,
+ base_filter,
+ show_nsfw
+ ]
+
+ cancel_btn_list = [cancel_all_tags,cancel_ver_search,cancel_installed,cancel_update_preview]
+
+ browser = [ver_search,save_all_tags,load_installed,update_preview]
+
+ browser_installed_load = [cancel_installed,load_to_browser_installed,installed_progress]
+ browser_load = [cancel_ver_search,load_to_browser,version_progress]
+
+ browser_installed_list = page_outputs + browser + browser_installed_load
+ browser_list = page_outputs + browser + browser_load
+
+ # Page Button Functions #
+
+ page_btn_list = {
+ refresh.click: _api.update_model_list,
+ search_term.submit: _api.update_model_list,
+ get_next_page.click: _api.update_next_page,
+ get_prev_page.click: _api.update_prev_page
+ }
+
+ for trigger, function in page_btn_list.items():
+ trigger(fn=function, inputs=page_inputs, outputs=page_outputs)
+ trigger(fn=None, _js="() => multi_model_select()")
+
+ for button in cancel_btn_list:
+ button.click(fn=_file.cancel_scan)
+
+ # Update model Functions #
+
+ ver_search.click(
+ fn=_file.ver_search_start,
+ inputs=[ver_start],
+ outputs=[
+ ver_start,
+ ver_search,
+ cancel_ver_search,
+ load_installed,
+ save_all_tags,
+ update_preview,
+ version_progress
+ ]
+ )
+
+ ver_start.change(
+ fn=_file.file_scan,
+ inputs=file_scan_inputs,
+ outputs=[
+ version_progress,
+ ver_finish
+ ]
+ )
+
+ ver_finish.change(
+ fn=_file.scan_finish,
+ outputs=[
+ ver_search,
+ save_all_tags,
+ load_installed,
+ update_preview,
+ cancel_ver_search,
+ load_to_browser
+ ]
+ )
+
+ load_installed.click(
+ fn=_file.installed_models_start,
+ inputs=[installed_start],
+ outputs=[
+ installed_start,
+ load_installed,
+ cancel_installed,
+ ver_search,
+ save_all_tags,
+ update_preview,
+ installed_progress
+ ]
+ )
+
+ installed_start.change(
+ fn=_file.file_scan,
+ inputs=file_scan_inputs,
+ outputs=[
+ installed_progress,
+ installed_finish
+ ]
+ )
+
+ installed_finish.change(
+ fn=_file.scan_finish,
+ outputs=[
+ ver_search,
+ save_all_tags,
+ load_installed,
+ update_preview,
+ cancel_installed,
+ load_to_browser_installed
+ ]
+ )
+
+ save_all_tags.click(
+ fn=_file.save_tag_start,
+ inputs=[tag_start],
+ outputs=[
+ tag_start,
+ save_all_tags,
+ cancel_all_tags,
+ load_installed,
+ ver_search,
+ update_preview,
+ tag_progress
+ ]
+ )
+
+ tag_start.change(
+ fn=_file.file_scan,
+ inputs=file_scan_inputs,
+ outputs=[
+ tag_progress,
+ tag_finish
+ ]
+ )
+
+ tag_finish.change(
+ fn=_file.save_tag_finish,
+ outputs=[
+ ver_search,
+ save_all_tags,
+ load_installed,
+ update_preview,
+ cancel_all_tags
+ ]
+ )
+
+ update_preview.click(
+ fn=_file.save_preview_start,
+ inputs=[preview_start],
+ outputs=[
+ preview_start,
+ update_preview,
+ cancel_update_preview,
+ load_installed,
+ ver_search,
+ save_all_tags,
+ preview_progress
+ ]
+ )
+
+ preview_start.change(
+ fn=_file.file_scan,
+ inputs=file_scan_inputs,
+ outputs=[
+ preview_progress,
+ preview_finish
+ ]
+ )
+
+ preview_finish.change(
+ fn=_file.save_preview_finish,
+ outputs=[
+ ver_search,
+ save_all_tags,
+ load_installed,
+ update_preview,
+ cancel_update_preview
+ ]
+ )
+
+ load_to_browser_installed.click(
+ fn=_file.load_to_browser,
+ inputs=load_to_browser_inputs,
+ outputs=browser_installed_list
+ )
+
+ load_to_browser.click(
+ fn=_file.load_to_browser,
+ inputs=load_to_browser_inputs,
+ outputs=browser_list
+ )
+
+ if ver_bool:
+ tab_name = "CivitAI Browser+"
+ else:
+ tab_name = "Civitai Browser+"
+
+ return (civitai_interface, tab_name, "civitai_interface"),
+
+def subfolder_list(folder, desc=None):
+ insert_sub_1 = getattr(opts, "insert_sub_1", False)
+ insert_sub_2 = getattr(opts, "insert_sub_2", False)
+ insert_sub_3 = getattr(opts, "insert_sub_3", False)
+ insert_sub_4 = getattr(opts, "insert_sub_4", False)
+ insert_sub_5 = getattr(opts, "insert_sub_5", False)
+ insert_sub_6 = getattr(opts, "insert_sub_6", False)
+ insert_sub_7 = getattr(opts, "insert_sub_7", False)
+ insert_sub_8 = getattr(opts, "insert_sub_8", False)
+ insert_sub_9 = getattr(opts, "insert_sub_9", False)
+ insert_sub_10 = getattr(opts, "insert_sub_10", False)
+ insert_sub_11 = getattr(opts, "insert_sub_11", False)
+ insert_sub_12 = getattr(opts, "insert_sub_12", False)
+ insert_sub_13 = getattr(opts, "insert_sub_13", False)
+ insert_sub_14 = getattr(opts, "insert_sub_14", False)
+ dot_subfolders = getattr(opts, "dot_subfolders", True)
+
+ if folder == None:
+ return
+ try:
+ model_folder = _api.contenttype_folder(folder, desc)
+ sub_folders = ["None"]
+ for root, dirs, _ in os.walk(model_folder, followlinks=True):
+ if dot_subfolders:
+ dirs = [d for d in dirs if not d.startswith('.')]
+ dirs = [d for d in dirs if not any(part.startswith('.') for part in os.path.join(root, d).split(os.sep))]
+ for d in dirs:
+ sub_folder = os.path.relpath(os.path.join(root, d), model_folder)
+ if sub_folder:
+ sub_folders.append(f'{os.sep}{sub_folder}')
+
+ sub_folders.remove("None")
+ sub_folders = sorted(sub_folders, key=lambda x: (x.lower(), x))
+ sub_folders.insert(0, "None")
+ if insert_sub_1:
+ sub_folders.insert(1, f"{os.sep}Base model")
+ if insert_sub_2:
+ sub_folders.insert(2, f"{os.sep}Base model{os.sep}Author name")
+ if insert_sub_3:
+ sub_folders.insert(3, f"{os.sep}Base model{os.sep}Author name{os.sep}Model name")
+ if insert_sub_4:
+ sub_folders.insert(4, f"{os.sep}Base model{os.sep}Author name{os.sep}Model name{os.sep}Version name")
+ if insert_sub_5:
+ sub_folders.insert(5, f"{os.sep}Base model{os.sep}Model name")
+ if insert_sub_6:
+ sub_folders.insert(6, f"{os.sep}Base model{os.sep}Model name{os.sep}Model version")
+ if insert_sub_7:
+ sub_folders.insert(7, f"{os.sep}Author name")
+ if insert_sub_8:
+ sub_folders.insert(8, f"{os.sep}Author name{os.sep}Base model")
+ if insert_sub_9:
+ sub_folders.insert(9, f"{os.sep}Author name{os.sep}Base model{os.sep}Model name")
+ if insert_sub_10:
+ sub_folders.insert(10, f"{os.sep}Author name{os.sep}Base model{os.sep}Model name{os.sep}Model version")
+ if insert_sub_11:
+ sub_folders.insert(11, f"{os.sep}Author name{os.sep}Model name")
+ if insert_sub_12:
+ sub_folders.insert(12, f"{os.sep}Author name{os.sep}Model name{os.sep}Model version")
+ if insert_sub_13:
+ sub_folders.insert(13, f"{os.sep}Model name")
+ if insert_sub_14:
+ sub_folders.insert(14, f"{os.sep}Model name{os.sep}Model version")
+
+ list = set()
+ sub_folders = [x for x in sub_folders if not (x in list or list.add(x))]
+ except:
+ return None
+ return sub_folders
+
+def make_lambda(folder, desc):
+ return lambda: {"choices": subfolder_list(folder, desc)}
+
+def on_ui_settings():
+ if ver_bool:
+ browser = ("civitai_browser", "Browser")
+ download = ("civitai_browser_download", "Downloads")
+ from modules.options import categories
+ categories.register_category("civitai_browser_plus", "CivitAI Browser+")
+ cat_id = "civitai_browser_plus"
+ else:
+ section = ("civitai_browser_plus", "CivitAI Browser+")
+ browser = download = section
+ if not (hasattr(shared.OptionInfo, "info") and callable(getattr(shared.OptionInfo, "info"))):
+ def info(self, info):
+ self.label += f" ({info})"
+ return self
+ shared.OptionInfo.info = info
+
+ # Download Options
+ shared.opts.add_option(
+ "use_aria2",
+ shared.OptionInfo(
+ True,
+ "Download models using Aria2",
+ section=download,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("Disable this option if you're experiencing any issues with downloads.")
+ )
+
+ shared.opts.add_option(
+ "disable_dns",
+ shared.OptionInfo(
+ False,
+ "Disable Async DNS for Aria2",
+ section=download,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("Useful for users who use PortMaster or other software that controls the DNS")
+ )
+
+ shared.opts.add_option(
+ "show_log",
+ shared.OptionInfo(
+ False,
+ "Show Aria2 logs in console",
+ section=download,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("Requires UI reload")
+ )
+
+ shared.opts.add_option(
+ "split_aria2",
+ shared.OptionInfo(
+ 64,
+ "Number of connections to use for downloading a model",
+ gr.Slider,
+ lambda: {"maximum": "64", "minimum": "1", "step": "1"},
+ section=download,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("Only applies to Aria2")
+ )
+
+ shared.opts.add_option(
+ "aria2_flags",
+ shared.OptionInfo(
+ r"",
+ "Custom Aria2 command line flags",
+ section=download,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("Requires UI reload")
+ )
+
+ shared.opts.add_option(
+ "unpack_zip",
+ shared.OptionInfo(
+ False,
+ "Automatically unpack .zip files after downloading",
+ section=download,
+ **({'category_id': cat_id} if ver_bool else {})
+ )
+ )
+
+ shared.opts.add_option(
+ "save_api_info",
+ shared.OptionInfo(
+ False,
+ "Save API info of model when saving model info",
+ section=download,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("creates an api_info.json file when saving any model info with all the API data of the model")
+ )
+
+ shared.opts.add_option(
+ "auto_save_all_img",
+ shared.OptionInfo(
+ False,
+ "Automatically save all images",
+ section=download,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("Automatically saves all the images of a model after downloading")
+ )
+
+ # Browser Options
+ shared.opts.add_option(
+ "custom_api_key",
+ shared.OptionInfo(
+ r"",
+ "Personal CivitAI API key",
+ section=browser,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("You can create your own API key in your CivitAI account settings, this required for some downloads, Requires UI reload")
+ )
+
+ shared.opts.add_option(
+ "hide_early_access",
+ shared.OptionInfo(
+ True,
+ "Hide early access models",
+ section=browser,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("Early access models are only downloadable for supporter tier members")
+ )
+
+ shared.opts.add_option(
+ "use_LORA",
+ shared.OptionInfo(
+ ver_bool,
+ "Treat LoCon's as LORA's",
+ section=browser,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("SD-WebUI v1.5 and higher treats LoCON's the same as LORA's, Requires UI reload")
+ )
+
+ shared.opts.add_option(
+ "dot_subfolders",
+ shared.OptionInfo(
+ True,
+ "Hide sub-folders that start with a '.'",
+ section=browser,
+ **({'category_id': cat_id} if ver_bool else {})
+ )
+ )
+
+ shared.opts.add_option(
+ "use_local_html",
+ shared.OptionInfo(
+ False,
+ "Use local HTML file for model info",
+ section=browser,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("Uses the matching local HTML file when pressing CivitAI button on model cards in txt2img and img2img")
+ )
+
+ shared.opts.add_option(
+ "page_header",
+ shared.OptionInfo(
+ False,
+ "Page navigation as header",
+ section=browser,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("Keeps the page navigation always visible at the top, Requires UI reload")
+ )
+
+ shared.opts.add_option(
+ "video_playback",
+ shared.OptionInfo(
+ True,
+ 'Gif/video playback in the browser',
+ section=browser,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("Disable this option if you're experiencing high CPU usage during video/gif playback")
+ )
+
+ shared.opts.add_option(
+ "individual_meta_btn",
+ shared.OptionInfo(
+ True,
+ 'Individual prompt buttons',
+ section=browser,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("Turns individual prompts from an example image into a button to send it to txt2img")
+ )
+
+ shared.opts.add_option(
+ "update_log",
+ shared.OptionInfo(
+ True,
+ 'Show console logs during update scanning',
+ section=browser,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info('Shows the "is currently outdated" messages in the console when scanning models for available updates')
+ )
+
+ shared.opts.add_option(
+ "image_location",
+ shared.OptionInfo(
+ r"",
+ "Custom save images location",
+ section=browser,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("Overrides the download folder location when saving images.")
+ )
+
+ shared.opts.add_option(
+ "sub_image_location",
+ shared.OptionInfo(
+ True,
+ 'Use sub folders inside custom images location',
+ section=browser,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("Will append any content type and sub folders to the custom path.")
+ )
+
+ shared.opts.add_option(
+ "local_path_in_html",
+ shared.OptionInfo(
+ False,
+ "Use local images in the HTML",
+ section=browser,
+ **({'category_id': cat_id} if ver_bool else {})
+ ).info("Does not work in combination with the \"Use local HTML file for model info\" option!")
+ )
+
+ shared.opts.add_option(
+ "save_to_custom",
+ shared.OptionInfo(
+ False,
+ "Store the HTML and api_info in the custom images location",
+ section=browser,
+ **({'category_id': cat_id} if ver_bool else {})
+ )
+ )
+
+ id_and_sub_options = {
+ "1" : f"{os.sep}Base model",
+ "2" : f"{os.sep}Base model{os.sep}Author name",
+ "3" : f"{os.sep}Base model{os.sep}Author name{os.sep}Model name",
+ "4" : f"{os.sep}Base model{os.sep}Author name{os.sep}Model name{os.sep}Model version",
+ "5" : f"{os.sep}Base model{os.sep}Model name",
+ "6" : f"{os.sep}Base model{os.sep}Model name{os.sep}Model version",
+ "7" : f"{os.sep}Author name",
+ "8" : f"{os.sep}Author name{os.sep}Base model",
+ "9" : f"{os.sep}Author name{os.sep}Base model{os.sep}Model name",
+ "10" : f"{os.sep}Author name{os.sep}Base model{os.sep}Model name{os.sep}Model version",
+ "11" : f"{os.sep}Author name{os.sep}Model name",
+ "12" : f"{os.sep}Author name{os.sep}Model name{os.sep}Model version",
+ "13" : f"{os.sep}Model name",
+ "14" : f"{os.sep}Model name{os.sep}Model version",
+ }
+
+ for number, string in id_and_sub_options.items():
+ shared.opts.add_option(
+ f"insert_sub_{number}",
+ shared.OptionInfo(
+ False,
+ f"Insert: [{string}]",
+ section=browser,
+ **({'category_id': cat_id} if ver_bool else {})
+ )
+ )
+
+ use_LORA = getattr(opts, "use_LORA", False)
+
+ # Default sub folders
+ folders = [
+ "Checkpoint",
+ "LORA & LoCon" if use_LORA else "LORA",
+ "LoCon" if not use_LORA else None,
+ "TextualInversion",
+ "Poses",
+ "Controlnet",
+ "Hypernetwork",
+ "MotionModule",
+ ("Upscaler", "SWINIR"),
+ ("Upscaler", "REALESRGAN"),
+ ("Upscaler", "GFPGAN"),
+ ("Upscaler", "BSRGAN"),
+ ("Upscaler", "ESRGAN"),
+ "VAE",
+ "AestheticGradient",
+ "Wildcards",
+ "Workflows",
+ "Other"
+ ]
+
+ for folder in folders:
+ if folder == None:
+ continue
+ desc = None
+ if isinstance(folder, tuple):
+ folder_name = " - ".join(folder)
+ setting_name = f"{folder[1]}_upscale"
+ folder = folder[0]
+ desc = folder[1]
+ else:
+ folder_name = folder
+ setting_name = folder
+ if folder == "LORA & LoCon":
+ folder = "LORA"
+ setting_name = "LORA_LoCon"
+
+ shared.opts.add_option(f"{setting_name}_subfolder", shared.OptionInfo("None", folder_name, gr.Dropdown, make_lambda(folder, desc), section=download, **({'category_id': cat_id} if ver_bool else {})))
+
+script_callbacks.on_ui_tabs(on_ui_tabs)
+script_callbacks.on_ui_settings(on_ui_settings)
\ No newline at end of file
diff --git a/sd-civitai-browser-plus/style.css b/sd-civitai-browser-plus/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..86b5fba4d7203ce469a23b04d51802e693a7086b
--- /dev/null
+++ b/sd-civitai-browser-plus/style.css
@@ -0,0 +1,695 @@
+/* Card list HTML */
+.civmodellist {
+ display: flex;
+ flex-wrap: wrap;
+ justify-content: center;
+}
+
+.civmodellist figure {
+ margin: 6px;
+ transition: transform .3s ease-out, box-shadow 0.3s ease;
+ cursor: pointer;
+ border-radius: 10px;
+}
+
+.civmodelcard {
+ position: relative;
+}
+
+.civmodelcard:hover {
+ transform: scale(1.1, 1.1);
+ position: relative;
+ z-index: var(--layer-5);
+ box-shadow: 0px 0px 1px 3px whitesmoke;
+}
+
+.civmodelcardinstalled {
+ box-shadow: 0px 0px 1px 3px aquamarine;
+}
+
+.civmodelcardoutdated {
+ box-shadow: 0px 0px 1px 3px orange;
+}
+
+.civmodelcard:hover figcaption{
+ bottom: initial;
+ background-color: rgba(32, 32, 32, 0.9);
+}
+
+.civmodelcard img, .civmodelcard .video-bg {
+ width: 8em;
+ height: 12em;
+ object-fit: cover;
+ border-radius: 10px;
+}
+
+.civmodelcard figcaption {
+ position: absolute;
+ bottom: 5px;
+ text-align: center;
+ width: 8em;
+ word-break: break-word;
+ background-color: rgba(32, 32, 32, 0.5);
+ color: white !important;
+}
+
+/* End of Card list HTML */
+
+#quicksettings > div{
+ max-width: None !important;
+ width: auto !important;
+}
+
+#togglesL{
+ margin-top: 3px;
+}
+
+#toggles{
+ margin-top: -10px;
+}
+
+#searchType > div {
+ gap: 0.5em;
+}
+
+#backToTopContainer {
+ position: fixed;
+ bottom: 0;
+ right: 0;
+ display: flex;
+ justify-content: flex-end;
+ z-index: 150;
+ pointer-events: none;
+ margin: 20px 51px 20px 20px;
+}
+
+#backToTop {
+ margin: 0;
+ max-width: 60px;
+ min-width: unset;
+ z-index: 200;
+ pointer-events: auto;
+}
+
+#browserTab {
+ position: relative;
+}
+
+#browserTab > div {
+ gap: var(--layout-gap) !important;
+}
+
+#browserTab > div > #header {
+ position: -webkit-sticky;
+ position: sticky;
+ top: 0;
+ background-color: var(--neutral-950);
+ z-index: 60;
+}
+
+.acss-14flpmm .gap:has(#quicksettings):first-child {
+ gap: var(--layout-gap);
+}
+
+#txt2img_seed > label > input{
+ height: unset !important;
+}
+
+#browserTab > div > #header, #browserTab > div > #header_off {
+ display: flex;
+ flex-direction: column;
+ padding-top: 15px;
+ margin-top: -15px;
+}
+
+#toggle1, #toggle2, #toggle3, #toggle4, #toggle4_api, #toggle5{
+ margin-top: 5px;
+ margin-right: 0px;
+ margin-left: 0px;
+ display: flex;
+ justify-content: center;
+}
+
+#toggle1L, #toggle2L, #toggle3L, #toggle4L, #toggle4L_api, #toggle5L, #overwrite_toggle, #skip_hash_toggle{
+ display: flex;
+ justify-content: center;
+}
+
+#centerText, #searchType {
+ text-align: center;
+}
+
+#browserTab {
+ min-height: 650px;
+}
+
+#download_all_button {
+ max-height: 40px;
+ height: 40px;
+ align-self: end;
+ margin-bottom: 1px;
+}
+
+#searchBox > label > textarea {
+ padding-top: 11px !important;
+}
+
+#searchBox {
+ max-width: 800px;
+ align-self: center;
+}
+
+#baseMdl {
+ min-width: 100px !important;
+ max-width: 100px !important;
+}
+
+#spanWidth {
+ display: flex !important;
+ flex-direction: row;
+}
+
+#spanWidth > div {
+ flex-wrap: nowrap;
+}
+
+.gradio-container-3-32-0 .prose :last-child {
+ margin-bottom: auto !important;
+}
+
+.date-section {
+ display: block;
+ width: 100%;
+ margin-bottom: 5px;
+ text-align: center;
+}
+
+.card-row {
+ display: flex;
+ flex-wrap: wrap;
+ justify-content: center;
+}
+
+#selected_tags {
+ text-align: center;
+}
+
+#pageBtn1, #pageBtn2 {
+ max-width: 120px !important;
+ min-width: 50px !important;
+}
+
+#pageSlider {
+ max-height: 44px;
+}
+
+#pageSlider > div:nth-child(2) {
+ max-height: 25px;
+}
+
+#pageBoxMobile {
+ display: flex;
+ justify-content: space-between;
+}
+
+#pageBox {
+ display: flex;
+ justify-content: center;
+ align-self: center;
+ max-width: 950px !important;
+}
+
+#pageBox > div:first-child {
+ align-items: end;
+}
+
+#refreshBtn, #refreshBtnL {
+ align-self: end;
+ height: 42px !important;
+ min-height: 42px !important;
+ max-height: 42px !important;
+ max-width: 42px !important;
+ min-width: 42px !important;
+ width: 42px !important;
+ padding: 0px !important;
+}
+
+#refreshBtn > img,
+#refreshBtnL > img {
+ margin: unset;
+}
+
+#searchRow {
+ max-width: 800px;
+ align-self: center;
+}
+
+#save_set_box {
+ display: flex;
+ justify-content: center;
+}
+
+#save_set_btn {
+ max-width: 220px !important;
+ min-width: 220px !important;
+ margin-bottom: -6px;
+ padding: 5px;
+ height: unset !important;
+ min-height: 35px !important;
+}
+
+#searchType > div:nth-child(3) {
+ justify-content: center;
+}
+
+.custom-checkbox {
+ position: absolute;
+ top: 10px;
+ right: 10px;
+ width: 20px;
+ min-width: 20px;
+ height: 20px;
+ background: #111B;
+ border-radius: var(--checkbox-border-radius);
+ border: 1px solid #bbbbbb;
+ cursor: pointer;
+}
+
+.custom-checkbox:hover {
+ border-color: #ffffff;
+}
+
+.model-checkbox:checked + .custom-checkbox {
+ background-color: var(--checkbox-background-color-selected);
+ border-color: var(--checkbox-border-color-selected);
+ background-image: var(--checkbox-check);
+ background-size: contain;
+ background-position: center;
+ background-repeat: no-repeat;
+}
+
+.open-in-civitai {
+ font-size: 18pt;
+ color: var(--body-text-color);
+ display: flex;
+ justify-content: center;
+ margin-top: -12px;
+}
+
+#model_header:hover{
+ color: var(--link-text-color-hover);
+}
+
+.civitai-txt2img-btn:hover {
+ border-color: var(--button-secondary-border-color-hover);
+ background: var(--button-secondary-background-fill-hover);
+ color: var(--button-secondary-text-color-hover);
+}
+
+.civitai-txt2img-btn {
+ border-radius: var(--button-large-radius);
+ border: var(--button-border-width) solid var(--button-secondary-border-color);
+ padding: var(--button-large-padding);
+ font-weight: var(--button-large-text-weight);
+ font-size: var(--button-large-text-size);
+ background: var(--button-secondary-background-fill);
+ color: var(--button-secondary-text-color);
+}
+
+.civitai-tags-container {
+ display: flex;
+ flex-wrap: wrap;
+ gap: 5px;
+}
+
+.civitai-tag,
+.civitai-meta,
+.civitai-meta-btn {
+ background-color: var(--error-background-fill);
+ border-radius: 8px;
+ padding: 4px 6px;
+ border: 1px solid var(--input-border-color);
+}
+
+.civitai-meta-btn:hover {
+ cursor: pointer;
+ background-color: var(--input-background-fill);
+}
+
+#select_all_models_container {
+ display: flex;
+ justify-content: flex-end;
+}
+
+#select_all_models {
+ max-width: 100px;
+ min-width: 100px;
+ min-height: 30px;
+ padding: 0px;
+ margin-top: -25px;
+}
+
+.civitai_dl_item,
+.civitai_dl_item_completed,
+.civitai_dl_item_failed {
+ background-color: var(--error-background-fill);
+ border-radius: 8px;
+ padding: 5px 0px;
+ border: 1px solid var(--input-border-color);
+ margin: 10px 0px;
+}
+
+.civitai_dl_item_failed > .dl_stat > .dl_progress_bar {
+ background-color: transparent !important;
+ padding: 0px 0px 2px 0px !important;
+}
+
+.dl_progress_bar {
+ background-color: var(--button-primary-border-color);
+ color: var(--body-text-color);
+ padding: 0px 0px 2px 5px;
+ border-radius: 8px;
+ transition: width 0.5s ease-in-out;
+}
+
+.dl_progress_bar::before,
+.dl_progress_bar::after {
+ content: "";
+ display: table;
+ clear: both;
+}
+
+.civitai-btn-text:hover {
+ color: var(--link-text-color-hover);
+ cursor: pointer;
+}
+/* Customized Accordion Filter */
+
+#filterBox,
+#filterBoxL {
+ align-self: end;
+ height: 42px;
+ max-width: 42px;
+ padding: unset !important;
+ margin: 0px !important;
+ display: flex;
+ justify-content: center;
+}
+
+#filterBox {
+ background: var(--button-secondary-background-fill);
+}
+
+#filterBoxL {
+ background: var(--input-background-fill);
+}
+
+#filterBox:hover,
+#filterBoxL:hover {
+ background: var(--button-secondary-background-fill-hover);
+}
+
+#filterBox .label-wrap.open,
+#filterBoxL .label-wrap.open{
+ border-bottom: unset !important;
+ background: var(--button-secondary-background-fill-hover);
+ border-radius: 7px !important;
+ height: 40px;
+}
+
+#filterBox > div:nth-child(3),
+#filterBoxL > div:nth-child(3) {
+ padding: 20px;
+ position: absolute;
+ border-radius: 10px;
+ width: 300px;
+ z-index: 100 !important;
+ margin-top: 55px;
+}
+
+.browser_tooltip {
+ box-shadow: var(--body-text-color) 0px 0px 2px 0px;
+ background: var(--background-fill-primary);
+ color: var(--body-text-color);
+ border-radius: 3px;
+ padding: 10px;
+ position: absolute;
+ z-index: 50;
+ margin-top: 30px;
+}
+
+#toggle4 > label > span, #toggle4L > label > span {
+ color: var(--neutral-400);
+}
+
+#filterBox > div:nth-child(3), #toggle4 > div:nth-child(3) {
+ background: var(--background-fill-primary);
+}
+
+#filterBoxL > div:nth-child(3), #toggle4L > div:nth-child(3) {
+ background: var(--neutral-950);
+}
+
+#filterBox > div:nth-child(2),
+#filterBoxL > div:nth-child(2) {
+ padding: 10px !important;
+}
+
+#filterBox .gradio-slider input[type="number"],
+#filterBoxL .gradio-slider input[type="number"] {
+ width: 70px !important;
+}
+
+#pageBox .gradio-slider input[type="number"] {
+ width: 5em !important;
+}
+
+#filterBox > div:nth-child(2) > span:nth-child(1),
+#filterBoxL > div:nth-child(2) > span:nth-child(1) {
+ display: none;
+}
+
+#filterBox > div:nth-child(2) > span:nth-child(2),
+#filterBoxL > div:nth-child(2) > span:nth-child(2) {
+ transform: rotate(0deg) !important;
+ transition: 0s !important;
+ display: inline-block;
+ width: 24px;
+ height: 24px;
+ font-size: 0;
+ color: transparent;
+ overflow: hidden;
+}
+
+#filterBox > div:nth-child(2) > span:nth-child(2)::before,
+#filterBoxL > div:nth-child(2) > span:nth-child(2)::before {
+ content: "";
+ display: block;
+ width: 100%;
+ height: 100%;
+}
+
+/* End of Custom Accordion */
+
+.goto-civitbrowser.card-button {
+ filter: drop-shadow(2px 2px 3px black);
+}
+
+.goto-civitbrowser.card-button:hover svg {
+ fill: red !important;
+}
+
+/* Custom settings Accordion */
+#settings-accordion {
+ border: 1px solid var(--block-border-color);
+ border-radius: 8px;
+ margin: 15px 0px 2px 0px;
+ padding: 8px 8px;
+}
+
+#accordionToggle {
+ width: 100%;
+ display: flex;
+ font-size: 12pt;
+ justify-content: space-between;
+}
+
+#selected_tags > div {
+ justify-content: center;
+ padding-top: 10px;
+ padding-bottom: 20px;
+}
+
+#civitai_preview_html .model-block {
+ box-shadow: 0px 0px 1px 3px #3339ff30;
+ border-radius: 10px;
+ padding: 1px 20px 10px;
+ margin-bottom: 20px;
+}
+
+#civitai_preview_html .model-block code {
+ white-space: pre-wrap;
+}
+
+#civitai_preview_html .model-block dl {
+ overflow-wrap: anywhere;
+}
+
+#civitai_preview_html .sampleimgs .model-block img,
+#civitai_preview_html .sampleimgs .model-block video {
+ padding-top: 1em;
+ max-width: 20em;
+ cursor: zoom-in;
+ transition: max-width 0.1s;
+}
+
+/* Preview Image zoom */
+#civitai_preview_html .zoom-radio {
+ display: none!important;
+}
+
+/* Style for when the image is clicked (radio button checked) */
+#civitai_preview_html .zoom-radio:checked + label > img,
+#civitai_preview_html .zoom-radio:checked + label > video {
+ max-width: 95vw;
+ max-height: 95vh;
+ padding-top: 0px;
+ cursor: zoom-out;
+ position: fixed;
+ top: 50%;
+ left: 50%;
+ transform: translate(-50%, -50%);
+ z-index: 1000; /* Higher than the overlay */
+ pointer-events: none; /* Allow clicks to penetrate through to the overlay for resetting */
+}
+
+/* Overlay for resetting zoomed state */
+#civitai_preview_html .zoom-overlay {
+ display: none;
+ position: fixed;
+ top: 0;
+ left: 0;
+ right: 0;
+ bottom: 0;
+ background: rgba(0, 0, 0, .5);
+ z-index: 999; /* Below the zoomed image */
+ cursor: zoom-out;
+}
+
+#civitai_preview_html .zoom-radio:checked + label + .zoom-overlay {
+ display: block;
+ pointer-events: all; /* Capture click events when displayed */
+}
+
+#civitai_preview_html .zoom-img-container {
+ min-width: 20em;
+}
+
+#civitai_preview_html .model-uploader {
+ border-bottom: 1px solid;
+ padding-bottom: 10px;
+ }
+
+#civitai_preview_html .model-description {
+ border-top: 1px solid;
+ padding-bottom: 10px;
+ margin-bottom: 10px;
+ }
+
+/*Avatar CSS mostly copied from CivitAI, but 48px instead of 32px*/
+#civitai_preview_html .avatar {
+ user-select: none;
+ overflow: hidden;
+ width: 48px;
+ height: 48px;
+ min-width: 48px;
+ border-radius: 48px;
+ text-decoration: none;
+ border: 0;
+ padding: 0;
+ background-color: rgba(0,0,0,0.31);
+ display: inline-block!important;
+ margin-left: 5px!important;
+ vertical-align: middle;
+}
+
+#civitai_preview_html .avatar img {
+ object-fit: cover;
+ width: 100%;
+ height: 100%;
+ display: block;
+ overflow-clip-margin: content-box;
+ overflow: clip;
+ border-style: none;
+}
+
+#civitai_preview_html dt {
+ font-size: medium;
+ color: #80a6c8!important;
+}
+
+#civitai_preview_html dd {
+ padding: 0px 0px 10px 10px;
+}
+
+/*CSS accordion for toggling extra metadata*/
+/*-----------------------------------------*/
+#civitai_preview_html .accordionCheckbox {
+ position: absolute;
+ opacity: 0;
+ z-index: -1;
+}
+
+#civitai_preview_html .tabs {
+ border-radius: 10px;
+ overflow: hidden;
+}
+
+#civitai_preview_html .tab {
+ width: 100%;
+ color: white;
+ overflow: hidden;
+ margin-left: -15px;
+}
+
+#civitai_preview_html .tab-label {
+ display: flex;
+ padding: 1em;
+ font-weight: bold;
+ cursor: pointer;
+ font-size: large;
+}
+
+/* Icon */
+#civitai_preview_html .tab-label::before {
+ content: "❯";
+ width: 1em;
+ height: 1em;
+ text-align: center;
+ transition: all 0.3s;
+}
+
+#civitai_preview_html .accordionCheckbox:checked + .tab-label::before {
+ transform: rotate(90deg);
+}
+
+#civitai_preview_html .tab-content {
+ max-height: 0;
+ padding: 0 1em;
+ transition: all 0.3s;
+}
+
+#civitai_preview_html .tab-close {
+ display: flex;
+ justify-content: flex-end;
+ padding: 1em;
+ font-size: 0.75em;
+ cursor: pointer;
+}
+
+#civitai_preview_html .accordionCheckbox:checked ~ .tab-content {
+ max-height: 200vh;
+ padding: 1em;
+}
+/*-----------------------------------------*/
+/*End CSS accordion for toggling extra metadata*/
\ No newline at end of file
diff --git a/sd-civitai-browser-plus/style_html.css b/sd-civitai-browser-plus/style_html.css
new file mode 100644
index 0000000000000000000000000000000000000000..92568e763aeb6a949f65ba98016b32e82f44046d
--- /dev/null
+++ b/sd-civitai-browser-plus/style_html.css
@@ -0,0 +1,244 @@
+body {
+ background-color: #0b0f19;
+}
+
+.model-block {
+ box-shadow: 0px 0px 1px 3px #3339ff30;
+ border-radius: 10px;
+ padding: 1px 20px 10px;
+ margin-bottom: 20px;
+}
+
+.model-block code {
+ white-space: pre-wrap;
+}
+
+.model-block dl {
+ overflow-wrap: anywhere;
+}
+
+.civnsfw img {
+ filter: unset;
+}
+
+.sampleimgs .model-block img,
+.sampleimgs .model-block video {
+ padding-top: 1em;
+ max-width: 20em;
+ cursor: zoom-in;
+ transition: max-width 0.1s;
+}
+
+/* Text adjustments */
+h1, h2, h3, h4, h5, dd, dt, p, a, label {
+ font-family: 'Source Sans Pro', 'ui-sans-serif', 'system-ui', sans-serif;
+}
+
+h3 {
+ font-size: 16px;
+}
+
+h2 {
+ margin: 16px 0px 8px;
+ font-size: 22px;
+}
+
+ul {
+ padding-left: 18px;
+}
+
+p {
+ color: #F3F4F6;
+ margin: 0px 0px 6px;
+ font-size: 14px;
+}
+
+dt {
+ font-size: medium;
+ color: #80a6c8 !important;
+ font-size: 16px;
+}
+
+dd {
+ padding: 0px 0px 5px 10px;
+ margin-inline-start: 0px;
+ font-size: 14px;
+ color: #F3F4F6;
+}
+
+a {
+ color: #F3F4F6;
+ font-weight: bold;
+ text-decoration: unset;
+}
+
+a:hover {
+ color: #60A5FA;
+}
+
+.civitai_txt2img {
+ display: none;
+}
+
+/* Preview Image zoom */
+.zoom-radio {
+ display: none !important;
+}
+
+/* Style for when the image is clicked (radio button checked) */
+.zoom-radio:checked + label > img,
+.zoom-radio:checked + label > video {
+ max-width: 95vw;
+ max-height: 95vh;
+ padding-top: 0px;
+ cursor: zoom-out;
+ position: fixed;
+ top: 50%;
+ left: 50%;
+ transform: translate(-50%, -50%);
+ z-index: 1000; /* Higher than the overlay */
+ pointer-events: none; /* Allow clicks to penetrate through to the overlay for resetting */
+}
+
+/* Overlay for resetting zoomed state */
+.zoom-overlay {
+ display: none;
+ position: fixed;
+ top: 0;
+ left: 0;
+ right: 0;
+ bottom: 0;
+ background: rgba(0, 0, 0, .5);
+ z-index: 999; /* Below the zoomed image */
+ cursor: zoom-out;
+}
+
+.zoom-radio:checked + label + .zoom-overlay {
+ display: block;
+ pointer-events: all; /* Capture click events when displayed */
+}
+
+.zoom-img-container {
+ min-width: 20em;
+}
+
+.model-uploader {
+ border-bottom: 1px solid;
+ padding-bottom: 10px;
+ color: white;
+ }
+
+.model-description {
+ border-top: 1px solid;
+ padding-bottom: 10px;
+ margin-bottom: 10px;
+ color: white;
+ }
+
+/*Avatar CSS mostly copied from CivitAI, but 48px instead of 32px*/
+.avatar {
+ user-select: none;
+ overflow: hidden;
+ width: 48px;
+ height: 48px;
+ min-width: 48px;
+ border-radius: 48px;
+ text-decoration: none;
+ border: 0;
+ padding: 0;
+ background-color: rgba(0,0,0,0.31);
+ display: inline-block!important;
+ margin-left: 5px!important;
+ vertical-align: middle;
+}
+
+.avatar img {
+ object-fit: cover;
+ width: 100%;
+ height: 100%;
+ display: block;
+ overflow-clip-margin: content-box;
+ overflow: clip;
+ border-style: none;
+}
+
+/*CSS accordion for toggling extra metadata*/
+/*-----------------------------------------*/
+.accordionCheckbox {
+ position: absolute;
+ opacity: 0;
+ z-index: -1;
+}
+
+.tabs {
+ border-radius: 10px;
+ overflow: hidden;
+}
+
+.tab {
+ width: 100%;
+ color: white;
+ overflow: hidden;
+ margin-left: -15px;
+}
+
+.tab-label {
+ display: flex;
+ padding: 1em;
+ font-weight: bold;
+ cursor: pointer;
+ font-size: large;
+}
+
+.civitai-tags-container {
+ display: flex;
+ flex-wrap: wrap;
+ gap: 5px;
+}
+
+.civitai-tag,
+.civitai-meta-btn {
+ background-color: #111827;
+ border-radius: 8px;
+ padding: 4px 6px;
+ border: 1px solid #374151;
+}
+
+.civitai-meta-btn:hover {
+ cursor: pointer;
+ background-color: #1F2937;
+}
+
+/* Icon */
+.tab-label::before {
+ content: "❯";
+ width: 1em;
+ height: 1em;
+ text-align: center;
+ transition: all 0.3s;
+}
+
+.accordionCheckbox:checked + .tab-label::before {
+ transform: rotate(90deg);
+}
+
+.tab-content {
+ max-height: 0;
+ padding: 0 1em;
+ transition: all 0.3s;
+}
+
+.tab-close {
+ display: flex;
+ justify-content: flex-end;
+ padding: 1em;
+ font-size: 0.75em;
+ cursor: pointer;
+}
+
+.accordionCheckbox:checked ~ .tab-content {
+ max-height: 100vh;
+ padding: 1em;
+}
+/*-----------------------------------------*/
+/*End CSS accordion for toggling extra metadata*/
\ No newline at end of file
diff --git a/sd-webui-animatediff/.github/FUNDING.yml b/sd-webui-animatediff/.github/FUNDING.yml
new file mode 100644
index 0000000000000000000000000000000000000000..436b51d30985dd2e231734262ab2e35e2461a86c
--- /dev/null
+++ b/sd-webui-animatediff/.github/FUNDING.yml
@@ -0,0 +1,13 @@
+# These are supported funding model platforms
+
+github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
+patreon: conrevo # Replace with a single Patreon username
+open_collective: # Replace with a single Open Collective username
+ko_fi: conrevo # Replace with a single Ko-fi username
+tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
+community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
+liberapay: # Replace with a single Liberapay username
+issuehunt: # Replace with a single IssueHunt username
+otechie: # Replace with a single Otechie username
+lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
+custom: ['https://paypal.me/conrevo', 'https://afdian.net/a/conrevo'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
diff --git a/sd-webui-animatediff/.github/ISSUE_TEMPLATE/bug_report.yml b/sd-webui-animatediff/.github/ISSUE_TEMPLATE/bug_report.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e96f8cd6e416e5c3446955f6a442e49ec5c2a094
--- /dev/null
+++ b/sd-webui-animatediff/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -0,0 +1,91 @@
+name: Bug Report
+description: Create a bug report
+title: "[Bug]: "
+labels: ["bug-report"]
+
+body:
+ - type: checkboxes
+ attributes:
+ label: Is there an existing issue for this?
+ description: Please search both open issues and closed issues to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
+ options:
+ - label: I have searched the existing issues and checked the recent builds/commits of both this extension and the webui
+ required: true
+ - type: checkboxes
+ attributes:
+ label: Have you read FAQ on README?
+ description: I have collected some common questions from AnimateDiff original repository.
+ options:
+ - label: I have updated WebUI and this extension to the latest version
+ required: true
+ - type: markdown
+ attributes:
+ value: |
+ *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
+ - type: textarea
+ id: what-did
+ attributes:
+ label: What happened?
+ description: Tell us what happened in a very clear and simple way
+ validations:
+ required: true
+ - type: textarea
+ id: steps
+ attributes:
+ label: Steps to reproduce the problem
+ description: Please provide us with precise step by step information on how to reproduce the bug
+ value: |
+ 1. Go to ....
+ 2. Press ....
+ 3. ...
+ validations:
+ required: true
+ - type: textarea
+ id: what-should
+ attributes:
+ label: What should have happened?
+ description: Tell what you think the normal behavior should be
+ validations:
+ required: true
+ - type: textarea
+ id: commits
+ attributes:
+ label: Commit where the problem happens
+ description: Which commit of the extension are you running on? Please include the commit of both the extension and the webui (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
+ value: |
+ webui:
+ extension:
+ validations:
+ required: true
+ - type: dropdown
+ id: browsers
+ attributes:
+ label: What browsers do you use to access the UI ?
+ multiple: true
+ options:
+ - Mozilla Firefox
+ - Google Chrome
+ - Brave
+ - Apple Safari
+ - Microsoft Edge
+ - type: textarea
+ id: cmdargs
+ attributes:
+ label: Command Line Arguments
+ description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
+ render: Shell
+ validations:
+ required: true
+ - type: textarea
+ id: logs
+ attributes:
+ label: Console logs
+ description: Please provide the errors printed on your console log of your browser (type F12 and go to console) and your terminal, after your bug happened.
+ render: Shell
+ validations:
+ required: true
+ - type: textarea
+ id: misc
+ attributes:
+ label: Additional information
+ description: Please provide us with any relevant additional info or context.
diff --git a/sd-webui-animatediff/.github/ISSUE_TEMPLATE/config.yml b/sd-webui-animatediff/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0086358db1eb971c0cfa8739c27518bbc18a5ff4
--- /dev/null
+++ b/sd-webui-animatediff/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1 @@
+blank_issues_enabled: true
diff --git a/sd-webui-animatediff/.github/ISSUE_TEMPLATE/feature_request.yml b/sd-webui-animatediff/.github/ISSUE_TEMPLATE/feature_request.yml
new file mode 100644
index 0000000000000000000000000000000000000000..834d9c5ffca3e218ab68ac2621982f6d859b330b
--- /dev/null
+++ b/sd-webui-animatediff/.github/ISSUE_TEMPLATE/feature_request.yml
@@ -0,0 +1,13 @@
+name: Feature Request
+description: Create a feature request
+title: "[Feature]: "
+labels: ["feature-request"]
+
+body:
+ - type: textarea
+ id: feature
+ attributes:
+ label: Expected behavior
+ description: Please describe the feature you want.
+ validations:
+ required: true
\ No newline at end of file
diff --git a/sd-webui-animatediff/.gitignore b/sd-webui-animatediff/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..98fba9a2b6dc1da2eb4fa419d6c6b6663009172b
--- /dev/null
+++ b/sd-webui-animatediff/.gitignore
@@ -0,0 +1,4 @@
+__pycache__
+model/*.*
+model/*.*
+TODO.md
\ No newline at end of file
diff --git a/sd-webui-animatediff/README.md b/sd-webui-animatediff/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f0866438fcacfa696faf70975286dc45d0f0a880
--- /dev/null
+++ b/sd-webui-animatediff/README.md
@@ -0,0 +1,353 @@
+# AnimateDiff for Stable Diffusion WebUI
+This extension aim for integrating [AnimateDiff](https://github.com/guoyww/AnimateDiff/) w/ [CLI](https://github.com/s9roll7/animatediff-cli-prompt-travel) into [AUTOMATIC1111 Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) w/ [ControlNet](https://github.com/Mikubill/sd-webui-controlnet). You can generate GIFs in exactly the same way as generating images after enabling this extension.
+
+This extension implements AnimateDiff in a different way. It does not require you to clone the whole SD1.5 repository. It also applied (probably) the least modification to `ldm`, so that you do not need to reload your model weights if you don't want to.
+
+You might also be interested in another extension I created: [Segment Anything for Stable Diffusion WebUI](https://github.com/continue-revolution/sd-webui-segment-anything).
+
+[Forge](https://github.com/lllyasviel/stable-diffusion-webui-forge) users should either checkout branch [forge/master](https://github.com/continue-revolution/sd-webui-animatediff/tree/forge/master) in this repository or use [sd-forge-animatediff](https://github.com/continue-revolution/sd-forge-animatediff). They will be in sync.
+
+[TusiArt](https://tusiart.com/) (for users physically inside P.R.China mainland) and [TensorArt](https://tusiart.com/) (for users outside P.R.China mainland) offers online service of this extension.
+
+## Table of Contents
+- [Update](#update)
+- [How to Use](#how-to-use)
+ - [WebUI](#webui)
+ - [API](#api)
+- [WebUI Parameters](#webui-parameters)
+- [Img2GIF](#img2gif)
+- [Prompt Travel](#prompt-travel)
+- [ControlNet V2V](#controlnet-v2v)
+- [Model Spec](#model-spec)
+ - [Motion LoRA](#motion-lora)
+ - [V3](#v3)
+ - [SDXL](#sdxl)
+- [Optimizations](#optimizations)
+ - [Attention](#attention)
+ - [FP8](#fp8)
+ - [LCM](#lcm)
+ - [Others](#others)
+- [Model Zoo](#model-zoo)
+- [VRAM](#vram)
+- [Batch Size](#batch-size)
+- [Demo](#demo)
+ - [Basic Usage](#basic-usage)
+ - [Motion LoRA](#motion-lora-1)
+ - [Prompt Travel](#prompt-travel-1)
+ - [AnimateDiff V3](#animatediff-v3)
+ - [AnimateDiff SDXL](#animatediff-sdxl)
+ - [ControlNet V2V](#controlnet-v2v-1)
+- [Tutorial](#tutorial)
+- [Thanks](#thanks)
+- [Star History](#star-history)
+- [Sponsor](#sponsor)
+
+
+## Update
+- `2023/07/20` [v1.1.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.1.0): Fix gif duration, add loop number, remove auto-download, remove xformers, remove instructions on gradio UI, refactor README, add [sponsor](#sponsor) QR code.
+- `2023/07/24` [v1.2.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.2.0): Fix incorrect insertion of motion modules, add option to change path to motion modules in `Settings/AnimateDiff`, fix loading different motion modules.
+- `2023/09/04` [v1.3.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.3.0): Support any community models with the same architecture; fix grey problem via [#63](https://github.com/continue-revolution/sd-webui-animatediff/issues/63)
+- `2023/09/11` [v1.4.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.4.0): Support official v2 motion module (different architecture: GroupNorm not hacked, UNet middle layer has motion module).
+- `2023/09/14`: [v1.4.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.4.1): Always change `beta`, `alpha_comprod` and `alpha_comprod_prev` to resolve grey problem in other samplers.
+- `2023/09/16`: [v1.5.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.5.0): Randomize init latent to support [better img2gif](#img2gif); add other output formats and infotext output; add appending reversed frames; refactor code to ease maintaining.
+- `2023/09/19`: [v1.5.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.5.1): Support xformers, sdp, sub-quadratic attention optimization - [VRAM](#vram) usage decrease to 5.60GB with default setting.
+- `2023/09/22`: [v1.5.2](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.5.2): Option to disable xformers at `Settings/AnimateDiff` [due to a bug in xformers](https://github.com/facebookresearch/xformers/issues/845), [API support](#api), option to enable GIF paletter optimization at `Settings/AnimateDiff`, gifsicle optimization move to `Settings/AnimateDiff`.
+- `2023/09/25`: [v1.6.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.6.0): [Motion LoRA](https://github.com/guoyww/AnimateDiff#features) supported. See [Motion Lora](#motion-lora) for more information.
+- `2023/09/27`: [v1.7.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.7.0): [ControlNet](https://github.com/Mikubill/sd-webui-controlnet) supported. See [ControlNet V2V](#controlnet-v2v) for more information. [Safetensors](#model-zoo) for some motion modules are also available now.
+- `2023/09/29`: [v1.8.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.8.0): Infinite generation supported. See [WebUI Parameters](#webui-parameters) for more information.
+- `2023/10/01`: [v1.8.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.8.1): Now you can uncheck `Batch cond/uncond` in `Settings/Optimization` if you want. This will reduce your [VRAM](#vram) (5.31GB -> 4.21GB for SDP) but take longer time.
+- `2023/10/08`: [v1.9.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.0): Prompt travel supported. You must have ControlNet installed (you do not need to enable ControlNet) to try it. See [Prompt Travel](#prompt-travel) for how to trigger this feature.
+- `2023/10/11`: [v1.9.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.1): Use state_dict key to guess mm version, replace match case with if else to support python<3.10, option to save PNG to custom dir
+ (see `Settings/AnimateDiff` for detail), move hints to js, install imageio\[ffmpeg\] automatically when MP4 save fails.
+- `2023/10/16`: [v1.9.2](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.2): Add context generator to completely remove any closed loop, prompt travel support closed loop, infotext fully supported including prompt travel, README refactor
+- `2023/10/19`: [v1.9.3](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.3): Support webp output format. See [#233](https://github.com/continue-revolution/sd-webui-animatediff/pull/233) for more information.
+- `2023/10/21`: [v1.9.4](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.4): Save prompt travel to output images, `Reverse` merged to `Closed loop` (See [WebUI Parameters](#webui-parameters)), remove `TimestepEmbedSequential` hijack, remove `hints.js`, better explanation of several context-related parameters.
+- `2023/10/25`: [v1.10.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.10.0): Support img2img batch. You need ControlNet installed to make it work properly (you do not need to enable ControlNet). See [ControlNet V2V](#controlnet-v2v) for more information.
+- `2023/10/29`: [v1.11.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.11.0): [HotShot-XL](https://github.com/hotshotco/Hotshot-XL) supported. See [SDXL](#sdxl) for more information.
+- `2023/11/06`: [v1.11.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.11.1): Optimize VRAM for ControlNet V2V, patch [encode_pil_to_base64](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/api/api.py#L104-L133) for api return a video, save frames to `AnimateDiff/yy-mm-dd/`, recover from assertion error, optional [request id](#api) for API.
+- `2023/11/10`: [v1.12.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.12.0): [AnimateDiff for SDXL](https://github.com/guoyww/AnimateDiff/tree/sdxl) supported. See [SDXL](#sdxl) for more information.
+- `2023/11/16`: [v1.12.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.12.1): FP8 precision and LCM sampler supported. See [Optimizations](#optimizations) for more information. You can also optionally upload videos to AWS S3 storage by configuring appropriately via `Settings/AnimateDiff AWS`.
+- `2023/12/19`: [v1.13.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.13.0): [AnimateDiff V3](https://github.com/guoyww/AnimateDiff?tab=readme-ov-file#202312-animatediff-v3-and-sparsectrl) supported. See [V3](#v3) for more information. Also: release all official models in fp16 & safetensors format [here](https://huggingface.co/conrevo/AnimateDiff-A1111/tree/main), add option to disable LCM sampler in `Settings/AnimateDiff`, remove patch [encode_pil_to_base64](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/api/api.py#L104-L133) because A1111 [v1.7.0](https://github.com/AUTOMATIC1111/stable-diffusion-webui/tree/v1.7.0) now supports video return for API.
+- `2024/01/12`: [v1.13.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.13.1): This small version update completely comes from the community. We fix mp4 encode error [#402](https://github.com/continue-revolution/sd-webui-animatediff/pull/402), support infotext copy-paste [#400](https://github.com/continue-revolution/sd-webui-animatediff/pull/400), validate prompt travel frame numbers [#401](https://github.com/continue-revolution/sd-webui-animatediff/pull/401).
+
+For `v2` update progress, please query [#381](https://github.com/continue-revolution/sd-webui-animatediff/pull/381). For updates (most likely) later than v2, please query [#366](https://github.com/continue-revolution/sd-webui-animatediff/pull/366). All update checklist is tentative and subject to change. `v1.13.x` is the last version update for `v1`. SparseCtrl, Magic Animate and other control methods will be supported in `v2` via updating both this repo and sd-webui-controlnet.
+
+
+## How to Use
+1. Update your WebUI to v1.6.0 and ControlNet to v1.1.410, then install this extension via link. I do not plan to support older version.
+1. Download motion modules and put the model weights under `stable-diffusion-webui/extensions/sd-webui-animatediff/model/`. If you want to use another directory to save model weights, please go to `Settings/AnimateDiff`. See [model zoo](#model-zoo) for a list of available motion modules.
+1. Enable `Pad prompt/negative prompt to be same length` in `Settings/Optimization` and click `Apply settings`. You must do this to prevent generating two separate unrelated GIFs. Checking `Batch cond/uncond` is optional, which can improve speed but increase VRAM usage.
+1. DO NOT disable hash calculation, otherwise AnimateDiff will have trouble figuring out when you switch motion module.
+
+### WebUI
+1. Go to txt2img if you want to try txt2gif and img2img if you want to try img2gif.
+1. Choose an SD1.5 checkpoint, write prompts, set configurations such as image width/height. If you want to generate multiple GIFs at once, please [change batch number, instead of batch size](#batch-size).
+1. Enable AnimateDiff extension, set up [each parameter](#webui-parameters), then click `Generate`.
+1. You should see the output GIF on the output gallery. You can access GIF output at `stable-diffusion-webui/outputs/{txt2img or img2img}-images/AnimateDiff/{yy-mm-dd}`. You can also access image frames at `stable-diffusion-webui/outputs/{txt2img or img2img}-images/{yy-mm-dd}`. You may choose to save frames for each generation into separate directories in `Settings/AnimateDiff`.
+
+### API
+It is quite similar to the way you use ControlNet. API will return a video in base64 format. In `format`, `PNG` means to save frames to your file system without returning all the frames. If you want your API to return all frames, please add `Frame` to `format` list. For most up-to-date parameters, please read [here](https://github.com/continue-revolution/sd-webui-animatediff/blob/master/scripts/animatediff_ui.py#L26).
+```
+'alwayson_scripts': {
+ 'AnimateDiff': {
+ 'args': [{
+ 'model': 'mm_sd_v15_v2.ckpt', # Motion module
+ 'format': ['GIF'], # Save format, 'GIF' | 'MP4' | 'PNG' | 'WEBP' | 'WEBM' | 'TXT' | 'Frame'
+ 'enable': True, # Enable AnimateDiff
+ 'video_length': 16, # Number of frames
+ 'fps': 8, # FPS
+ 'loop_number': 0, # Display loop number
+ 'closed_loop': 'R+P', # Closed loop, 'N' | 'R-P' | 'R+P' | 'A'
+ 'batch_size': 16, # Context batch size
+ 'stride': 1, # Stride
+ 'overlap': -1, # Overlap
+ 'interp': 'Off', # Frame interpolation, 'Off' | 'FILM'
+ 'interp_x': 10 # Interp X
+ 'video_source': 'path/to/video.mp4', # Video source
+ 'video_path': 'path/to/frames', # Video path
+ 'latent_power': 1, # Latent power
+ 'latent_scale': 32, # Latent scale
+ 'last_frame': None, # Optional last frame
+ 'latent_power_last': 1, # Optional latent power for last frame
+ 'latent_scale_last': 32,# Optional latent scale for last frame
+ 'request_id': '' # Optional request id. If provided, outputs will have request id as filename suffix
+ }
+ ]
+ }
+},
+```
+
+
+## WebUI Parameters
+1. **Save format** — Format of the output. Choose at least one of "GIF"|"MP4"|"WEBP"|"WEBM"|"PNG". Check "TXT" if you want infotext, which will live in the same directory as the output GIF. Infotext is also accessible via `stable-diffusion-webui/params.txt` and outputs in all formats.
+ 1. You can optimize GIF with `gifsicle` (`apt install gifsicle` required, read [#91](https://github.com/continue-revolution/sd-webui-animatediff/pull/91) for more information) and/or `palette` (read [#104](https://github.com/continue-revolution/sd-webui-animatediff/pull/104) for more information). Go to `Settings/AnimateDiff` to enable them.
+ 1. You can set quality and lossless for WEBP via `Settings/AnimateDiff`. Read [#233](https://github.com/continue-revolution/sd-webui-animatediff/pull/233) for more information.
+ 1. If you are using API, by adding "PNG" to `format`, you can save all frames to your file system without returning all the frames. If you want your API to return all frames, please add `Frame` to `format` list.
+1. **Number of frames** — Choose whatever number you like.
+
+ If you enter 0 (default):
+ - If you submit a video via `Video source` / enter a video path via `Video path` / enable ANY batch ControlNet, the number of frames will be the number of frames in the video (use shortest if more than one videos are submitted).
+ - Otherwise, the number of frames will be your `Context batch size` described below.
+
+ If you enter something smaller than your `Context batch size` other than 0: you will get the first `Number of frames` frames as your output GIF from your whole generation. All following frames will not appear in your generated GIF, but will be saved as PNGs as usual. Do not set `Number of frames` to be something smaler than `Context batch size` other than 0 because of [#213](https://github.com/continue-revolution/sd-webui-animatediff/issues/213).
+1. **FPS** — Frames per second, which is how many frames (images) are shown every second. If 16 frames are generated at 8 frames per second, your GIF’s duration is 2 seconds. If you submit a source video, your FPS will be the same as the source video.
+1. **Display loop number** — How many times the GIF is played. A value of `0` means the GIF never stops playing.
+1. **Context batch size** — How many frames will be passed into the motion module at once. The SD1.5 motion modules are trained with 16 frames, so it’ll give the best results when the number of frames is set to `16`. SDXL HotShotXL motion modules are trained with 8 frames instead. Choose [1, 24] for V1 / HotShotXL motion modules and [1, 32] for V2 / AnimateDiffXL motion modules.
+1. **Closed loop** — Closed loop means that this extension will try to make the last frame the same as the first frame.
+ 1. When `Number of frames` > `Context batch size`, including when ControlNet is enabled and the source video frame number > `Context batch size` and `Number of frames` is 0, closed loop will be performed by AnimateDiff infinite context generator.
+ 1. When `Number of frames` <= `Context batch size`, AnimateDiff infinite context generator will not be effective. Only when you choose `A` will AnimateDiff append reversed list of frames to the original list of frames to form closed loop.
+
+ See below for explanation of each choice:
+
+ - `N` means absolutely no closed loop - this is the only available option if `Number of frames` is smaller than `Context batch size` other than 0.
+ - `R-P` means that the extension will try to reduce the number of closed loop context. The prompt travel will not be interpolated to be a closed loop.
+ - `R+P` means that the extension will try to reduce the number of closed loop context. The prompt travel will be interpolated to be a closed loop.
+ - `A` means that the extension will aggressively try to make the last frame the same as the first frame. The prompt travel will be interpolated to be a closed loop.
+1. **Stride** — Max motion stride as a power of 2 (default: 1).
+ 1. Due to the limitation of the infinite context generator, this parameter is effective only when `Number of frames` > `Context batch size`, including when ControlNet is enabled and the source video frame number > `Context batch size` and `Number of frames` is 0.
+ 1. "Absolutely no closed loop" is only possible when `Stride` is 1.
+ 1. For each 1 <= $2^i$ <= `Stride`, the infinite context generator will try to make frames $2^i$ apart temporal consistent. For example, if `Stride` is 4 and `Number of frames` is 8, it will make the following frames temporal consistent:
+ - `Stride` == 1: [0, 1, 2, 3, 4, 5, 6, 7]
+ - `Stride` == 2: [0, 2, 4, 6], [1, 3, 5, 7]
+ - `Stride` == 4: [0, 4], [1, 5], [2, 6], [3, 7]
+1. **Overlap** — Number of frames to overlap in context. If overlap is -1 (default): your overlap will be `Context batch size` // 4.
+ 1. Due to the limitation of the infinite context generator, this parameter is effective only when `Number of frames` > `Context batch size`, including when ControlNet is enabled and the source video frame number > `Context batch size` and `Number of frames` is 0.
+1. **Frame Interpolation** — Interpolate between frames with Deforum's FILM implementation. Requires Deforum extension. [#128](https://github.com/continue-revolution/sd-webui-animatediff/pull/128)
+1. **Interp X** — Replace each input frame with X interpolated output frames. [#128](https://github.com/continue-revolution/sd-webui-animatediff/pull/128).
+1. **Video source** — [Optional] Video source file for [ControlNet V2V](#controlnet-v2v). You MUST enable ControlNet. It will be the source control for ALL ControlNet units that you enable without submitting a control image or a path to ControlNet panel. You can of course submit one control image via `Single Image` tab or an input directory via `Batch` tab, which will override this video source input and work as usual.
+1. **Video path** — [Optional] Folder for source frames for [ControlNet V2V](#controlnet-v2v), but lower priority than `Video source`. You MUST enable ControlNet. It will be the source control for ALL ControlNet units that you enable without submitting a control image or a path to ControlNet. You can of course submit one control image via `Single Image` tab or an input directory via `Batch` tab, which will override this video path input and work as usual.
+ - For people who want to inpaint videos: enter a folder which contains two sub-folders `image` and `mask` on ControlNet inpainting unit. These two sub-folders should contain the same number of images. This extension will match them according to the same sequence. Using my [Segment Anything](https://github.com/continue-revolution/sd-webui-segment-anything) extension can make your life much easier.
+
+Please read
+- [Img2GIF](#img2gif) for extra parameters on img2gif panel.
+- [Prompt Travel](#prompt-travel) for how to trigger prompt travel.
+- [ControlNet V2V](#controlnet-v2v) for how to use ControlNet V2V.
+- [Model Spec](#model-spec) for how to use Motion LoRA, V3 and SDXL.
+
+
+## Img2GIF
+You need to go to img2img and submit an init frame via A1111 panel. You can optionally submit a last frame via extension panel.
+
+By default: your `init_latent` will be changed to
+```
+init_alpha = (1 - frame_number ^ latent_power / latent_scale)
+init_latent = init_latent * init_alpha + random_tensor * (1 - init_alpha)
+```
+
+If you upload a last frame: your `init_latent` will be changed in a similar way. Read [this code](https://github.com/continue-revolution/sd-webui-animatediff/tree/v1.5.0/scripts/animatediff_latent.py#L28-L65) to understand how it works.
+
+
+## Prompt Travel
+Write positive prompt following the example below.
+
+The first line is head prompt, which is optional. You can write no/single/multiple lines of head prompts.
+
+The second and third lines are for prompt interpolation, in format `frame number`: `prompt`. Your `frame number` should be in ascending order, smaller than the total `Number of frames`. The first frame is 0 index.
+
+The last line is tail prompt, which is optional. You can write no/single/multiple lines of tail prompts. If you don't need this feature, just write prompts in the old way.
+```
+1girl, yoimiya (genshin impact), origen, line, comet, wink, Masterpiece, BestQuality. UltraDetailed, , ,
+0: closed mouth
+8: open mouth
+smile
+```
+
+
+## ControlNet V2V
+You need to go to txt2img / img2img-batch and submit source video or path to frames. Each ControlNet will find control images according to this priority:
+1. ControlNet `Single Image` tab or `Batch` tab. Simply upload a control image or a directory of control frames is enough.
+1. Img2img Batch tab `Input directory` if you are using img2img batch. If you upload a directory of control frames, it will be the source control for ALL ControlNet units that you enable without submitting a control image or a path to ControlNet panel.
+1. AnimateDiff `Video Source`. If you upload a video through `Video Source`, it will be the source control for ALL ControlNet units that you enable without submitting a control image or a path to ControlNet panel.
+1. AnimateDiff `Video Path`. If you upload a path to frames through `Video Path`, it will be the source control for ALL ControlNet units that you enable without submitting a control image or a path to ControlNet panel.
+
+`Number of frames` will be capped to the minimum number of images among all **folders** you provide. Each control image in each folder will be applied to one single frame. If you upload one single image for a ControlNet unit, that image will control **ALL** frames.
+
+For people who want to inpaint videos: enter a folder which contains two sub-folders `image` and `mask` on ControlNet inpainting unit. These two sub-folders should contain the same number of images. This extension will match them according to the same sequence. Using my [Segment Anything](https://github.com/continue-revolution/sd-webui-segment-anything) extension can make your life much easier.
+
+AnimateDiff in img2img batch will be available in [v1.10.0](https://github.com/continue-revolution/sd-webui-animatediff/pull/224).
+
+
+## Model Spec
+### Motion LoRA
+[Download](https://huggingface.co/conrevo/AnimateDiff-A1111/tree/main/lora) and use them like any other LoRA you use (example: download motion lora to `stable-diffusion-webui/models/Lora` and add `` to your positive prompt). **Motion LoRA only supports V2 motion modules**.
+
+### V3
+V3 has identical state dict keys as V1 but slightly different inference logic (GroupNorm is not hacked for V3). You may optionally use [adapter](https://huggingface.co/conrevo/AnimateDiff-A1111/resolve/main/lora/mm_sd15_v3_adapter.safetensors?download=true) for V3, in the same way as the way you use LoRA. You MUST use [my link](https://huggingface.co/conrevo/AnimateDiff-A1111/resolve/main/lora/mm_sd15_v3_adapter.safetensors?download=true) instead of the [official link](https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_adapter.ckpt?download=true). The official adapter won't work for A1111 due to state dict incompatibility.
+
+### SDXL
+[AnimateDiffXL](https://github.com/guoyww/AnimateDiff/tree/sdxl) and [HotShot-XL](https://github.com/hotshotco/Hotshot-XL) have identical architecture to AnimateDiff-SD1.5. The only 2 difference are
+- HotShot-XL is trained with 8 frames instead of 16 frames. You are recommended to set `Context batch size` to 8 for HotShot-XL.
+- AnimateDiffXL is still trained with 16 frames. You do not need to change `Context batch size` for AnimateDiffXL.
+- AnimateDiffXL & HotShot-XL have fewer layers compared to AnimateDiff-SD1.5 because of SDXL.
+- AnimateDiffXL is trained with higher resolution compared to HotShot-XL.
+
+Although AnimateDiffXL & HotShot-XL have identical structure with AnimateDiff-SD1.5, I strongly discourage you from using AnimateDiff-SD1.5 for SDXL, or using HotShot / AnimateDiffXL for SD1.5 - you will get severe artifect if you do that. I have decided not to supported that, despite the fact that it is not hard for me to do that.
+
+Technically all features available for AnimateDiff + SD1.5 are also available for (AnimateDiff / HotShot) + SDXL. However, I have not tested all of them. I have tested infinite context generation and prompt travel; I have not tested ControlNet. If you find any bug, please report it to me.
+
+
+## Optimizations
+Optimizations can be significantly helpful if you want to improve speed and reduce VRAM usage. With [attention optimization](#attention), [FP8](#fp8) and unchecking `Batch cond/uncond` in `Settings/Optimization`, I am able to run 4 x ControlNet + AnimateDiff + Stable Diffusion to generate 36 frames of 1024 * 1024 images with 18GB VRAM.
+
+### Attention
+Adding `--xformers` / `--opt-sdp-attention` to your command lines can significantly reduce VRAM and improve speed. However, due to a bug in xformers, you may or may not get CUDA error. If you get CUDA error, please either completely switch to `--opt-sdp-attention`, or preserve `--xformers` -> go to `Settings/AnimateDiff` -> choose "Optimize attention layers with sdp (torch >= 2.0.0 required)".
+
+### FP8
+FP8 requires torch >= 2.1.0 and WebUI [test-fp8](https://github.com/AUTOMATIC1111/stable-diffusion-webui/tree/test-fp8) branch by [@KohakuBlueleaf](https://github.com/KohakuBlueleaf). Follow these steps to enable FP8:
+1. Switch to `test-fp8` branch via `git checkout test-fp8` in your `stable-diffusion-webui` directory.
+1. Reinstall torch via adding `--reinstall-torch` ONCE to your command line arguments.
+1. Goto Settings Tab > Optimizations > FP8 weight, change it to `Enable`
+
+### LCM
+[Latent Consistency Model](https://github.com/luosiallen/latent-consistency-model) is a recent breakthrough in Stable Diffusion community. I provide a "gift" to everyone who update this extension to >= [v1.12.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.12.1) - you will find `LCM` sampler in the normal place you select samplers in WebUI. You can generate images / videos within 6-8 steps if you
+- select `Euler A` / `Euler` / `LCM` sampler (other samplers may also work, subject to further experiments)
+- use [LCM LoRA](https://civitai.com/models/195519/lcm-lora-weights-stable-diffusion-acceleration-module)
+- use a low CFG denoising strength (1-2 is recommended)
+
+Note that LCM sampler is still under experiment and subject to change adhering to [@luosiallen](https://github.com/luosiallen)'s wish.
+
+Benefits of using this extension instead of [sd-webui-lcm](https://github.com/0xbitches/sd-webui-lcm) are
+- you do not need to install diffusers
+- you can use LCM sampler with any other extensions, such as ControlNet and AnimateDiff
+
+### Others
+- Remove any VRAM heavy arguments such as `--no-half`. These arguments can significantly increase VRAM usage and reduce speed.
+- Check `Batch cond/uncond` in `Settings/Optimization` to improve speed; uncheck it to reduce VRAM usage.
+
+
+## Model Zoo
+I am maintaining a [huggingface repo](https://huggingface.co/conrevo/AnimateDiff-A1111/tree/main) to provide all official models in fp16 & safetensors format. You are highly recommended to use my link. You MUST use my link to download adapter for V3. You may still use the old links if you want, for all models except adapter for V3.
+
+- "Official" models by [@guoyww](https://github.com/guoyww): [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI) | [HuggingFace](https://huggingface.co/guoyww/animatediff/tree/main) | [CivitAI](https://civitai.com/models/108836)
+- "Stabilized" community models by [@manshoety](https://huggingface.co/manshoety): [HuggingFace](https://huggingface.co/manshoety/AD_Stabilized_Motion/tree/main)
+- "TemporalDiff" models by [@CiaraRowles](https://huggingface.co/CiaraRowles): [HuggingFace](https://huggingface.co/CiaraRowles/TemporalDiff/tree/main)
+- "HotShotXL" models by [@hotshotco](https://huggingface.co/hotshotco/): [HuggingFace](https://huggingface.co/hotshotco/Hotshot-XL/tree/main)
+
+
+## VRAM
+Actual VRAM usage depends on your image size and context batch size. You can try to reduce image size or context batch size to reduce VRAM usage.
+
+The following data are SD1.5 + AnimateDiff, tested on Ubuntu 22.04, NVIDIA 4090, torch 2.0.1+cu117, H=W=512, frame=16 (default setting). `w/`/`w/o` means `Batch cond/uncond` in `Settings/Optimization` is checked/unchecked.
+| Optimization | VRAM w/ | VRAM w/o |
+| --- | --- | --- |
+| No optimization | 12.13GB | |
+| xformers/sdp | 5.60GB | 4.21GB |
+| sub-quadratic | 10.39GB | |
+
+For SDXL + HotShot + SDP, tested on Ubuntu 22.04, NVIDIA 4090, torch 2.0.1+cu117, H=W=512, frame=8 (default setting), you need 8.66GB VRAM.
+
+For SDXL + AnimateDiff + SDP, tested on Ubuntu 22.04, NVIDIA 4090, torch 2.0.1+cu117, H=1024, W=768, frame=16, you need 13.87GB VRAM.
+
+
+## Batch Size
+Batch size on WebUI will be replaced by GIF frame number internally: 1 full GIF generated in 1 batch. If you want to generate multiple GIF at once, please change batch number.
+
+Batch number is NOT the same as batch size. In A1111 WebUI, batch number is above batch size. Batch number means the number of sequential steps, but batch size means the number of parallel steps. You do not have to worry too much when you increase batch number, but you do need to worry about your VRAM when you increase your batch size (where in this extension, video frame number). You do not need to change batch size at all when you are using this extension.
+
+We are currently developing approach to support batch size on WebUI in the near future.
+
+
+## Demo
+
+### Basic Usage
+| AnimateDiff | Extension | img2img |
+| --- | --- | --- |
+| ![image](https://user-images.githubusercontent.com/63914308/255306527-5105afe8-d497-4ab1-b5c4-37540e9601f8.gif) |![00013-10788741199826055000](https://github.com/continue-revolution/sd-webui-animatediff/assets/63914308/43b9cf34-dbd1-4120-b220-ea8cb7882272) | ![00018-727621716](https://github.com/continue-revolution/sd-webui-animatediff/assets/63914308/d04bb573-c8ca-4ae6-a2d9-81f8012bec3a) |
+
+### Motion LoRA
+| No LoRA | PanDown | PanLeft |
+| --- | --- | --- |
+| ![00094-1401397431](https://github.com/continue-revolution/sd-webui-animatediff/assets/63914308/d8d2b860-c781-4dd0-8c0a-0eb26970130b) | ![00095-3197605735](https://github.com/continue-revolution/sd-webui-animatediff/assets/63914308/aed2243f-5494-4fe3-a10a-96c57f6f2906) | ![00093-2722547708](https://github.com/continue-revolution/sd-webui-animatediff/assets/63914308/c32e9aaf-54f2-4f40-879b-e800c7c7848c) |
+
+### Prompt Travel
+![00201-2296305953](https://github.com/continue-revolution/sd-webui-animatediff/assets/63914308/881f317c-f1d2-4635-b84b-b4c4881650f6)
+
+The prompt is similar to [above](#prompt-travel).
+
+### AnimateDiff V3
+You should be able to read infotext to understand how I generated this sample.
+![00024-3973810345](https://github.com/continue-revolution/sd-webui-animatediff/assets/63914308/5f3e3858-8033-4a16-94b0-4dbc0d0a67fc)
+
+
+### AnimateDiff SDXL
+You should be able to read infotext to understand how I generated this sample.
+![00025-1668075705](https://github.com/continue-revolution/sd-webui-animatediff/assets/63914308/6d32daf9-51c6-490f-a942-db36f84f23cf)
+
+### ControlNet V2V
+TODO
+
+
+## Tutorial
+TODO
+
+
+## Thanks
+I thank researchers from [Shanghai AI Lab](https://www.shlab.org.cn/), especially [@guoyww](https://github.com/guoyww) for creating AnimateDiff. I also thank [@neggles](https://github.com/neggles) and [@s9roll7](https://github.com/s9roll7) for creating and improving [AnimateDiff CLI Prompt Travel](https://github.com/s9roll7/animatediff-cli-prompt-travel). This extension could not be made possible without these creative works.
+
+I also thank community developers, especially
+- [@zappityzap](https://github.com/zappityzap) who developed the majority of the [output features](https://github.com/continue-revolution/sd-webui-animatediff/blob/master/scripts/animatediff_output.py)
+- [@TDS4874](https://github.com/TDS4874) and [@opparco](https://github.com/opparco) for resolving the grey issue which significantly improve the performance
+- [@talesofai](https://github.com/talesofai) who developed i2v in [this forked repo](https://github.com/talesofai/AnimateDiff)
+- [@rkfg](https://github.com/rkfg) for developing GIF palette optimization
+
+and many others who have contributed to this extension.
+
+I also thank community users, especially [@streamline](https://twitter.com/kaizirod) who provided dataset and workflow of ControlNet V2V. His workflow is extremely amazing and definitely worth checking out.
+
+
+## Star History
+
+
+
+
+
+## Sponsor
+You can sponsor me via WeChat, AliPay or [PayPal](https://paypal.me/conrevo). You can also support me via [patreon](https://www.patreon.com/conrevo), [ko-fi](https://ko-fi.com/conrevo) or [afdian](https://afdian.net/a/conrevo).
+
+| WeChat | AliPay | PayPal |
+| --- | --- | --- |
+| ![216aff0250c7fd2bb32eeb4f7aae623](https://user-images.githubusercontent.com/63914308/232824466-21051be9-76ce-4862-bb0d-a431c186fce1.jpg) | ![15fe95b4ada738acf3e44c1d45a1805](https://user-images.githubusercontent.com/63914308/232824545-fb108600-729d-4204-8bec-4fd5cc8a14ec.jpg) | ![IMG_1419_](https://github.com/continue-revolution/sd-webui-animatediff/assets/63914308/eaa7b114-a2e6-4ecc-a29f-253ace06d1ea) |
diff --git a/sd-webui-animatediff/model/.gitkeep b/sd-webui-animatediff/model/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sd-webui-animatediff/motion_module.py b/sd-webui-animatediff/motion_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..60dcbe99a67c8b4e509e6ab2c0aaf318cbdeccf8
--- /dev/null
+++ b/sd-webui-animatediff/motion_module.py
@@ -0,0 +1,657 @@
+from enum import Enum
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from modules import sd_hijack, shared
+from ldm.modules.attention import FeedForward
+
+from einops import rearrange, repeat
+import math
+
+
+class MotionModuleType(Enum):
+ AnimateDiffV1 = "AnimateDiff V1, Yuwei GUo, Shanghai AI Lab"
+ AnimateDiffV2 = "AnimateDiff V2, Yuwei Guo, Shanghai AI Lab"
+ AnimateDiffV3 = "AnimateDiff V3, Yuwei Guo, Shanghai AI Lab"
+ AnimateDiffXL = "AnimateDiff SDXL, Yuwei Guo, Shanghai AI Lab"
+ HotShotXL = "HotShot-XL, John Mullan, Natural Synthetics Inc"
+
+
+ @staticmethod
+ def get_mm_type(state_dict: dict[str, torch.Tensor]):
+ keys = list(state_dict.keys())
+ if any(["mid_block" in k for k in keys]):
+ return MotionModuleType.AnimateDiffV2
+ elif any(["temporal_attentions" in k for k in keys]):
+ return MotionModuleType.HotShotXL
+ elif any(["down_blocks.3" in k for k in keys]):
+ if 32 in next((state_dict[key] for key in state_dict if 'pe' in key), None).shape:
+ return MotionModuleType.AnimateDiffV3
+ else:
+ return MotionModuleType.AnimateDiffV1
+ else:
+ return MotionModuleType.AnimateDiffXL
+
+
+def zero_module(module):
+ # Zero out the parameters of a module and return it.
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+class MotionWrapper(nn.Module):
+ def __init__(self, mm_name: str, mm_hash: str, mm_type: MotionModuleType):
+ super().__init__()
+ self.is_v2 = mm_type == MotionModuleType.AnimateDiffV2
+ self.is_v3 = mm_type == MotionModuleType.AnimateDiffV3
+ self.is_hotshot = mm_type == MotionModuleType.HotShotXL
+ self.is_adxl = mm_type == MotionModuleType.AnimateDiffXL
+ self.is_xl = self.is_hotshot or self.is_adxl
+ max_len = 32 if (self.is_v2 or self.is_adxl or self.is_v3) else 24
+ in_channels = (320, 640, 1280) if (self.is_xl) else (320, 640, 1280, 1280)
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+ for c in in_channels:
+ self.down_blocks.append(MotionModule(c, num_mm=2, max_len=max_len, is_hotshot=self.is_hotshot))
+ self.up_blocks.insert(0,MotionModule(c, num_mm=3, max_len=max_len, is_hotshot=self.is_hotshot))
+ if self.is_v2:
+ self.mid_block = MotionModule(1280, num_mm=1, max_len=max_len)
+ self.mm_name = mm_name
+ self.mm_type = mm_type
+ self.mm_hash = mm_hash
+
+
+ def enable_gn_hack(self):
+ return not (self.is_adxl or self.is_v3)
+
+
+class MotionModule(nn.Module):
+ def __init__(self, in_channels, num_mm, max_len, is_hotshot=False):
+ super().__init__()
+ motion_modules = nn.ModuleList([get_motion_module(in_channels, max_len, is_hotshot) for _ in range(num_mm)])
+ if is_hotshot:
+ self.temporal_attentions = motion_modules
+ else:
+ self.motion_modules = motion_modules
+
+
+
+def get_motion_module(in_channels, max_len, is_hotshot):
+ vtm = VanillaTemporalModule(in_channels=in_channels, temporal_position_encoding_max_len=max_len, is_hotshot=is_hotshot)
+ return vtm.temporal_transformer if is_hotshot else vtm
+
+
+class VanillaTemporalModule(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ num_attention_heads = 8,
+ num_transformer_block = 1,
+ attention_block_types =( "Temporal_Self", "Temporal_Self" ),
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = True,
+ temporal_position_encoding_max_len = 24,
+ temporal_attention_dim_div = 1,
+ zero_initialize = True,
+ is_hotshot = False,
+ ):
+ super().__init__()
+
+ self.temporal_transformer = TemporalTransformer3DModel(
+ in_channels=in_channels,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
+ num_layers=num_transformer_block,
+ attention_block_types=attention_block_types,
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ temporal_position_encoding=temporal_position_encoding,
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
+ is_hotshot=is_hotshot,
+ )
+
+ if zero_initialize:
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
+
+
+ def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None): # TODO: encoder_hidden_states do seem to be always None
+ return self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask)
+
+
+class TemporalTransformer3DModel(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ num_attention_heads,
+ attention_head_dim,
+
+ num_layers,
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
+ dropout = 0.0,
+ norm_num_groups = 32,
+ cross_attention_dim = 768,
+ activation_fn = "geglu",
+ attention_bias = False,
+ upcast_attention = False,
+
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = False,
+ temporal_position_encoding_max_len = 24,
+ is_hotshot = False,
+ ):
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ TemporalTransformerBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ attention_block_types=attention_block_types,
+ dropout=dropout,
+ norm_num_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ upcast_attention=upcast_attention,
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ temporal_position_encoding=temporal_position_encoding,
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
+ is_hotshot=is_hotshot,
+ )
+ for d in range(num_layers)
+ ]
+ )
+ self.proj_out = nn.Linear(inner_dim, in_channels)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ video_length = hidden_states.shape[0] // (2 if shared.opts.batch_cond_uncond else 1)
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states).type(hidden_states.dtype)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ # Transformer Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
+
+ # output
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+ output = hidden_states + residual
+ return output
+
+
+class TemporalTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_attention_heads,
+ attention_head_dim,
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
+ dropout = 0.0,
+ norm_num_groups = 32,
+ cross_attention_dim = 768,
+ activation_fn = "geglu",
+ attention_bias = False,
+ upcast_attention = False,
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = False,
+ temporal_position_encoding_max_len = 24,
+ is_hotshot = False,
+ ):
+ super().__init__()
+
+ attention_blocks = []
+ norms = []
+
+ for block_name in attention_block_types:
+ attention_blocks.append(
+ VersatileAttention(
+ attention_mode=block_name.split("_")[0],
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
+
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ temporal_position_encoding=temporal_position_encoding,
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
+ is_hotshot=is_hotshot,
+ )
+ )
+ norms.append(nn.LayerNorm(dim))
+
+ self.attention_blocks = nn.ModuleList(attention_blocks)
+ self.norms = nn.ModuleList(norms)
+
+ self.ff = FeedForward(dim, dropout=dropout, glu=(activation_fn=='geglu'))
+ self.ff_norm = nn.LayerNorm(dim)
+
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
+ norm_hidden_states = norm(hidden_states).type(hidden_states.dtype)
+ hidden_states = attention_block(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
+ video_length=video_length,
+ ) + hidden_states
+
+ hidden_states = self.ff(self.ff_norm(hidden_states).type(hidden_states.dtype)) + hidden_states
+
+ output = hidden_states
+ return output
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(
+ self,
+ d_model,
+ dropout = 0.,
+ max_len = 24,
+ is_hotshot = False,
+ ):
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout)
+ position = torch.arange(max_len).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
+ pe = torch.zeros(1, max_len, d_model)
+ pe[0, :, 0::2] = torch.sin(position * div_term)
+ pe[0, :, 1::2] = torch.cos(position * div_term)
+ self.register_buffer('positional_encoding' if is_hotshot else 'pe', pe)
+ self.is_hotshot = is_hotshot
+
+ def forward(self, x):
+ x = x + (self.positional_encoding[:, :x.size(1)] if self.is_hotshot else self.pe[:, :x.size(1)])
+ return self.dropout(x)
+
+
+class CrossAttention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+
+ self.scale = dim_head**-0.5
+
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+ self._slice_size = None
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
+ else:
+ self.group_norm = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
+ self.to_out.append(nn.Dropout(dropout))
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ self._slice_size = slice_size
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2).type(hidden_states.dtype)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if sd_hijack.current_optimizer is not None and sd_hijack.current_optimizer.name in ["xformers", "sdp", "sdp-no-mem", "sub-quadratic"]:
+ hidden_states = self._memory_efficient_attention(query, key, value, attention_mask, sd_hijack.current_optimizer.name)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+ def _attention(self, query, key, value, attention_mask=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ attention_scores = torch.baddbmm(
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+ query,
+ key.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attention_scores = attention_scores + attention_mask
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attention_probs = attention_probs.to(value.dtype)
+
+ # compute attention output
+ hidden_states = torch.bmm(attention_probs, value)
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+
+ if self.upcast_attention:
+ query_slice = query_slice.float()
+ key_slice = key_slice.float()
+
+ attn_slice = torch.baddbmm(
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
+ query_slice,
+ key_slice.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
+
+ if self.upcast_softmax:
+ attn_slice = attn_slice.float()
+
+ attn_slice = attn_slice.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attn_slice = attn_slice.to(value.dtype)
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _memory_efficient_attention(self, q, k, v, mask, current_optimizer_name):
+ # TODO attention_mask
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+
+ fallthrough = False
+
+ if current_optimizer_name == "xformers" or fallthrough:
+ fallthrough = False
+ try:
+ import xformers.ops
+ from modules.sd_hijack_optimizations import get_xformers_flash_attention_op
+ hidden_states = xformers.ops.memory_efficient_attention(
+ q, k, v, attn_bias=mask,
+ op=get_xformers_flash_attention_op(q, k, v))
+ except (ImportError, RuntimeError, AttributeError):
+ fallthrough = True
+
+ if current_optimizer_name == "sdp" or fallthrough:
+ fallthrough = False
+ try:
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
+ )
+ except (ImportError, RuntimeError, AttributeError):
+ fallthrough = True
+
+ if current_optimizer_name == "sdp-no-mem" or fallthrough:
+ fallthrough = False
+ try:
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
+ )
+ except (ImportError, RuntimeError, AttributeError):
+ fallthrough = True
+
+ if current_optimizer_name == "sub-quadratic" or fallthrough:
+ fallthrough = False
+ try:
+ from modules.sd_hijack_optimizations import sub_quad_attention
+ from modules import shared
+ hidden_states = sub_quad_attention(
+ q, k, v,
+ q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size,
+ kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size,
+ chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold,
+ use_checkpoint=self.training
+ )
+ except (ImportError, RuntimeError, AttributeError):
+ fallthrough = True
+
+ if fallthrough:
+ fallthrough = False
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+ return hidden_states
+
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+
+class VersatileAttention(CrossAttention):
+ def __init__(
+ self,
+ attention_mode = None,
+ cross_frame_attention_mode = None,
+ temporal_position_encoding = False,
+ temporal_position_encoding_max_len = 24,
+ is_hotshot = False,
+ *args, **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ assert attention_mode == "Temporal"
+
+ self.attention_mode = attention_mode
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
+
+ self.pos_encoder = PositionalEncoding(
+ kwargs["query_dim"],
+ dropout=0.,
+ max_len=temporal_position_encoding_max_len,
+ is_hotshot=is_hotshot,
+ ) if (temporal_position_encoding and attention_mode == "Temporal") else None
+
+ def extra_repr(self):
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ if self.attention_mode == "Temporal":
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
+
+ if self.pos_encoder is not None:
+ hidden_states = self.pos_encoder(hidden_states)
+
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
+ else:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2).dtype(hidden_states.dtype)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ xformers_option = shared.opts.data.get("animatediff_xformers", "Optimize attention layers with xformers")
+ optimizer_collections = ["xformers", "sdp", "sdp-no-mem", "sub-quadratic"]
+ if xformers_option == "Do not optimize attention layers": # "Do not optimize attention layers"
+ optimizer_collections = optimizer_collections[1:]
+
+ # attention, what we cannot get enough of
+ if sd_hijack.current_optimizer is not None and sd_hijack.current_optimizer.name in optimizer_collections:
+ optimizer_name = sd_hijack.current_optimizer.name
+ if xformers_option == "Optimize attention layers with sdp (torch >= 2.0.0 required)" and optimizer_name == "xformers":
+ optimizer_name = "sdp" # "Optimize attention layers with sdp (torch >= 2.0.0 required)"
+ hidden_states = self._memory_efficient_attention(query, key, value, attention_mask, optimizer_name)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+
+ if self.attention_mode == "Temporal":
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
+
+ return hidden_states
diff --git a/sd-webui-animatediff/scripts/animatediff.py b/sd-webui-animatediff/scripts/animatediff.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1a19aee0f66c2aa323865c278ce7f30b2f4d5a4
--- /dev/null
+++ b/sd-webui-animatediff/scripts/animatediff.py
@@ -0,0 +1,305 @@
+import gradio as gr
+from modules import script_callbacks, scripts, shared
+from modules.processing import (Processed, StableDiffusionProcessing,
+ StableDiffusionProcessingImg2Img)
+from modules.scripts import PostprocessBatchListArgs, PostprocessImageArgs
+
+from scripts.animatediff_cn import AnimateDiffControl
+from scripts.animatediff_infv2v import AnimateDiffInfV2V
+from scripts.animatediff_latent import AnimateDiffI2VLatent
+from scripts.animatediff_logger import logger_animatediff as logger
+from scripts.animatediff_lora import AnimateDiffLora
+from scripts.animatediff_mm import mm_animatediff as motion_module
+from scripts.animatediff_prompt import AnimateDiffPromptSchedule
+from scripts.animatediff_output import AnimateDiffOutput
+from scripts.animatediff_ui import AnimateDiffProcess, AnimateDiffUiGroup, supported_save_formats
+from scripts.animatediff_infotext import update_infotext, infotext_pasted
+from scripts.animatediff_xyz import patch_xyz, xyz_attrs
+
+script_dir = scripts.basedir()
+motion_module.set_script_dir(script_dir)
+
+
+class AnimateDiffScript(scripts.Script):
+
+ def __init__(self):
+ self.lora_hacker = None
+ self.cfg_hacker = None
+ self.cn_hacker = None
+ self.prompt_scheduler = None
+ self.hacked = False
+ self.infotext_fields: List[Tuple[gr.components.IOComponent, str]] = []
+ self.paste_field_names: List[str] = []
+
+
+ def title(self):
+ return "AnimateDiff"
+
+
+ def show(self, is_img2img):
+ return scripts.AlwaysVisible
+
+
+ def ui(self, is_img2img):
+ unit = AnimateDiffUiGroup().render(
+ is_img2img,
+ motion_module.get_model_dir(),
+ self.infotext_fields,
+ self.paste_field_names
+ )
+ return (unit,)
+
+ def before_process(self, p: StableDiffusionProcessing, params: AnimateDiffProcess):
+ if p.is_api and isinstance(params, dict):
+ self.ad_params = AnimateDiffProcess(**params)
+ params = self.ad_params
+
+ # apply XYZ settings
+ params.apply_xyz()
+ xyz_attrs.clear()
+
+ if params.enable:
+ logger.info("AnimateDiff process start.")
+ params.set_p(p)
+ motion_module.inject(p.sd_model, params.model)
+ self.prompt_scheduler = AnimateDiffPromptSchedule()
+ self.lora_hacker = AnimateDiffLora(motion_module.mm.is_v2)
+ self.lora_hacker.hack()
+ self.cfg_hacker = AnimateDiffInfV2V(p, self.prompt_scheduler)
+ self.cfg_hacker.hack(params)
+ self.cn_hacker = AnimateDiffControl(p, self.prompt_scheduler)
+ self.cn_hacker.hack(params)
+ update_infotext(p, params)
+ self.hacked = True
+ elif self.hacked:
+ self.cn_hacker.restore()
+ self.cfg_hacker.restore()
+ self.lora_hacker.restore()
+ motion_module.restore(p.sd_model)
+ self.hacked = False
+
+
+ def before_process_batch(self, p: StableDiffusionProcessing, params: AnimateDiffProcess, **kwargs):
+ if p.is_api and isinstance(params, dict): params = self.ad_params
+ if params.enable and isinstance(p, StableDiffusionProcessingImg2Img) and not hasattr(p, '_animatediff_i2i_batch'):
+ AnimateDiffI2VLatent().randomize(p, params)
+
+
+ def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBatchListArgs, params: AnimateDiffProcess, **kwargs):
+ if p.is_api and isinstance(params, dict): params = self.ad_params
+ if params.enable:
+ self.prompt_scheduler.save_infotext_img(p)
+
+
+ def postprocess_image(self, p: StableDiffusionProcessing, pp: PostprocessImageArgs, params: AnimateDiffProcess, *args):
+ if p.is_api and isinstance(params, dict): params = self.ad_params
+ if params.enable and isinstance(p, StableDiffusionProcessingImg2Img) and hasattr(p, '_animatediff_paste_to_full'):
+ p.paste_to = p._animatediff_paste_to_full[p.batch_index]
+
+
+ def postprocess(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDiffProcess):
+ if p.is_api and isinstance(params, dict): params = self.ad_params
+ if params.enable:
+ self.prompt_scheduler.save_infotext_txt(res)
+ self.cn_hacker.restore()
+ self.cfg_hacker.restore()
+ self.lora_hacker.restore()
+ motion_module.restore(p.sd_model)
+ self.hacked = False
+ AnimateDiffOutput().output(p, res, params)
+ logger.info("AnimateDiff process end.")
+
+
+def on_ui_settings():
+ section = ("animatediff", "AnimateDiff")
+ s3_selection =("animatediff", "AnimateDiff AWS")
+ shared.opts.add_option(
+ "animatediff_model_path",
+ shared.OptionInfo(
+ None,
+ "Path to save AnimateDiff motion modules",
+ gr.Textbox,
+ section=section,
+ ),
+ )
+ shared.opts.add_option(
+ "animatediff_default_save_formats",
+ shared.OptionInfo(
+ ["GIF", "PNG"],
+ "Default Save Formats",
+ gr.CheckboxGroup,
+ {"choices": supported_save_formats},
+ section=section
+ ).needs_restart()
+ )
+ shared.opts.add_option(
+ "animatediff_optimize_gif_palette",
+ shared.OptionInfo(
+ False,
+ "Calculate the optimal GIF palette, improves quality significantly, removes banding",
+ gr.Checkbox,
+ section=section
+ )
+ )
+ shared.opts.add_option(
+ "animatediff_optimize_gif_gifsicle",
+ shared.OptionInfo(
+ False,
+ "Optimize GIFs with gifsicle, reduces file size",
+ gr.Checkbox,
+ section=section
+ )
+ )
+ shared.opts.add_option(
+ key="animatediff_mp4_crf",
+ info=shared.OptionInfo(
+ default=23,
+ label="MP4 Quality (CRF)",
+ component=gr.Slider,
+ component_args={
+ "minimum": 0,
+ "maximum": 51,
+ "step": 1},
+ section=section
+ )
+ .link("docs", "https://trac.ffmpeg.org/wiki/Encode/H.264#crf")
+ .info("17 for best quality, up to 28 for smaller size")
+ )
+ shared.opts.add_option(
+ key="animatediff_mp4_preset",
+ info=shared.OptionInfo(
+ default="",
+ label="MP4 Encoding Preset",
+ component=gr.Dropdown,
+ component_args={"choices": ["", 'veryslow', 'slower', 'slow', 'medium', 'fast', 'faster', 'veryfast', 'superfast', 'ultrafast']},
+ section=section,
+ )
+ .link("docs", "https://trac.ffmpeg.org/wiki/Encode/H.264#Preset")
+ .info("encoding speed, use the slowest you can tolerate")
+ )
+ shared.opts.add_option(
+ key="animatediff_mp4_tune",
+ info=shared.OptionInfo(
+ default="",
+ label="MP4 Tune encoding for content type",
+ component=gr.Dropdown,
+ component_args={"choices": ["", "film", "animation", "grain"]},
+ section=section
+ )
+ .link("docs", "https://trac.ffmpeg.org/wiki/Encode/H.264#Tune")
+ .info("optimize for specific content types")
+ )
+ shared.opts.add_option(
+ "animatediff_webp_quality",
+ shared.OptionInfo(
+ 80,
+ "WebP Quality (if lossless=True, increases compression and CPU usage)",
+ gr.Slider,
+ {
+ "minimum": 1,
+ "maximum": 100,
+ "step": 1},
+ section=section
+ )
+ )
+ shared.opts.add_option(
+ "animatediff_webp_lossless",
+ shared.OptionInfo(
+ False,
+ "Save WebP in lossless format (highest quality, largest file size)",
+ gr.Checkbox,
+ section=section
+ )
+ )
+ shared.opts.add_option(
+ "animatediff_save_to_custom",
+ shared.OptionInfo(
+ False,
+ "Save frames to stable-diffusion-webui/outputs/{ txt|img }2img-images/AnimateDiff/{gif filename}/{date} "
+ "instead of stable-diffusion-webui/outputs/{ txt|img }2img-images/{date}/.",
+ gr.Checkbox,
+ section=section
+ )
+ )
+ shared.opts.add_option(
+ "animatediff_xformers",
+ shared.OptionInfo(
+ "Optimize attention layers with xformers",
+ "When you have --xformers in your command line args, you want AnimateDiff to ",
+ gr.Radio,
+ {"choices": ["Optimize attention layers with xformers",
+ "Optimize attention layers with sdp (torch >= 2.0.0 required)",
+ "Do not optimize attention layers"]},
+ section=section
+ )
+ )
+ shared.opts.add_option(
+ "animatediff_disable_lcm",
+ shared.OptionInfo(
+ False,
+ "Disable LCM",
+ gr.Checkbox,
+ section=section
+ )
+ )
+ shared.opts.add_option(
+ "animatediff_s3_enable",
+ shared.OptionInfo(
+ False,
+ "Enable to Store file in object storage that supports the s3 protocol",
+ gr.Checkbox,
+ section=s3_selection
+ )
+ )
+ shared.opts.add_option(
+ "animatediff_s3_host",
+ shared.OptionInfo(
+ None,
+ "S3 protocol host",
+ gr.Textbox,
+ section=s3_selection,
+ ),
+ )
+ shared.opts.add_option(
+ "animatediff_s3_port",
+ shared.OptionInfo(
+ None,
+ "S3 protocol port",
+ gr.Textbox,
+ section=s3_selection,
+ ),
+ )
+ shared.opts.add_option(
+ "animatediff_s3_access_key",
+ shared.OptionInfo(
+ None,
+ "S3 protocol access_key",
+ gr.Textbox,
+ section=s3_selection,
+ ),
+ )
+ shared.opts.add_option(
+ "animatediff_s3_secret_key",
+ shared.OptionInfo(
+ None,
+ "S3 protocol secret_key",
+ gr.Textbox,
+ section=s3_selection,
+ ),
+ )
+ shared.opts.add_option(
+ "animatediff_s3_storge_bucket",
+ shared.OptionInfo(
+ None,
+ "Bucket for file storage",
+ gr.Textbox,
+ section=s3_selection,
+ ),
+ )
+
+patch_xyz()
+
+script_callbacks.on_ui_settings(on_ui_settings)
+script_callbacks.on_after_component(AnimateDiffUiGroup.on_after_component)
+script_callbacks.on_before_ui(AnimateDiffUiGroup.on_before_ui)
+script_callbacks.on_infotext_pasted(infotext_pasted)
diff --git a/sd-webui-animatediff/scripts/animatediff_cn.py b/sd-webui-animatediff/scripts/animatediff_cn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d372d6f1437ba5308ffa6cb8f6dbdd5c36427120
--- /dev/null
+++ b/sd-webui-animatediff/scripts/animatediff_cn.py
@@ -0,0 +1,642 @@
+from pathlib import Path
+from types import MethodType
+from typing import Optional
+
+import os
+import shutil
+import cv2
+import numpy as np
+import torch
+from tqdm import tqdm
+from PIL import Image, ImageFilter, ImageOps
+from modules import processing, shared, masking, images, devices
+from modules.paths import data_path
+from modules.processing import (StableDiffusionProcessing,
+ StableDiffusionProcessingImg2Img,
+ StableDiffusionProcessingTxt2Img)
+
+from scripts.animatediff_logger import logger_animatediff as logger
+from scripts.animatediff_ui import AnimateDiffProcess
+from scripts.animatediff_prompt import AnimateDiffPromptSchedule
+from scripts.animatediff_infotext import update_infotext
+from scripts.animatediff_i2ibatch import animatediff_i2ibatch
+
+
+class AnimateDiffControl:
+ original_processing_process_images_hijack = None
+ original_controlnet_main_entry = None
+ original_postprocess_batch = None
+
+ def __init__(self, p: StableDiffusionProcessing, prompt_scheduler: AnimateDiffPromptSchedule):
+ try:
+ from scripts.external_code import find_cn_script
+ self.cn_script = find_cn_script(p.scripts)
+ except:
+ self.cn_script = None
+ self.prompt_scheduler = prompt_scheduler
+
+
+ def hack_batchhijack(self, params: AnimateDiffProcess):
+ cn_script = self.cn_script
+ prompt_scheduler = self.prompt_scheduler
+
+ def get_input_frames():
+ if params.video_source is not None and params.video_source != '':
+ cap = cv2.VideoCapture(params.video_source)
+ frame_count = 0
+ tmp_frame_dir = Path(f'{data_path}/tmp/animatediff-frames/')
+ tmp_frame_dir.mkdir(parents=True, exist_ok=True)
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if not ret:
+ break
+ cv2.imwrite(f"{tmp_frame_dir}/{frame_count}.png", frame)
+ frame_count += 1
+ cap.release()
+ return str(tmp_frame_dir)
+ elif params.video_path is not None and params.video_path != '':
+ return params.video_path
+ return ''
+
+ from scripts.batch_hijack import BatchHijack, instance
+ def hacked_processing_process_images_hijack(self, p: StableDiffusionProcessing, *args, **kwargs):
+ from scripts import external_code
+ from scripts.batch_hijack import InputMode
+
+ units = external_code.get_all_units_in_processing(p)
+ units = [unit for unit in units if getattr(unit, 'enabled', False)]
+
+ if len(units) > 0:
+ global_input_frames = get_input_frames()
+ for idx, unit in enumerate(units):
+ # i2i-batch mode
+ if getattr(p, '_animatediff_i2i_batch', None) and not unit.image:
+ unit.input_mode = InputMode.BATCH
+ # if no input given for this unit, use global input
+ if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH:
+ if not unit.batch_images:
+ assert global_input_frames, 'No input images found for ControlNet module'
+ unit.batch_images = global_input_frames
+ elif not unit.image:
+ try:
+ cn_script.choose_input_image(p, unit, idx)
+ except:
+ assert global_input_frames != '', 'No input images found for ControlNet module'
+ unit.batch_images = global_input_frames
+ unit.input_mode = InputMode.BATCH
+
+ if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH:
+ if 'inpaint' in unit.module:
+ images = shared.listfiles(f'{unit.batch_images}/image')
+ masks = shared.listfiles(f'{unit.batch_images}/mask')
+ assert len(images) == len(masks), 'Inpainting image mask count mismatch'
+ unit.batch_images = [{'image': images[i], 'mask': masks[i]} for i in range(len(images))]
+ else:
+ unit.batch_images = shared.listfiles(unit.batch_images)
+
+ unit_batch_list = [len(unit.batch_images) for unit in units
+ if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH]
+ if getattr(p, '_animatediff_i2i_batch', None):
+ unit_batch_list.append(len(p.init_images))
+
+ if len(unit_batch_list) > 0:
+ video_length = min(unit_batch_list)
+ # ensure that params.video_length <= video_length and params.batch_size <= video_length
+ if params.video_length > video_length:
+ params.video_length = video_length
+ if params.batch_size > video_length:
+ params.batch_size = video_length
+ if params.video_default:
+ params.video_length = video_length
+ p.batch_size = video_length
+ for unit in units:
+ if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH:
+ unit.batch_images = unit.batch_images[:params.video_length]
+
+ animatediff_i2ibatch.cap_init_image(p, params)
+ prompt_scheduler.parse_prompt(p, params)
+ update_infotext(p, params)
+ return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs)
+
+ if AnimateDiffControl.original_processing_process_images_hijack is not None:
+ logger.info('BatchHijack already hacked.')
+ return
+
+ AnimateDiffControl.original_processing_process_images_hijack = BatchHijack.processing_process_images_hijack
+ BatchHijack.processing_process_images_hijack = hacked_processing_process_images_hijack
+ processing.process_images_inner = instance.processing_process_images_hijack
+
+
+ def restore_batchhijack(self):
+ if AnimateDiffControl.original_processing_process_images_hijack is not None:
+ from scripts.batch_hijack import BatchHijack, instance
+ BatchHijack.processing_process_images_hijack = AnimateDiffControl.original_processing_process_images_hijack
+ AnimateDiffControl.original_processing_process_images_hijack = None
+ processing.process_images_inner = instance.processing_process_images_hijack
+
+
+ def hack_cn(self):
+ cn_script = self.cn_script
+
+
+ def hacked_main_entry(self, p: StableDiffusionProcessing):
+ from scripts import external_code, global_state, hook
+ from scripts.controlnet_lora import bind_control_lora
+ from scripts.adapter import Adapter, Adapter_light, StyleAdapter
+ from scripts.batch_hijack import InputMode
+ from scripts.controlnet_lllite import PlugableControlLLLite, clear_all_lllite
+ from scripts.controlmodel_ipadapter import (PlugableIPAdapter,
+ clear_all_ip_adapter)
+ from scripts.hook import ControlModelType, ControlParams, UnetHook
+ from scripts.logging import logger
+ from scripts.processor import model_free_preprocessors
+
+ # TODO: i2i-batch mode, what should I change?
+ def image_has_mask(input_image: np.ndarray) -> bool:
+ return (
+ input_image.ndim == 3 and
+ input_image.shape[2] == 4 and
+ np.max(input_image[:, :, 3]) > 127
+ )
+
+
+ def prepare_mask(
+ mask: Image.Image, p: processing.StableDiffusionProcessing
+ ) -> Image.Image:
+ mask = mask.convert("L")
+ if getattr(p, "inpainting_mask_invert", False):
+ mask = ImageOps.invert(mask)
+
+ if hasattr(p, 'mask_blur_x'):
+ if getattr(p, "mask_blur_x", 0) > 0:
+ np_mask = np.array(mask)
+ kernel_size = 2 * int(2.5 * p.mask_blur_x + 0.5) + 1
+ np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), p.mask_blur_x)
+ mask = Image.fromarray(np_mask)
+ if getattr(p, "mask_blur_y", 0) > 0:
+ np_mask = np.array(mask)
+ kernel_size = 2 * int(2.5 * p.mask_blur_y + 0.5) + 1
+ np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), p.mask_blur_y)
+ mask = Image.fromarray(np_mask)
+ else:
+ if getattr(p, "mask_blur", 0) > 0:
+ mask = mask.filter(ImageFilter.GaussianBlur(p.mask_blur))
+
+ return mask
+
+
+ def set_numpy_seed(p: processing.StableDiffusionProcessing) -> Optional[int]:
+ try:
+ tmp_seed = int(p.all_seeds[0] if p.seed == -1 else max(int(p.seed), 0))
+ tmp_subseed = int(p.all_seeds[0] if p.subseed == -1 else max(int(p.subseed), 0))
+ seed = (tmp_seed + tmp_subseed) & 0xFFFFFFFF
+ np.random.seed(seed)
+ return seed
+ except Exception as e:
+ logger.warning(e)
+ logger.warning('Warning: Failed to use consistent random seed.')
+ return None
+
+ sd_ldm = p.sd_model
+ unet = sd_ldm.model.diffusion_model
+ self.noise_modifier = None
+
+ setattr(p, 'controlnet_control_loras', [])
+
+ if self.latest_network is not None:
+ # always restore (~0.05s)
+ self.latest_network.restore()
+
+ # always clear (~0.05s)
+ clear_all_lllite()
+ clear_all_ip_adapter()
+
+ self.enabled_units = cn_script.get_enabled_units(p)
+
+ if len(self.enabled_units) == 0:
+ self.latest_network = None
+ return
+
+ detected_maps = []
+ forward_params = []
+ post_processors = []
+
+ # cache stuff
+ if self.latest_model_hash != p.sd_model.sd_model_hash:
+ cn_script.clear_control_model_cache()
+
+ for idx, unit in enumerate(self.enabled_units):
+ unit.module = global_state.get_module_basename(unit.module)
+
+ # unload unused preproc
+ module_list = [unit.module for unit in self.enabled_units]
+ for key in self.unloadable:
+ if key not in module_list:
+ self.unloadable.get(key, lambda:None)()
+
+ self.latest_model_hash = p.sd_model.sd_model_hash
+ for idx, unit in enumerate(self.enabled_units):
+ cn_script.bound_check_params(unit)
+
+ resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
+ control_mode = external_code.control_mode_from_value(unit.control_mode)
+
+ if unit.module in model_free_preprocessors:
+ model_net = None
+ else:
+ control_model = cn_script.load_control_model(p, unet, unit.model)
+ model_net = control_model.model if isinstance(control_model, tuple) else control_model
+ model_net.reset()
+ if model_net is not None and getattr(devices, "fp8", False) and not isinstance(model_net, PlugableIPAdapter):
+ for _module in model_net.modules():
+ if isinstance(_module, (torch.nn.Conv2d, torch.nn.Linear)):
+ _module.to(torch.float8_e4m3fn)
+
+ if getattr(model_net, 'is_control_lora', False):
+ control_lora = model_net.control_model
+ bind_control_lora(unet, control_lora)
+ p.controlnet_control_loras.append(control_lora)
+
+ if getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH:
+ input_images = []
+ for img in unit.batch_images:
+ unit.image = img
+ input_image, _ = cn_script.choose_input_image(p, unit, idx)
+ input_images.append(input_image)
+ else:
+ input_image, image_from_a1111 = cn_script.choose_input_image(p, unit, idx)
+ input_images = [input_image]
+
+ if image_from_a1111:
+ a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
+ if a1111_i2i_resize_mode is not None:
+ resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
+
+ for idx, input_image in enumerate(input_images):
+ a1111_mask_image : Optional[Image.Image] = getattr(p, "image_mask", None)
+ if a1111_mask_image and isinstance(a1111_mask_image, list):
+ a1111_mask_image = a1111_mask_image[idx]
+ if 'inpaint' in unit.module and not image_has_mask(input_image) and a1111_mask_image is not None:
+ a1111_mask = np.array(prepare_mask(a1111_mask_image, p))
+ if a1111_mask.ndim == 2:
+ if a1111_mask.shape[0] == input_image.shape[0]:
+ if a1111_mask.shape[1] == input_image.shape[1]:
+ input_image = np.concatenate([input_image[:, :, 0:3], a1111_mask[:, :, None]], axis=2)
+ a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
+ if a1111_i2i_resize_mode is not None:
+ resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
+
+ if 'reference' not in unit.module and issubclass(type(p), StableDiffusionProcessingImg2Img) \
+ and p.inpaint_full_res and a1111_mask_image is not None:
+ logger.debug("A1111 inpaint mask START")
+ input_image = [input_image[:, :, i] for i in range(input_image.shape[2])]
+ input_image = [Image.fromarray(x) for x in input_image]
+
+ mask = prepare_mask(a1111_mask_image, p)
+
+ crop_region = masking.get_crop_region(np.array(mask), p.inpaint_full_res_padding)
+ crop_region = masking.expand_crop_region(crop_region, p.width, p.height, mask.width, mask.height)
+
+ input_image = [
+ images.resize_image(resize_mode.int_value(), i, mask.width, mask.height)
+ for i in input_image
+ ]
+
+ input_image = [x.crop(crop_region) for x in input_image]
+ input_image = [
+ images.resize_image(external_code.ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height)
+ for x in input_image
+ ]
+
+ input_image = [np.asarray(x)[:, :, 0] for x in input_image]
+ input_image = np.stack(input_image, axis=2)
+ logger.debug("A1111 inpaint mask END")
+
+ # safe numpy
+ logger.debug("Safe numpy convertion START")
+ input_image = np.ascontiguousarray(input_image.copy()).copy()
+ logger.debug("Safe numpy convertion END")
+
+ input_images[idx] = input_image
+
+ if 'inpaint_only' == unit.module and issubclass(type(p), StableDiffusionProcessingImg2Img) and p.image_mask is not None:
+ logger.warning('A1111 inpaint and ControlNet inpaint duplicated. ControlNet support enabled.')
+ unit.module = 'inpaint'
+
+ logger.info(f"Loading preprocessor: {unit.module}")
+ preprocessor = self.preprocessor[unit.module]
+
+ high_res_fix = isinstance(p, StableDiffusionProcessingTxt2Img) and getattr(p, 'enable_hr', False)
+
+ h = (p.height // 8) * 8
+ w = (p.width // 8) * 8
+
+ if high_res_fix:
+ if p.hr_resize_x == 0 and p.hr_resize_y == 0:
+ hr_y = int(p.height * p.hr_scale)
+ hr_x = int(p.width * p.hr_scale)
+ else:
+ hr_y, hr_x = p.hr_resize_y, p.hr_resize_x
+ hr_y = (hr_y // 8) * 8
+ hr_x = (hr_x // 8) * 8
+ else:
+ hr_y = h
+ hr_x = w
+
+ if unit.module == 'inpaint_only+lama' and resize_mode == external_code.ResizeMode.OUTER_FIT:
+ # inpaint_only+lama is special and required outpaint fix
+ for idx, input_image in enumerate(input_images):
+ _, input_image = cn_script.detectmap_proc(input_image, unit.module, resize_mode, hr_y, hr_x)
+ input_images[idx] = input_image
+
+ control_model_type = ControlModelType.ControlNet
+ global_average_pooling = False
+
+ if 'reference' in unit.module:
+ control_model_type = ControlModelType.AttentionInjection
+ elif 'revision' in unit.module:
+ control_model_type = ControlModelType.ReVision
+ elif hasattr(model_net, 'control_model') and (isinstance(model_net.control_model, Adapter) or isinstance(model_net.control_model, Adapter_light)):
+ control_model_type = ControlModelType.T2I_Adapter
+ elif hasattr(model_net, 'control_model') and isinstance(model_net.control_model, StyleAdapter):
+ control_model_type = ControlModelType.T2I_StyleAdapter
+ elif isinstance(model_net, PlugableIPAdapter):
+ control_model_type = ControlModelType.IPAdapter
+ elif isinstance(model_net, PlugableControlLLLite):
+ control_model_type = ControlModelType.Controlllite
+
+ if control_model_type is ControlModelType.ControlNet:
+ global_average_pooling = model_net.control_model.global_average_pooling
+
+ preprocessor_resolution = unit.processor_res
+ if unit.pixel_perfect:
+ preprocessor_resolution = external_code.pixel_perfect_resolution(
+ input_images[0],
+ target_H=h,
+ target_W=w,
+ resize_mode=resize_mode
+ )
+
+ logger.info(f'preprocessor resolution = {preprocessor_resolution}')
+ # Preprocessor result may depend on numpy random operations, use the
+ # random seed in `StableDiffusionProcessing` to make the
+ # preprocessor result reproducable.
+ # Currently following preprocessors use numpy random:
+ # - shuffle
+ seed = set_numpy_seed(p)
+ logger.debug(f"Use numpy seed {seed}.")
+
+ controls = []
+ hr_controls = []
+ controls_ipadapter = {'hidden_states': [], 'image_embeds': []}
+ hr_controls_ipadapter = {'hidden_states': [], 'image_embeds': []}
+ for idx, input_image in tqdm(enumerate(input_images), total=len(input_images)):
+ detected_map, is_image = preprocessor(
+ input_image,
+ res=preprocessor_resolution,
+ thr_a=unit.threshold_a,
+ thr_b=unit.threshold_b,
+ )
+
+ if high_res_fix:
+ if is_image:
+ hr_control, hr_detected_map = cn_script.detectmap_proc(detected_map, unit.module, resize_mode, hr_y, hr_x)
+ detected_maps.append((hr_detected_map, unit.module))
+ else:
+ hr_control = detected_map
+ else:
+ hr_control = None
+
+ if is_image:
+ control, detected_map = cn_script.detectmap_proc(detected_map, unit.module, resize_mode, h, w)
+ detected_maps.append((detected_map, unit.module))
+ else:
+ control = detected_map
+ detected_maps.append((input_image, unit.module))
+
+ if control_model_type == ControlModelType.T2I_StyleAdapter:
+ control = control['last_hidden_state']
+
+ if control_model_type == ControlModelType.ReVision:
+ control = control['image_embeds']
+
+ if control_model_type == ControlModelType.IPAdapter:
+ if model_net.is_plus:
+ controls_ipadapter['hidden_states'].append(control['hidden_states'][-2].cpu())
+ else:
+ controls_ipadapter['image_embeds'].append(control['image_embeds'].cpu())
+ if hr_control is not None:
+ if model_net.is_plus:
+ hr_controls_ipadapter['hidden_states'].append(hr_control['hidden_states'][-2].cpu())
+ else:
+ hr_controls_ipadapter['image_embeds'].append(hr_control['image_embeds'].cpu())
+ else:
+ hr_controls_ipadapter = None
+ hr_controls = None
+ else:
+ controls.append(control.cpu())
+ if hr_control is not None:
+ hr_controls.append(hr_control.cpu())
+ else:
+ hr_controls = None
+
+ if control_model_type == ControlModelType.IPAdapter:
+ ipadapter_key = 'hidden_states' if model_net.is_plus else 'image_embeds'
+ controls = {ipadapter_key: torch.cat(controls_ipadapter[ipadapter_key], dim=0)}
+ if controls[ipadapter_key].shape[0] > 1:
+ controls[ipadapter_key] = torch.cat([controls[ipadapter_key], controls[ipadapter_key]], dim=0)
+ if model_net.is_plus:
+ controls[ipadapter_key] = [controls[ipadapter_key], None]
+ if hr_controls_ipadapter is not None:
+ hr_controls = {ipadapter_key: torch.cat(hr_controls_ipadapter[ipadapter_key], dim=0)}
+ if hr_controls[ipadapter_key].shape[0] > 1:
+ hr_controls[ipadapter_key] = torch.cat([hr_controls[ipadapter_key], hr_controls[ipadapter_key]], dim=0)
+ if model_net.is_plus:
+ hr_controls[ipadapter_key] = [hr_controls[ipadapter_key], None]
+ else:
+ controls = torch.cat(controls, dim=0)
+ if controls.shape[0] > 1:
+ controls = torch.cat([controls, controls], dim=0)
+ if hr_controls is not None:
+ hr_controls = torch.cat(hr_controls, dim=0)
+ if hr_controls.shape[0] > 1:
+ hr_controls = torch.cat([hr_controls, hr_controls], dim=0)
+
+ preprocessor_dict = dict(
+ name=unit.module,
+ preprocessor_resolution=preprocessor_resolution,
+ threshold_a=unit.threshold_a,
+ threshold_b=unit.threshold_b
+ )
+
+ forward_param = ControlParams(
+ control_model=model_net,
+ preprocessor=preprocessor_dict,
+ hint_cond=controls,
+ weight=unit.weight,
+ guidance_stopped=False,
+ start_guidance_percent=unit.guidance_start,
+ stop_guidance_percent=unit.guidance_end,
+ advanced_weighting=None,
+ control_model_type=control_model_type,
+ global_average_pooling=global_average_pooling,
+ hr_hint_cond=hr_controls,
+ soft_injection=control_mode != external_code.ControlMode.BALANCED,
+ cfg_injection=control_mode == external_code.ControlMode.CONTROL,
+ )
+ forward_params.append(forward_param)
+
+ unit_is_batch = getattr(unit, 'input_mode', InputMode.SIMPLE) == InputMode.BATCH
+ if 'inpaint_only' in unit.module:
+ final_inpaint_raws = []
+ final_inpaint_masks = []
+ for i in range(len(controls)):
+ final_inpaint_feed = hr_controls[i] if hr_controls is not None else controls[i]
+ final_inpaint_feed = final_inpaint_feed.detach().cpu().numpy()
+ final_inpaint_feed = np.ascontiguousarray(final_inpaint_feed).copy()
+ final_inpaint_mask = final_inpaint_feed[0, 3, :, :].astype(np.float32)
+ final_inpaint_raw = final_inpaint_feed[0, :3].astype(np.float32)
+ sigma = shared.opts.data.get("control_net_inpaint_blur_sigma", 7)
+ final_inpaint_mask = cv2.dilate(final_inpaint_mask, np.ones((sigma, sigma), dtype=np.uint8))
+ final_inpaint_mask = cv2.blur(final_inpaint_mask, (sigma, sigma))[None]
+ _, Hmask, Wmask = final_inpaint_mask.shape
+ final_inpaint_raw = torch.from_numpy(np.ascontiguousarray(final_inpaint_raw).copy())
+ final_inpaint_mask = torch.from_numpy(np.ascontiguousarray(final_inpaint_mask).copy())
+ final_inpaint_raws.append(final_inpaint_raw)
+ final_inpaint_masks.append(final_inpaint_mask)
+
+ def inpaint_only_post_processing(x, i):
+ _, H, W = x.shape
+ if Hmask != H or Wmask != W:
+ logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
+ return x
+ idx = i if unit_is_batch else 0
+ r = final_inpaint_raw[idx].to(x.dtype).to(x.device)
+ m = final_inpaint_mask[idx].to(x.dtype).to(x.device)
+ y = m * x.clip(0, 1) + (1 - m) * r
+ y = y.clip(0, 1)
+ return y
+
+ post_processors.append(inpaint_only_post_processing)
+
+ if 'recolor' in unit.module:
+ final_feeds = []
+ for i in range(len(controls)):
+ final_feed = hr_control if hr_control is not None else control
+ final_feed = final_feed.detach().cpu().numpy()
+ final_feed = np.ascontiguousarray(final_feed).copy()
+ final_feed = final_feed[0, 0, :, :].astype(np.float32)
+ final_feed = (final_feed * 255).clip(0, 255).astype(np.uint8)
+ Hfeed, Wfeed = final_feed.shape
+ final_feeds.append(final_feed)
+
+ if 'luminance' in unit.module:
+
+ def recolor_luminance_post_processing(x, i):
+ C, H, W = x.shape
+ if Hfeed != H or Wfeed != W or C != 3:
+ logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
+ return x
+ h = x.detach().cpu().numpy().transpose((1, 2, 0))
+ h = (h * 255).clip(0, 255).astype(np.uint8)
+ h = cv2.cvtColor(h, cv2.COLOR_RGB2LAB)
+ h[:, :, 0] = final_feed[i if unit_is_batch else 0]
+ h = cv2.cvtColor(h, cv2.COLOR_LAB2RGB)
+ h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1))
+ y = torch.from_numpy(h).clip(0, 1).to(x)
+ return y
+
+ post_processors.append(recolor_luminance_post_processing)
+
+ if 'intensity' in unit.module:
+
+ def recolor_intensity_post_processing(x, i):
+ C, H, W = x.shape
+ if Hfeed != H or Wfeed != W or C != 3:
+ logger.error('Error: ControlNet find post-processing resolution mismatch. This could be related to other extensions hacked processing.')
+ return x
+ h = x.detach().cpu().numpy().transpose((1, 2, 0))
+ h = (h * 255).clip(0, 255).astype(np.uint8)
+ h = cv2.cvtColor(h, cv2.COLOR_RGB2HSV)
+ h[:, :, 2] = final_feed[i if unit_is_batch else 0]
+ h = cv2.cvtColor(h, cv2.COLOR_HSV2RGB)
+ h = (h.astype(np.float32) / 255.0).transpose((2, 0, 1))
+ y = torch.from_numpy(h).clip(0, 1).to(x)
+ return y
+
+ post_processors.append(recolor_intensity_post_processing)
+
+ if '+lama' in unit.module:
+ forward_param.used_hint_cond_latent = hook.UnetHook.call_vae_using_process(p, control)
+ self.noise_modifier = forward_param.used_hint_cond_latent
+
+ del model_net
+
+ is_low_vram = any(unit.low_vram for unit in self.enabled_units)
+
+ self.latest_network = UnetHook(lowvram=is_low_vram)
+ self.latest_network.hook(model=unet, sd_ldm=sd_ldm, control_params=forward_params, process=p)
+
+ for param in forward_params:
+ if param.control_model_type == ControlModelType.IPAdapter:
+ param.control_model.hook(
+ model=unet,
+ clip_vision_output=param.hint_cond,
+ weight=param.weight,
+ dtype=torch.float32,
+ start=param.start_guidance_percent,
+ end=param.stop_guidance_percent
+ )
+ if param.control_model_type == ControlModelType.Controlllite:
+ param.control_model.hook(
+ model=unet,
+ cond=param.hint_cond,
+ weight=param.weight,
+ start=param.start_guidance_percent,
+ end=param.stop_guidance_percent
+ )
+
+ self.detected_map = detected_maps
+ self.post_processors = post_processors
+
+ if os.path.exists(f'{data_path}/tmp/animatediff-frames/'):
+ shutil.rmtree(f'{data_path}/tmp/animatediff-frames/')
+
+ def hacked_postprocess_batch(self, p, *args, **kwargs):
+ images = kwargs.get('images', [])
+ for post_processor in self.post_processors:
+ for i in range(len(images)):
+ images[i] = post_processor(images[i], i)
+ return
+
+ if AnimateDiffControl.original_controlnet_main_entry is not None:
+ logger.info('ControlNet Main Entry already hacked.')
+ return
+
+ AnimateDiffControl.original_controlnet_main_entry = self.cn_script.controlnet_main_entry
+ AnimateDiffControl.original_postprocess_batch = self.cn_script.postprocess_batch
+ self.cn_script.controlnet_main_entry = MethodType(hacked_main_entry, self.cn_script)
+ self.cn_script.postprocess_batch = MethodType(hacked_postprocess_batch, self.cn_script)
+
+
+ def restore_cn(self):
+ if AnimateDiffControl.original_controlnet_main_entry is not None:
+ self.cn_script.controlnet_main_entry = AnimateDiffControl.original_controlnet_main_entry
+ AnimateDiffControl.original_controlnet_main_entry = None
+ if AnimateDiffControl.original_postprocess_batch is not None:
+ self.cn_script.postprocess_batch = AnimateDiffControl.original_postprocess_batch
+ AnimateDiffControl.original_postprocess_batch = None
+
+
+ def hack(self, params: AnimateDiffProcess):
+ if self.cn_script is not None:
+ logger.info(f"Hacking ControlNet.")
+ self.hack_batchhijack(params)
+ self.hack_cn()
+
+
+ def restore(self):
+ if self.cn_script is not None:
+ logger.info(f"Restoring ControlNet.")
+ self.restore_batchhijack()
+ self.restore_cn()
diff --git a/sd-webui-animatediff/scripts/animatediff_i2ibatch.py b/sd-webui-animatediff/scripts/animatediff_i2ibatch.py
new file mode 100644
index 0000000000000000000000000000000000000000..a308077c2930f0fca47b46db9459fac777741923
--- /dev/null
+++ b/sd-webui-animatediff/scripts/animatediff_i2ibatch.py
@@ -0,0 +1,309 @@
+from pathlib import Path
+from types import MethodType
+
+import os
+import cv2
+import numpy as np
+import torch
+import hashlib
+from PIL import Image, ImageOps, UnidentifiedImageError
+from modules import processing, shared, scripts, img2img, devices, masking, sd_samplers, images
+from modules.processing import (StableDiffusionProcessingImg2Img,
+ process_images,
+ create_binary_mask,
+ create_random_tensors,
+ images_tensor_to_samples,
+ setup_color_correction,
+ opt_f)
+from modules.shared import opts
+from modules.sd_samplers_common import images_tensor_to_samples, approximation_indexes
+
+from scripts.animatediff_logger import logger_animatediff as logger
+
+
+class AnimateDiffI2IBatch:
+ original_img2img_process_batch = None
+
+ def hack(self):
+ # TODO: PR this hack to A1111
+ if AnimateDiffI2IBatch.original_img2img_process_batch is not None:
+ logger.info("Hacking i2i-batch is already done.")
+ return
+
+ logger.info("Hacking i2i-batch.")
+ AnimateDiffI2IBatch.original_img2img_process_batch = img2img.process_batch
+ original_img2img_process_batch = AnimateDiffI2IBatch.original_img2img_process_batch
+
+ def hacked_i2i_init(self, all_prompts, all_seeds, all_subseeds): # only hack this when i2i-batch with batch mask
+ self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
+
+ self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
+ crop_regions = []
+ paste_to = []
+ masks_for_overlay = []
+
+ image_masks = self.image_mask
+
+ for idx, image_mask in enumerate(image_masks):
+ # image_mask is passed in as RGBA by Gradio to support alpha masks,
+ # but we still want to support binary masks.
+ image_mask = create_binary_mask(image_mask)
+
+ if self.inpainting_mask_invert:
+ image_mask = ImageOps.invert(image_mask)
+
+ if self.mask_blur_x > 0:
+ np_mask = np.array(image_mask)
+ kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1
+ np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
+ image_mask = Image.fromarray(np_mask)
+
+ if self.mask_blur_y > 0:
+ np_mask = np.array(image_mask)
+ kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1
+ np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
+ image_mask = Image.fromarray(np_mask)
+
+ if self.inpaint_full_res:
+ masks_for_overlay.append(image_mask)
+ mask = image_mask.convert('L')
+ crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
+ crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
+ crop_regions.append(crop_region)
+ x1, y1, x2, y2 = crop_region
+
+ mask = mask.crop(crop_region)
+ image_mask = images.resize_image(2, mask, self.width, self.height)
+ paste_to.append((x1, y1, x2-x1, y2-y1))
+ else:
+ image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
+ np_mask = np.array(image_mask)
+ np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
+ masks_for_overlay.append(Image.fromarray(np_mask))
+
+ image_masks[idx] = image_mask
+
+ self.mask_for_overlay = masks_for_overlay[0] # only for saving purpose
+ if paste_to:
+ self.paste_to = paste_to[0]
+ self._animatediff_paste_to_full = paste_to
+
+ self.overlay_images = []
+ add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
+ if add_color_corrections:
+ self.color_corrections = []
+ imgs = []
+ for idx, img in enumerate(self.init_images):
+ latent_mask = (self.latent_mask[idx] if isinstance(self.latent_mask, list) else self.latent_mask) if self.latent_mask is not None else image_masks[idx]
+ # Save init image
+ if opts.save_init_img:
+ self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
+ images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
+
+ image = images.flatten(img, opts.img2img_background_color)
+
+ if not crop_regions and self.resize_mode != 3:
+ image = images.resize_image(self.resize_mode, image, self.width, self.height)
+
+ if image_masks:
+ image_masked = Image.new('RGBa', (image.width, image.height))
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(masks_for_overlay[idx].convert('L')))
+
+ self.overlay_images.append(image_masked.convert('RGBA'))
+
+ # crop_region is not None if we are doing inpaint full res
+ if crop_regions:
+ image = image.crop(crop_regions[idx])
+ image = images.resize_image(2, image, self.width, self.height)
+
+ if image_masks:
+ if self.inpainting_fill != 1:
+ image = masking.fill(image, latent_mask)
+
+ if add_color_corrections:
+ self.color_corrections.append(setup_color_correction(image))
+
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.moveaxis(image, 2, 0)
+
+ imgs.append(image)
+
+ if len(imgs) == 1:
+ batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
+ if self.overlay_images is not None:
+ self.overlay_images = self.overlay_images * self.batch_size
+
+ if self.color_corrections is not None and len(self.color_corrections) == 1:
+ self.color_corrections = self.color_corrections * self.batch_size
+
+ elif len(imgs) <= self.batch_size:
+ self.batch_size = len(imgs)
+ batch_images = np.array(imgs)
+ else:
+ raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
+
+ image = torch.from_numpy(batch_images)
+ image = image.to(shared.device, dtype=devices.dtype_vae)
+
+ if opts.sd_vae_encode_method != 'Full':
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
+
+ self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
+ devices.torch_gc()
+
+ if self.resize_mode == 3:
+ self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+
+ if image_masks is not None:
+ def process_letmask(init_mask):
+ # init_mask = latent_mask
+ latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
+ latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
+ latmask = latmask[0]
+ latmask = np.around(latmask)
+ return np.tile(latmask[None], (4, 1, 1))
+
+ if self.latent_mask is not None and not isinstance(self.latent_mask, list):
+ latmask = process_letmask(self.latent_mask)
+ else:
+ if isinstance(self.latent_mask, list):
+ latmask = [process_letmask(x) for x in self.latent_mask]
+ else:
+ latmask = [process_letmask(x) for x in image_masks]
+ latmask = np.stack(latmask, axis=0)
+
+ self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
+ self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
+
+ # this needs to be fixed to be done in sample() using actual seeds for batches
+ if self.inpainting_fill == 2:
+ self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
+ elif self.inpainting_fill == 3:
+ self.init_latent = self.init_latent * self.mask
+
+ self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_masks) # let's ignore this image_masks which is related to inpaint model with different arch
+
+ def hacked_img2img_process_batch_hijack(
+ p: StableDiffusionProcessingImg2Img, input_dir: str, output_dir: str, inpaint_mask_dir: str,
+ args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
+ if p.scripts:
+ for script in p.scripts.alwayson_scripts:
+ if script.title().lower() == "animatediff":
+ ad_arg = p.script_args[script.args_from]
+ ad_enabled = ad_arg.get('enable', False) if isinstance(ad_arg, dict) else getattr(ad_arg, 'enable', False)
+ if ad_enabled:
+ p._animatediff_i2i_batch = 1 # i2i-batch mode, ordinary
+
+ if not hasattr(p, '_animatediff_i2i_batch'):
+ return original_img2img_process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale, scale_by, use_png_info, png_info_props, png_info_dir)
+ output_dir = output_dir.strip()
+ processing.fix_seed(p)
+
+ images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
+
+ is_inpaint_batch = False
+ if inpaint_mask_dir:
+ inpaint_masks = shared.listfiles(inpaint_mask_dir)
+ is_inpaint_batch = bool(inpaint_masks)
+
+ if is_inpaint_batch:
+ assert len(inpaint_masks) == 1 or len(inpaint_masks) == len(images), 'The number of masks must be 1 or equal to the number of images.'
+ logger.info(f"\n[i2i batch] Inpaint batch is enabled. {len(inpaint_masks)} masks found.")
+ if len(inpaint_masks) > 1: # batch mask
+ p.init = MethodType(hacked_i2i_init, p)
+
+ logger.info(f"[i2i batch] Will process {len(images)} images, creating {p.n_iter} new videos.")
+
+ # extract "default" params to use in case getting png info fails
+ prompt = p.prompt
+ negative_prompt = p.negative_prompt
+ seed = p.seed
+ cfg_scale = p.cfg_scale
+ sampler_name = p.sampler_name
+ steps = p.steps
+ frame_images = []
+ frame_masks = []
+
+ for i, image in enumerate(images):
+
+ try:
+ img = Image.open(image)
+ except UnidentifiedImageError as e:
+ print(e)
+ continue
+ # Use the EXIF orientation of photos taken by smartphones.
+ img = ImageOps.exif_transpose(img)
+
+ if to_scale:
+ p.width = int(img.width * scale_by)
+ p.height = int(img.height * scale_by)
+
+ frame_images.append(img)
+
+ image_path = Path(image)
+ if is_inpaint_batch:
+ if len(inpaint_masks) == 1:
+ mask_image_path = inpaint_masks[0]
+ p.image_mask = Image.open(mask_image_path)
+ else:
+ # try to find corresponding mask for an image using index matching
+ mask_image_path = inpaint_masks[i]
+ frame_masks.append(Image.open(mask_image_path))
+
+ mask_image = Image.open(mask_image_path)
+ p.image_mask = mask_image
+
+ if use_png_info:
+ try:
+ info_img = frame_images[0]
+ if png_info_dir:
+ info_img_path = os.path.join(png_info_dir, os.path.basename(image))
+ info_img = Image.open(info_img_path)
+ from modules import images as imgutil
+ from modules.generation_parameters_copypaste import parse_generation_parameters
+ geninfo, _ = imgutil.read_info_from_image(info_img)
+ parsed_parameters = parse_generation_parameters(geninfo)
+ parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
+ except Exception:
+ parsed_parameters = {}
+
+ p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
+ p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
+ p.seed = int(parsed_parameters.get("Seed", seed))
+ p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
+ p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
+ p.steps = int(parsed_parameters.get("Steps", steps))
+
+ p.init_images = frame_images
+ if len(frame_masks) > 0:
+ p.image_mask = frame_masks
+
+ proc = scripts.scripts_img2img.run(p, *args) # we should not support this, but just leave it here
+ if proc is None:
+ if output_dir:
+ p.outpath_samples = output_dir
+ p.override_settings['save_to_dirs'] = False
+ if p.n_iter > 1 or p.batch_size > 1:
+ p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
+ else:
+ p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
+ return process_images(p)
+ else:
+ logger.warn("Warning: you are using an unsupported external script. AnimateDiff may not work properly.")
+
+ img2img.process_batch = hacked_img2img_process_batch_hijack
+
+
+ def cap_init_image(self, p: StableDiffusionProcessingImg2Img, params):
+ if params.enable and isinstance(p, StableDiffusionProcessingImg2Img) and hasattr(p, '_animatediff_i2i_batch'):
+ if len(p.init_images) > params.video_length:
+ p.init_images = p.init_images[:params.video_length]
+ if p.image_mask and isinstance(p.image_mask, list) and len(p.image_mask) > params.video_length:
+ p.image_mask = p.image_mask[:params.video_length]
+ if len(p.init_images) < params.video_length:
+ params.video_length = len(p.init_images)
+ if len(p.init_images) < params.batch_size:
+ params.batch_size = len(p.init_images)
+
+
+animatediff_i2ibatch = AnimateDiffI2IBatch()
diff --git a/sd-webui-animatediff/scripts/animatediff_infotext.py b/sd-webui-animatediff/scripts/animatediff_infotext.py
new file mode 100644
index 0000000000000000000000000000000000000000..e18a2a04d4416245a4c1d404d34a68a130907e04
--- /dev/null
+++ b/sd-webui-animatediff/scripts/animatediff_infotext.py
@@ -0,0 +1,35 @@
+import os
+
+from modules.paths import data_path
+from modules.processing import StableDiffusionProcessing, StableDiffusionProcessingImg2Img
+
+from scripts.animatediff_ui import AnimateDiffProcess
+from scripts.animatediff_logger import logger_animatediff as logger
+
+
+def update_infotext(p: StableDiffusionProcessing, params: AnimateDiffProcess):
+ if p.extra_generation_params is not None:
+ p.extra_generation_params["AnimateDiff"] = params.get_dict(isinstance(p, StableDiffusionProcessingImg2Img))
+
+
+def write_params_txt(info: str):
+ with open(os.path.join(data_path, "params.txt"), "w", encoding="utf8") as file:
+ file.write(info)
+
+
+
+def infotext_pasted(infotext, results):
+ for k, v in results.items():
+ if not k.startswith("AnimateDiff"):
+ continue
+
+ assert isinstance(v, str), f"Expected string but got {v}."
+ try:
+ for items in v.split(', '):
+ field, value = items.split(': ')
+ results[f"AnimateDiff {field}"] = value
+ results.pop("AnimateDiff")
+ except Exception as e:
+ logger.warn(f"Failed to parse infotext value:\n{v}")
+ logger.warn(f"Exception: {e}")
+ break
diff --git a/sd-webui-animatediff/scripts/animatediff_infv2v.py b/sd-webui-animatediff/scripts/animatediff_infv2v.py
new file mode 100644
index 0000000000000000000000000000000000000000..b057026798eb971a3a2f2b9400550dae8a6ddd8e
--- /dev/null
+++ b/sd-webui-animatediff/scripts/animatediff_infv2v.py
@@ -0,0 +1,322 @@
+from typing import List
+
+import numpy as np
+import torch
+
+from modules import prompt_parser, devices, sd_samplers_common, shared
+from modules.shared import opts, state
+from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
+from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
+from modules.sd_samplers_cfg_denoiser import CFGDenoiser, catenate_conds, subscript_cond, pad_cond
+
+from scripts.animatediff_logger import logger_animatediff as logger
+from scripts.animatediff_ui import AnimateDiffProcess
+from scripts.animatediff_prompt import AnimateDiffPromptSchedule
+
+
+class AnimateDiffInfV2V:
+ cfg_original_forward = None
+
+ def __init__(self, p, prompt_scheduler: AnimateDiffPromptSchedule):
+ try:
+ from scripts.external_code import find_cn_script
+ self.cn_script = find_cn_script(p.scripts)
+ except:
+ self.cn_script = None
+ self.prompt_scheduler = prompt_scheduler
+
+
+ # Returns fraction that has denominator that is a power of 2
+ @staticmethod
+ def ordered_halving(val):
+ # get binary value, padded with 0s for 64 bits
+ bin_str = f"{val:064b}"
+ # flip binary value, padding included
+ bin_flip = bin_str[::-1]
+ # convert binary to int
+ as_int = int(bin_flip, 2)
+ # divide by 1 << 64, equivalent to 2**64, or 18446744073709551616,
+ # or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's)
+ final = as_int / (1 << 64)
+ return final
+
+
+ # Generator that returns lists of latent indeces to diffuse on
+ @staticmethod
+ def uniform(
+ step: int = ...,
+ video_length: int = 0,
+ batch_size: int = 16,
+ stride: int = 1,
+ overlap: int = 4,
+ loop_setting: str = 'R-P',
+ ):
+ if video_length <= batch_size:
+ yield list(range(batch_size))
+ return
+
+ closed_loop = (loop_setting == 'A')
+ stride = min(stride, int(np.ceil(np.log2(video_length / batch_size))) + 1)
+
+ for context_step in 1 << np.arange(stride):
+ pad = int(round(video_length * AnimateDiffInfV2V.ordered_halving(step)))
+ both_close_loop = False
+ for j in range(
+ int(AnimateDiffInfV2V.ordered_halving(step) * context_step) + pad,
+ video_length + pad + (0 if closed_loop else -overlap),
+ (batch_size * context_step - overlap),
+ ):
+ if loop_setting == 'N' and context_step == 1:
+ current_context = [e % video_length for e in range(j, j + batch_size * context_step, context_step)]
+ first_context = [e % video_length for e in range(0, batch_size * context_step, context_step)]
+ last_context = [e % video_length for e in range(video_length - batch_size * context_step, video_length, context_step)]
+ def get_unsorted_index(lst):
+ for i in range(1, len(lst)):
+ if lst[i] < lst[i-1]:
+ return i
+ return None
+ unsorted_index = get_unsorted_index(current_context)
+ if unsorted_index is None:
+ yield current_context
+ elif both_close_loop: # last and this context are close loop
+ both_close_loop = False
+ yield first_context
+ elif unsorted_index < batch_size - overlap: # only this context is close loop
+ yield last_context
+ yield first_context
+ else: # this and next context are close loop
+ both_close_loop = True
+ yield last_context
+ else:
+ yield [e % video_length for e in range(j, j + batch_size * context_step, context_step)]
+
+
+ def hack(self, params: AnimateDiffProcess):
+ if AnimateDiffInfV2V.cfg_original_forward is not None:
+ logger.info("CFGDenoiser already hacked")
+ return
+
+ logger.info(f"Hacking CFGDenoiser forward function.")
+ AnimateDiffInfV2V.cfg_original_forward = CFGDenoiser.forward
+ cn_script = self.cn_script
+ prompt_scheduler = self.prompt_scheduler
+
+ def mm_cn_select(context: List[int]):
+ # take control images for current context.
+ if cn_script and cn_script.latest_network:
+ from scripts.hook import ControlModelType
+ for control in cn_script.latest_network.control_params:
+ if control.control_model_type not in [ControlModelType.IPAdapter, ControlModelType.Controlllite]:
+ if control.hint_cond.shape[0] > len(context):
+ control.hint_cond_backup = control.hint_cond
+ control.hint_cond = control.hint_cond[context]
+ control.hint_cond = control.hint_cond.to(device=devices.get_device_for("controlnet"))
+ if control.hr_hint_cond is not None:
+ if control.hr_hint_cond.shape[0] > len(context):
+ control.hr_hint_cond_backup = control.hr_hint_cond
+ control.hr_hint_cond = control.hr_hint_cond[context]
+ control.hr_hint_cond = control.hr_hint_cond.to(device=devices.get_device_for("controlnet"))
+ # IPAdapter and Controlllite are always on CPU.
+ elif control.control_model_type == ControlModelType.IPAdapter and control.control_model.image_emb.shape[0] > len(context):
+ control.control_model.image_emb_backup = control.control_model.image_emb
+ control.control_model.image_emb = control.control_model.image_emb[context]
+ control.control_model.uncond_image_emb_backup = control.control_model.uncond_image_emb
+ control.control_model.uncond_image_emb = control.control_model.uncond_image_emb[context]
+ elif control.control_model_type == ControlModelType.Controlllite:
+ for module in control.control_model.modules.values():
+ if module.cond_image.shape[0] > len(context):
+ module.cond_image_backup = module.cond_image
+ module.set_cond_image(module.cond_image[context])
+
+ def mm_cn_restore(context: List[int]):
+ # restore control images for next context
+ if cn_script and cn_script.latest_network:
+ from scripts.hook import ControlModelType
+ for control in cn_script.latest_network.control_params:
+ if control.control_model_type not in [ControlModelType.IPAdapter, ControlModelType.Controlllite]:
+ if getattr(control, "hint_cond_backup", None) is not None:
+ control.hint_cond_backup[context] = control.hint_cond.to(device="cpu")
+ control.hint_cond = control.hint_cond_backup
+ if control.hr_hint_cond is not None and getattr(control, "hr_hint_cond_backup", None) is not None:
+ control.hr_hint_cond_backup[context] = control.hr_hint_cond.to(device="cpu")
+ control.hr_hint_cond = control.hr_hint_cond_backup
+ elif control.control_model_type == ControlModelType.IPAdapter and getattr(control.control_model, "image_emb_backup", None) is not None:
+ control.control_model.image_emb = control.control_model.image_emb_backup
+ control.control_model.uncond_image_emb = control.control_model.uncond_image_emb_backup
+ elif control.control_model_type == ControlModelType.Controlllite:
+ for module in control.control_model.modules.values():
+ if getattr(module, "cond_image_backup", None) is not None:
+ module.set_cond_image(module.cond_image_backup)
+
+ def mm_sd_forward(self, x_in, sigma_in, cond_in, image_cond_in, make_condition_dict):
+ x_out = torch.zeros_like(x_in)
+ for context in AnimateDiffInfV2V.uniform(self.step, params.video_length, params.batch_size, params.stride, params.overlap, params.closed_loop):
+ if shared.opts.batch_cond_uncond:
+ _context = context + [c + params.video_length for c in context]
+ else:
+ _context = context
+ mm_cn_select(_context)
+ out = self.inner_model(
+ x_in[_context], sigma_in[_context],
+ cond=make_condition_dict(
+ cond_in[_context] if not isinstance(cond_in, dict) else {k: v[_context] for k, v in cond_in.items()},
+ image_cond_in[_context]))
+ x_out = x_out.to(dtype=out.dtype)
+ x_out[_context] = out
+ mm_cn_restore(_context)
+ return x_out
+
+ def mm_cfg_forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
+ if state.interrupted or state.skipped:
+ raise sd_samplers_common.InterruptedException
+
+ if sd_samplers_common.apply_refiner(self):
+ cond = self.sampler.sampler_extra_args['cond']
+ uncond = self.sampler.sampler_extra_args['uncond']
+
+ # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
+ # so is_edit_model is set to False to support AND composition.
+ is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
+
+ conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
+ uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
+
+ assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+
+ if self.mask_before_denoising and self.mask is not None:
+ x = self.init_latent * self.mask + self.nmask * x
+
+ batch_size = len(conds_list)
+ repeats = [len(conds_list[i]) for i in range(batch_size)]
+
+ if shared.sd_model.model.conditioning_key == "crossattn-adm":
+ image_uncond = torch.zeros_like(image_cond) # this should not be supported.
+ make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
+ else:
+ image_uncond = image_cond
+ if isinstance(uncond, dict):
+ make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
+ else:
+ make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
+
+ if not is_edit_model:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
+ else:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
+
+ denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond)
+ cfg_denoiser_callback(denoiser_params)
+ x_in = denoiser_params.x
+ image_cond_in = denoiser_params.image_cond
+ sigma_in = denoiser_params.sigma
+ tensor = denoiser_params.text_cond
+ uncond = denoiser_params.text_uncond
+ skip_uncond = False
+
+ # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
+ if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
+ skip_uncond = True
+ x_in = x_in[:-batch_size]
+ sigma_in = sigma_in[:-batch_size]
+
+ self.padded_cond_uncond = False
+ if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
+ empty = shared.sd_model.cond_stage_model_empty_prompt
+ num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
+
+ if num_repeats < 0:
+ tensor = pad_cond(tensor, -num_repeats, empty)
+ self.padded_cond_uncond = True
+ elif num_repeats > 0:
+ uncond = pad_cond(uncond, num_repeats, empty)
+ self.padded_cond_uncond = True
+
+ if tensor.shape[1] == uncond.shape[1] or skip_uncond:
+ prompt_closed_loop = (params.video_length > params.batch_size) and (params.closed_loop in ['R+P', 'A']) # hook
+ tensor = prompt_scheduler.multi_cond(tensor, prompt_closed_loop) # hook
+ if is_edit_model:
+ cond_in = catenate_conds([tensor, uncond, uncond])
+ elif skip_uncond:
+ cond_in = tensor
+ else:
+ cond_in = catenate_conds([tensor, uncond])
+
+ if shared.opts.batch_cond_uncond:
+ x_out = mm_sd_forward(self, x_in, sigma_in, cond_in, image_cond_in, make_condition_dict) # hook
+ else:
+ x_out = torch.zeros_like(x_in)
+ for batch_offset in range(0, x_out.shape[0], batch_size):
+ a = batch_offset
+ b = a + batch_size
+ x_out[a:b] = mm_sd_forward(self, x_in[a:b], sigma_in[a:b], subscript_cond(cond_in, a, b), subscript_cond(image_cond_in, a, b), make_condition_dict) # hook
+ else:
+ x_out = torch.zeros_like(x_in)
+ batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size
+ for batch_offset in range(0, tensor.shape[0], batch_size):
+ a = batch_offset
+ b = min(a + batch_size, tensor.shape[0])
+
+ if not is_edit_model:
+ c_crossattn = subscript_cond(tensor, a, b)
+ else:
+ c_crossattn = torch.cat([tensor[a:b]], uncond)
+
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
+
+ if not skip_uncond:
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
+
+ denoised_image_indexes = [x[0][0] for x in conds_list]
+ if skip_uncond:
+ fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
+ x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
+
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
+ cfg_denoised_callback(denoised_params)
+
+ devices.test_for_nans(x_out, "unet")
+
+ if is_edit_model:
+ denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
+ elif skip_uncond:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
+ else:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+
+ if not self.mask_before_denoising and self.mask is not None:
+ denoised = self.init_latent * self.mask + self.nmask * denoised
+
+ self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
+
+ if opts.live_preview_content == "Prompt":
+ preview = self.sampler.last_latent
+ elif opts.live_preview_content == "Negative prompt":
+ preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
+ else:
+ preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
+
+ sd_samplers_common.store_latent(preview)
+
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
+ cfg_after_cfg_callback(after_cfg_callback_params)
+ denoised = after_cfg_callback_params.x
+
+ self.step += 1
+ return denoised
+
+ CFGDenoiser.forward = mm_cfg_forward
+
+
+ def restore(self):
+ if AnimateDiffInfV2V.cfg_original_forward is None:
+ logger.info("CFGDenoiser already restored.")
+ return
+
+ logger.info(f"Restoring CFGDenoiser forward function.")
+ CFGDenoiser.forward = AnimateDiffInfV2V.cfg_original_forward
+ AnimateDiffInfV2V.cfg_original_forward = None
diff --git a/sd-webui-animatediff/scripts/animatediff_latent.py b/sd-webui-animatediff/scripts/animatediff_latent.py
new file mode 100644
index 0000000000000000000000000000000000000000..51d9704c8bcdef469015855ed77b889a65ebd35b
--- /dev/null
+++ b/sd-webui-animatediff/scripts/animatediff_latent.py
@@ -0,0 +1,84 @@
+import numpy as np
+import torch
+from modules import images, shared
+from modules.devices import device, dtype_vae, torch_gc
+from modules.processing import StableDiffusionProcessingImg2Img
+from modules.sd_samplers_common import (approximation_indexes,
+ images_tensor_to_samples)
+
+from scripts.animatediff_logger import logger_animatediff as logger
+from scripts.animatediff_ui import AnimateDiffProcess
+
+
+class AnimateDiffI2VLatent:
+ def randomize(
+ self, p: StableDiffusionProcessingImg2Img, params: AnimateDiffProcess
+ ):
+ # Get init_alpha
+ init_alpha = [
+ 1 - pow(i, params.latent_power) / params.latent_scale
+ for i in range(params.video_length)
+ ]
+ logger.info(f"Randomizing init_latent according to {init_alpha}.")
+ init_alpha = torch.tensor(init_alpha, dtype=torch.float32, device=device)[
+ :, None, None, None
+ ]
+ init_alpha[init_alpha < 0] = 0
+
+ if params.last_frame is not None:
+ last_frame = params.last_frame
+ if type(last_frame) == str:
+ from modules.api.api import decode_base64_to_image
+ last_frame = decode_base64_to_image(last_frame)
+ # Get last_alpha
+ last_alpha = [
+ 1 - pow(i, params.latent_power_last) / params.latent_scale_last
+ for i in range(params.video_length)
+ ]
+ last_alpha.reverse()
+ logger.info(f"Randomizing last_latent according to {last_alpha}.")
+ last_alpha = torch.tensor(last_alpha, dtype=torch.float32, device=device)[
+ :, None, None, None
+ ]
+ last_alpha[last_alpha < 0] = 0
+
+ # Normalize alpha
+ sum_alpha = init_alpha + last_alpha
+ mask_alpha = sum_alpha > 1
+ scaling_factor = 1 / sum_alpha[mask_alpha]
+ init_alpha[mask_alpha] *= scaling_factor
+ last_alpha[mask_alpha] *= scaling_factor
+ init_alpha[0] = 1
+ init_alpha[-1] = 0
+ last_alpha[0] = 0
+ last_alpha[-1] = 1
+
+ # Calculate last_latent
+ if p.resize_mode != 3:
+ last_frame = images.resize_image(
+ p.resize_mode, last_frame, p.width, p.height
+ )
+ last_frame = np.array(last_frame).astype(np.float32) / 255.0
+ last_frame = np.moveaxis(last_frame, 2, 0)[None, ...]
+ last_frame = torch.from_numpy(last_frame).to(device).to(dtype_vae)
+ last_latent = images_tensor_to_samples(
+ last_frame,
+ approximation_indexes.get(shared.opts.sd_vae_encode_method),
+ p.sd_model,
+ )
+ torch_gc()
+ if p.resize_mode == 3:
+ opt_f = 8
+ last_latent = torch.nn.functional.interpolate(
+ last_latent,
+ size=(p.height // opt_f, p.width // opt_f),
+ mode="bilinear",
+ )
+ # Modify init_latent
+ p.init_latent = (
+ p.init_latent * init_alpha
+ + last_latent * last_alpha
+ + p.rng.next() * (1 - init_alpha - last_alpha)
+ )
+ else:
+ p.init_latent = p.init_latent * init_alpha + p.rng.next() * (1 - init_alpha)
diff --git a/sd-webui-animatediff/scripts/animatediff_lcm.py b/sd-webui-animatediff/scripts/animatediff_lcm.py
new file mode 100644
index 0000000000000000000000000000000000000000..cce3ffe460511d194145e6e4adf8dfa4e0af763f
--- /dev/null
+++ b/sd-webui-animatediff/scripts/animatediff_lcm.py
@@ -0,0 +1,137 @@
+
+# TODO: remove this file when LCM is merged to A1111
+import torch
+
+from k_diffusion import utils, sampling
+from k_diffusion.external import DiscreteEpsDDPMDenoiser
+from k_diffusion.sampling import default_noise_sampler, trange
+
+from modules import shared, sd_samplers_cfg_denoiser, sd_samplers_kdiffusion
+from scripts.animatediff_logger import logger_animatediff as logger
+
+
+class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
+ def __init__(self, model):
+ timesteps = 1000
+ beta_start = 0.00085
+ beta_end = 0.012
+
+ betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+
+ original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM)
+ self.skip_steps = timesteps // original_timesteps
+
+
+ alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device)
+ for x in range(original_timesteps):
+ alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
+
+ super().__init__(model, alphas_cumprod_valid, quantize=None)
+
+
+ def get_sigmas(self, n=None, sgm=False):
+ if n is None:
+ return sampling.append_zero(self.sigmas.flip(0))
+
+ start = self.sigma_to_t(self.sigma_max)
+ end = self.sigma_to_t(self.sigma_min)
+
+ if sgm:
+ t = torch.linspace(start, end, n + 1, device=shared.sd_model.device)[:-1]
+ else:
+ t = torch.linspace(start, end, n, device=shared.sd_model.device)
+
+ return sampling.append_zero(self.t_to_sigma(t))
+
+
+ def sigma_to_t(self, sigma, quantize=None):
+ log_sigma = sigma.log()
+ dists = log_sigma - self.log_sigmas[:, None]
+ return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
+
+
+ def t_to_sigma(self, timestep):
+ t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
+ return super().t_to_sigma(t)
+
+
+ def get_eps(self, *args, **kwargs):
+ return self.inner_model.apply_model(*args, **kwargs)
+
+
+ def get_scaled_out(self, sigma, output, input):
+ sigma_data = 0.5
+ scaled_timestep = utils.append_dims(self.sigma_to_t(sigma), output.ndim) * 10.0
+
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
+
+ return c_out * output + c_skip * input
+
+
+ def forward(self, input, sigma, **kwargs):
+ c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
+ eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
+ return self.get_scaled_out(sigma, input + eps * c_out, input)
+
+
+def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
+ extra_args = {} if extra_args is None else extra_args
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
+ s_in = x.new_ones([x.shape[0]])
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+
+ x = denoised
+ if sigmas[i + 1] > 0:
+ x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
+ return x
+
+
+class CFGDenoiserLCM(sd_samplers_cfg_denoiser.CFGDenoiser):
+ @property
+ def inner_model(self):
+ if self.model_wrap is None:
+ denoiser = LCMCompVisDenoiser
+ self.model_wrap = denoiser(shared.sd_model)
+
+ return self.model_wrap
+
+
+class LCMSampler(sd_samplers_kdiffusion.KDiffusionSampler):
+ def __init__(self, funcname, sd_model, options=None):
+ super().__init__(funcname, sd_model, options)
+ self.model_wrap_cfg = CFGDenoiserLCM(self)
+ self.model_wrap = self.model_wrap_cfg.inner_model
+
+
+class AnimateDiffLCM:
+ lcm_ui_injected = False
+
+
+ @staticmethod
+ def hack_kdiff_ui():
+ if shared.opts.data.get("animatediff_disable_lcm", False):
+ return
+
+ if AnimateDiffLCM.lcm_ui_injected:
+ logger.info(f"LCM UI already injected.")
+ return
+
+ logger.info(f"Injecting LCM to UI.")
+ from modules import sd_samplers, sd_samplers_common
+ samplers_lcm = [('LCM', sample_lcm, ['k_lcm'], {})]
+ samplers_data_lcm = [
+ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: LCMSampler(funcname, model), aliases, options)
+ for label, funcname, aliases, options in samplers_lcm
+ ]
+ sd_samplers.all_samplers.extend(samplers_data_lcm)
+ sd_samplers.all_samplers_map = {x.name: x for x in sd_samplers.all_samplers}
+ sd_samplers.set_samplers()
+ AnimateDiffLCM.lcm_ui_injected = True
diff --git a/sd-webui-animatediff/scripts/animatediff_logger.py b/sd-webui-animatediff/scripts/animatediff_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..043b809b912f9c15898eeb40ce97c2a929311b8f
--- /dev/null
+++ b/sd-webui-animatediff/scripts/animatediff_logger.py
@@ -0,0 +1,41 @@
+import copy
+import logging
+import sys
+
+from modules import shared
+
+
+class ColoredFormatter(logging.Formatter):
+ COLORS = {
+ "DEBUG": "\033[0;36m", # CYAN
+ "INFO": "\033[0;32m", # GREEN
+ "WARNING": "\033[0;33m", # YELLOW
+ "ERROR": "\033[0;31m", # RED
+ "CRITICAL": "\033[0;37;41m", # WHITE ON RED
+ "RESET": "\033[0m", # RESET COLOR
+ }
+
+ def format(self, record):
+ colored_record = copy.copy(record)
+ levelname = colored_record.levelname
+ seq = self.COLORS.get(levelname, self.COLORS["RESET"])
+ colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
+ return super().format(colored_record)
+
+
+# Create a new logger
+logger_animatediff = logging.getLogger("AnimateDiff")
+logger_animatediff.propagate = False
+
+# Add handler if we don't have one.
+if not logger_animatediff.handlers:
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setFormatter(
+ ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
+ )
+ logger_animatediff.addHandler(handler)
+
+# Configure logger
+loglevel_string = getattr(shared.cmd_opts, "animatediff_loglevel", "INFO")
+loglevel = getattr(logging, loglevel_string.upper(), None)
+logger_animatediff.setLevel(loglevel)
diff --git a/sd-webui-animatediff/scripts/animatediff_lora.py b/sd-webui-animatediff/scripts/animatediff_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..184613a776ee9ffcc9b985906967f968a8676697
--- /dev/null
+++ b/sd-webui-animatediff/scripts/animatediff_lora.py
@@ -0,0 +1,84 @@
+import os
+import re
+import sys
+
+from modules import sd_models, shared
+from modules.paths import extensions_builtin_dir
+
+from scripts.animatediff_logger import logger_animatediff as logger
+
+sys.path.append(f"{extensions_builtin_dir}/Lora")
+
+class AnimateDiffLora:
+ original_load_network = None
+
+ def __init__(self, v2: bool):
+ self.v2 = v2
+
+ def hack(self):
+ if not self.v2:
+ return
+
+ if AnimateDiffLora.original_load_network is not None:
+ logger.info("AnimateDiff LoRA already hacked")
+ return
+
+ logger.info("Hacking LoRA module to support motion LoRA")
+ import network
+ import networks
+ AnimateDiffLora.original_load_network = networks.load_network
+ original_load_network = AnimateDiffLora.original_load_network
+
+ def mm_load_network(name, network_on_disk):
+
+ def convert_mm_name_to_compvis(key):
+ sd_module_key, _, network_part = re.split(r'(_lora\.)', key)
+ sd_module_key = sd_module_key.replace("processor.", "").replace("to_out", "to_out.0")
+ return sd_module_key, 'lora_' + network_part
+
+ net = network.Network(name, network_on_disk)
+ net.mtime = os.path.getmtime(network_on_disk.filename)
+
+ sd = sd_models.read_state_dict(network_on_disk.filename)
+
+ if 'motion_modules' in list(sd.keys())[0]:
+ logger.info(f"Loading motion LoRA {name} from {network_on_disk.filename}")
+ matched_networks = {}
+
+ for key_network, weight in sd.items():
+ key, network_part = convert_mm_name_to_compvis(key_network)
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
+ assert sd_module is not None, f"Failed to find sd module for key {key}."
+
+ if key not in matched_networks:
+ matched_networks[key] = network.NetworkWeights(
+ network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
+
+ matched_networks[key].w[network_part] = weight
+
+ for key, weights in matched_networks.items():
+ net_module = networks.module_types[0].create_module(net, weights)
+ assert net_module is not None, "Failed to create motion module LoRA"
+ net.modules[key] = net_module
+
+ return net
+ else:
+ del sd
+ return original_load_network(name, network_on_disk)
+
+ networks.load_network = mm_load_network
+
+
+ def restore(self):
+ if not self.v2:
+ return
+
+ if AnimateDiffLora.original_load_network is None:
+ logger.info("AnimateDiff LoRA already restored")
+ return
+
+ logger.info("Restoring hacked LoRA")
+ import networks
+ networks.load_network = AnimateDiffLora.original_load_network
+ AnimateDiffLora.original_load_network = None
diff --git a/sd-webui-animatediff/scripts/animatediff_mm.py b/sd-webui-animatediff/scripts/animatediff_mm.py
new file mode 100644
index 0000000000000000000000000000000000000000..537972dc00a45298bd228c652c652d356468072c
--- /dev/null
+++ b/sd-webui-animatediff/scripts/animatediff_mm.py
@@ -0,0 +1,204 @@
+import gc
+import os
+
+import torch
+from einops import rearrange
+from modules import hashes, shared, sd_models, devices
+from modules.devices import cpu, device, torch_gc
+
+from motion_module import MotionWrapper, MotionModuleType
+from scripts.animatediff_logger import logger_animatediff as logger
+
+
+class AnimateDiffMM:
+ mm_injected = False
+
+ def __init__(self):
+ self.mm: MotionWrapper = None
+ self.script_dir = None
+ self.prev_alpha_cumprod = None
+ self.gn32_original_forward = None
+
+
+ def set_script_dir(self, script_dir):
+ self.script_dir = script_dir
+
+
+ def get_model_dir(self):
+ model_dir = shared.opts.data.get("animatediff_model_path", os.path.join(self.script_dir, "model"))
+ if not model_dir:
+ model_dir = os.path.join(self.script_dir, "model")
+ return model_dir
+
+
+ def _load(self, model_name):
+ model_path = os.path.join(self.get_model_dir(), model_name)
+ if not os.path.isfile(model_path):
+ raise RuntimeError("Please download models manually.")
+ if self.mm is None or self.mm.mm_name != model_name:
+ logger.info(f"Loading motion module {model_name} from {model_path}")
+ model_hash = hashes.sha256(model_path, f"AnimateDiff/{model_name}")
+ mm_state_dict = sd_models.read_state_dict(model_path)
+ model_type = MotionModuleType.get_mm_type(mm_state_dict)
+ logger.info(f"Guessed {model_name} architecture: {model_type}")
+ self.mm = MotionWrapper(model_name, model_hash, model_type)
+ missed_keys = self.mm.load_state_dict(mm_state_dict)
+ logger.warn(f"Missing keys {missed_keys}")
+ self.mm.to(device).eval()
+ if not shared.cmd_opts.no_half:
+ self.mm.half()
+ if getattr(devices, "fp8", False):
+ for module in self.mm.modules():
+ if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
+ module.to(torch.float8_e4m3fn)
+
+
+ def inject(self, sd_model, model_name="mm_sd_v15.ckpt"):
+ if AnimateDiffMM.mm_injected:
+ logger.info("Motion module already injected. Trying to restore.")
+ self.restore(sd_model)
+
+ unet = sd_model.model.diffusion_model
+ self._load(model_name)
+ inject_sdxl = sd_model.is_sdxl or self.mm.is_xl
+ sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5"
+ assert sd_model.is_sdxl == self.mm.is_xl, f"Motion module incompatible with SD. You are using {sd_ver} with {self.mm.mm_type}."
+
+ if self.mm.is_v2:
+ logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet middle block.")
+ unet.middle_block.insert(-1, self.mm.mid_block.motion_modules[0])
+ elif self.mm.enable_gn_hack():
+ logger.info(f"Hacking {sd_ver} GroupNorm32 forward function.")
+ if self.mm.is_hotshot:
+ from sgm.modules.diffusionmodules.util import GroupNorm32
+ else:
+ from ldm.modules.diffusionmodules.util import GroupNorm32
+ self.gn32_original_forward = GroupNorm32.forward
+ gn32_original_forward = self.gn32_original_forward
+
+ def groupnorm32_mm_forward(self, x):
+ x = rearrange(x, "(b f) c h w -> b c f h w", b=2)
+ x = gn32_original_forward(self, x)
+ x = rearrange(x, "b c f h w -> (b f) c h w", b=2)
+ return x
+
+ GroupNorm32.forward = groupnorm32_mm_forward
+
+ logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet input blocks.")
+ for mm_idx, unet_idx in enumerate([1, 2, 4, 5, 7, 8, 10, 11]):
+ if inject_sdxl and mm_idx >= 6:
+ break
+ mm_idx0, mm_idx1 = mm_idx // 2, mm_idx % 2
+ mm_inject = getattr(self.mm.down_blocks[mm_idx0], "temporal_attentions" if self.mm.is_hotshot else "motion_modules")[mm_idx1]
+ unet.input_blocks[unet_idx].append(mm_inject)
+
+ logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet output blocks.")
+ for unet_idx in range(12):
+ if inject_sdxl and unet_idx >= 9:
+ break
+ mm_idx0, mm_idx1 = unet_idx // 3, unet_idx % 3
+ mm_inject = getattr(self.mm.up_blocks[mm_idx0], "temporal_attentions" if self.mm.is_hotshot else "motion_modules")[mm_idx1]
+ if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_xl else 11):
+ unet.output_blocks[unet_idx].insert(-1, mm_inject)
+ else:
+ unet.output_blocks[unet_idx].append(mm_inject)
+
+ self._set_ddim_alpha(sd_model)
+ self._set_layer_mapping(sd_model)
+ AnimateDiffMM.mm_injected = True
+ logger.info(f"Injection finished.")
+
+
+ def restore(self, sd_model):
+ if not AnimateDiffMM.mm_injected:
+ logger.info("Motion module already removed.")
+ return
+
+ inject_sdxl = sd_model.is_sdxl or self.mm.is_xl
+ sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5"
+ self._restore_ddim_alpha(sd_model)
+ unet = sd_model.model.diffusion_model
+
+ logger.info(f"Removing motion module from {sd_ver} UNet input blocks.")
+ for unet_idx in [1, 2, 4, 5, 7, 8, 10, 11]:
+ if inject_sdxl and unet_idx >= 9:
+ break
+ unet.input_blocks[unet_idx].pop(-1)
+
+ logger.info(f"Removing motion module from {sd_ver} UNet output blocks.")
+ for unet_idx in range(12):
+ if inject_sdxl and unet_idx >= 9:
+ break
+ if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_xl else 11):
+ unet.output_blocks[unet_idx].pop(-2)
+ else:
+ unet.output_blocks[unet_idx].pop(-1)
+
+ if self.mm.is_v2:
+ logger.info(f"Removing motion module from {sd_ver} UNet middle block.")
+ unet.middle_block.pop(-2)
+ elif self.mm.enable_gn_hack():
+ logger.info(f"Restoring {sd_ver} GroupNorm32 forward function.")
+ if self.mm.is_hotshot:
+ from sgm.modules.diffusionmodules.util import GroupNorm32
+ else:
+ from ldm.modules.diffusionmodules.util import GroupNorm32
+ GroupNorm32.forward = self.gn32_original_forward
+ self.gn32_original_forward = None
+
+ AnimateDiffMM.mm_injected = False
+ logger.info(f"Removal finished.")
+ if sd_model.lowvram:
+ self.unload()
+
+
+ def _set_ddim_alpha(self, sd_model):
+ logger.info(f"Setting DDIM alpha.")
+ beta_start = 0.00085
+ beta_end = 0.020 if self.mm.is_adxl else 0.012
+ if self.mm.is_adxl:
+ betas = torch.linspace(beta_start**0.5, beta_end**0.5, 1000, dtype=torch.float32, device=device) ** 2
+ else:
+ betas = torch.linspace(
+ beta_start,
+ beta_end,
+ 1000 if sd_model.is_sdxl else sd_model.num_timesteps,
+ dtype=torch.float32,
+ device=device,
+ )
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ self.prev_alpha_cumprod = sd_model.alphas_cumprod
+ sd_model.alphas_cumprod = alphas_cumprod
+
+
+ def _set_layer_mapping(self, sd_model):
+ if hasattr(sd_model, 'network_layer_mapping'):
+ for name, module in self.mm.named_modules():
+ sd_model.network_layer_mapping[name] = module
+ module.network_layer_name = name
+
+
+ def _restore_ddim_alpha(self, sd_model):
+ logger.info(f"Restoring DDIM alpha.")
+ sd_model.alphas_cumprod = self.prev_alpha_cumprod
+ self.prev_alpha_cumprod = None
+
+
+ def unload(self):
+ logger.info("Moving motion module to CPU")
+ if self.mm is not None:
+ self.mm.to(cpu)
+ torch_gc()
+ gc.collect()
+
+
+ def remove(self):
+ logger.info("Removing motion module from any memory")
+ del self.mm
+ self.mm = None
+ torch_gc()
+ gc.collect()
+
+
+mm_animatediff = AnimateDiffMM()
diff --git a/sd-webui-animatediff/scripts/animatediff_output.py b/sd-webui-animatediff/scripts/animatediff_output.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2db85e6a6cf1776efeb681201cf7c3453d81bb8
--- /dev/null
+++ b/sd-webui-animatediff/scripts/animatediff_output.py
@@ -0,0 +1,369 @@
+import base64
+import datetime
+from pathlib import Path
+import traceback
+
+import imageio.v3 as imageio
+import numpy as np
+from PIL import Image, PngImagePlugin
+import PIL.features
+import piexif
+from modules import images, shared
+from modules.processing import Processed, StableDiffusionProcessing
+
+from scripts.animatediff_logger import logger_animatediff as logger
+from scripts.animatediff_ui import AnimateDiffProcess
+
+
+
+class AnimateDiffOutput:
+ def output(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDiffProcess):
+ video_paths = []
+ first_frames = []
+ from_xyz = any("xyz_grid" in frame.filename for frame in traceback.extract_stack())
+ logger.info(f"Saving output formats: {', '.join(params.format)}")
+ date = datetime.datetime.now().strftime('%Y-%m-%d')
+ output_dir = Path(f"{p.outpath_samples}/AnimateDiff/{date}")
+ output_dir.mkdir(parents=True, exist_ok=True)
+ step = params.video_length if params.video_length > params.batch_size else params.batch_size
+ for i in range(res.index_of_first_image, len(res.images), step):
+ # frame interpolation replaces video_list with interpolated frames
+ # so make a copy instead of a slice (reference), to avoid modifying res
+ frame_list = [image.copy() for image in res.images[i : i + params.video_length]]
+ if from_xyz:
+ first_frames.append(res.images[i].copy())
+
+ seq = images.get_next_sequence_number(output_dir, "")
+ filename_suffix = f"-{params.request_id}" if params.request_id else ""
+ filename = f"{seq:05}-{res.all_seeds[(i-res.index_of_first_image)]}{filename_suffix}"
+
+ video_path_prefix = output_dir / filename
+
+ frame_list = self._add_reverse(params, frame_list)
+ frame_list = self._interp(p, params, frame_list, filename)
+ video_paths += self._save(params, frame_list, video_path_prefix, res, i)
+
+ if len(video_paths) == 0:
+ return
+
+ res.images = video_paths if not p.is_api else (self._encode_video_to_b64(video_paths) + (frame_list if 'Frame' in params.format else []))
+
+ # replace results with first frame of each video so xyz grid draws correctly
+ if from_xyz:
+ res.images = first_frames
+
+ def _add_reverse(self, params: AnimateDiffProcess, frame_list: list):
+ if params.video_length <= params.batch_size and params.closed_loop in ['A']:
+ frame_list_reverse = frame_list[::-1]
+ if len(frame_list_reverse) > 0:
+ frame_list_reverse.pop(0)
+ if len(frame_list_reverse) > 0:
+ frame_list_reverse.pop(-1)
+ return frame_list + frame_list_reverse
+ return frame_list
+
+
+ def _interp(
+ self,
+ p: StableDiffusionProcessing,
+ params: AnimateDiffProcess,
+ frame_list: list,
+ filename: str
+ ):
+ if params.interp not in ['FILM']:
+ return frame_list
+
+ try:
+ from deforum_helpers.frame_interpolation import (
+ calculate_frames_to_add, check_and_download_film_model)
+ from film_interpolation.film_inference import run_film_interp_infer
+ except ImportError:
+ logger.error("Deforum not found. Please install: https://github.com/deforum-art/deforum-for-automatic1111-webui.git")
+ return frame_list
+
+ import glob
+ import os
+ import shutil
+
+ import modules.paths as ph
+
+ # load film model
+ deforum_models_path = ph.models_path + '/Deforum'
+ film_model_folder = os.path.join(deforum_models_path,'film_interpolation')
+ film_model_name = 'film_net_fp16.pt'
+ film_model_path = os.path.join(film_model_folder, film_model_name)
+ check_and_download_film_model('film_net_fp16.pt', film_model_folder)
+
+ film_in_between_frames_count = calculate_frames_to_add(len(frame_list), params.interp_x)
+
+ # save original frames to tmp folder for deforum input
+ tmp_folder = f"{p.outpath_samples}/AnimateDiff/tmp"
+ input_folder = f"{tmp_folder}/input"
+ os.makedirs(input_folder, exist_ok=True)
+ for tmp_seq, frame in enumerate(frame_list):
+ imageio.imwrite(f"{input_folder}/{tmp_seq:05}.png", frame)
+
+ # deforum saves output frames to tmp/{filename}
+ save_folder = f"{tmp_folder}/{filename}"
+ os.makedirs(save_folder, exist_ok=True)
+
+ run_film_interp_infer(
+ model_path = film_model_path,
+ input_folder = input_folder,
+ save_folder = save_folder,
+ inter_frames = film_in_between_frames_count)
+
+ # load deforum output frames and replace video_list
+ interp_frame_paths = sorted(glob.glob(os.path.join(save_folder, '*.png')))
+ frame_list = []
+ for f in interp_frame_paths:
+ with Image.open(f) as img:
+ img.load()
+ frame_list.append(img)
+
+ # if saving PNG, enforce saving to custom folder
+ if "PNG" in params.format:
+ params.force_save_to_custom = True
+
+ # remove tmp folder
+ try: shutil.rmtree(tmp_folder)
+ except OSError as e: print(f"Error: {e}")
+
+ return frame_list
+
+
+ def _save(
+ self,
+ params: AnimateDiffProcess,
+ frame_list: list,
+ video_path_prefix: Path,
+ res: Processed,
+ index: int,
+ ):
+ video_paths = []
+ video_array = [np.array(v) for v in frame_list]
+ infotext = res.infotexts[index]
+ s3_enable =shared.opts.data.get("animatediff_s3_enable", False)
+ use_infotext = shared.opts.enable_pnginfo and infotext is not None
+ if "PNG" in params.format and (shared.opts.data.get("animatediff_save_to_custom", False) or getattr(params, "force_save_to_custom", False)):
+ video_path_prefix.mkdir(exist_ok=True, parents=True)
+ for i, frame in enumerate(frame_list):
+ png_filename = video_path_prefix/f"{i:05}.png"
+ png_info = PngImagePlugin.PngInfo()
+ png_info.add_text('parameters', infotext)
+ imageio.imwrite(png_filename, frame, pnginfo=png_info)
+
+ if "GIF" in params.format:
+ video_path_gif = str(video_path_prefix) + ".gif"
+ video_paths.append(video_path_gif)
+ if shared.opts.data.get("animatediff_optimize_gif_palette", False):
+ try:
+ import av
+ except ImportError:
+ from launch import run_pip
+ run_pip(
+ "install imageio[pyav]",
+ "sd-webui-animatediff GIF palette optimization requirement: imageio[pyav]",
+ )
+ imageio.imwrite(
+ video_path_gif, video_array, plugin='pyav', fps=params.fps,
+ codec='gif', out_pixel_format='pal8',
+ filter_graph=(
+ {
+ "split": ("split", ""),
+ "palgen": ("palettegen", ""),
+ "paluse": ("paletteuse", ""),
+ "scale": ("scale", f"{frame_list[0].width}:{frame_list[0].height}")
+ },
+ [
+ ("video_in", "scale", 0, 0),
+ ("scale", "split", 0, 0),
+ ("split", "palgen", 1, 0),
+ ("split", "paluse", 0, 0),
+ ("palgen", "paluse", 0, 1),
+ ("paluse", "video_out", 0, 0),
+ ]
+ )
+ )
+ # imageio[pyav].imwrite doesn't support comment parameter
+ if use_infotext:
+ try:
+ import exiftool
+ except ImportError:
+ from launch import run_pip
+ run_pip(
+ "install PyExifTool",
+ "sd-webui-animatediff GIF palette optimization requirement: PyExifTool",
+ )
+ import exiftool
+ finally:
+ try:
+ exif_tool = exiftool.ExifTool()
+ with exif_tool:
+ escaped_infotext = infotext.replace('\n', r'\n')
+ exif_tool.execute("-overwrite_original", f"-Comment={escaped_infotext}", video_path_gif)
+ except FileNotFoundError:
+ logger.warn(
+ "exiftool not found, required for infotext with optimized GIF palette, try: apt install libimage-exiftool-perl or https://exiftool.org/"
+ )
+ else:
+ imageio.imwrite(
+ video_path_gif,
+ video_array,
+ plugin='pillow',
+ duration=(1000 / params.fps),
+ loop=params.loop_number,
+ comment=(infotext if use_infotext else "")
+ )
+ if shared.opts.data.get("animatediff_optimize_gif_gifsicle", False):
+ self._optimize_gif(video_path_gif)
+
+ if "MP4" in params.format:
+ video_path_mp4 = str(video_path_prefix) + ".mp4"
+ video_paths.append(video_path_mp4)
+ try:
+ import av
+ except ImportError:
+ from launch import run_pip
+ run_pip(
+ "install imageio[pyav]",
+ "sd-webui-animatediff MP4 save requirement: imageio[pyav]",
+ )
+ import av
+ options = {
+ "crf": str(shared.opts.data.get("animatediff_mp4_crf", 23))
+ }
+ preset = shared.opts.data.get("animatediff_mp4_preset", "")
+ if preset != "": options["preset"] = preset
+ tune = shared.opts.data.get("animatediff_mp4_tune", "")
+ if tune != "": options["tune"] = tune
+ output = av.open(video_path_mp4, "w")
+ logger.info(f"Saving {video_path_mp4}")
+ if use_infotext:
+ output.metadata["Comment"] = infotext
+ stream = output.add_stream('libx264', params.fps, options=options)
+ stream.width = frame_list[0].width
+ stream.height = frame_list[0].height
+ for img in video_array:
+ frame = av.VideoFrame.from_ndarray(img)
+ packet = stream.encode(frame)
+ output.mux(packet)
+ packet = stream.encode(None)
+ output.mux(packet)
+ output.close()
+
+ if "TXT" in params.format and res.images[index].info is not None:
+ video_path_txt = str(video_path_prefix) + ".txt"
+ with open(video_path_txt, "w", encoding="utf8") as file:
+ file.write(f"{infotext}\n")
+
+ if "WEBP" in params.format:
+ if PIL.features.check('webp_anim'):
+ video_path_webp = str(video_path_prefix) + ".webp"
+ video_paths.append(video_path_webp)
+ exif_bytes = b''
+ if use_infotext:
+ exif_bytes = piexif.dump({
+ "Exif":{
+ piexif.ExifIFD.UserComment:piexif.helper.UserComment.dump(infotext, encoding="unicode")
+ }})
+ lossless = shared.opts.data.get("animatediff_webp_lossless", False)
+ quality = shared.opts.data.get("animatediff_webp_quality", 80)
+ logger.info(f"Saving {video_path_webp} with lossless={lossless} and quality={quality}")
+ imageio.imwrite(video_path_webp, video_array, plugin='pillow',
+ duration=int(1 / params.fps * 1000), loop=params.loop_number,
+ lossless=lossless, quality=quality, exif=exif_bytes
+ )
+ # see additional Pillow WebP options at https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp
+ else:
+ logger.warn("WebP animation in Pillow requires system WebP library v0.5.0 or later")
+ if "WEBM" in params.format:
+ video_path_webm = str(video_path_prefix) + ".webm"
+ video_paths.append(video_path_webm)
+ logger.info(f"Saving {video_path_webm}")
+ with imageio.imopen(video_path_webm, "w", plugin="pyav") as file:
+ if use_infotext:
+ file.container_metadata["Title"] = infotext
+ file.container_metadata["Comment"] = infotext
+ file.write(video_array, codec="vp9", fps=params.fps)
+
+ if s3_enable:
+ for video_path in video_paths: self._save_to_s3_stroge(video_path)
+ return video_paths
+
+
+ def _optimize_gif(self, video_path: str):
+ try:
+ import pygifsicle
+ except ImportError:
+ from launch import run_pip
+
+ run_pip(
+ "install pygifsicle",
+ "sd-webui-animatediff GIF optimization requirement: pygifsicle",
+ )
+ import pygifsicle
+ finally:
+ try:
+ pygifsicle.optimize(video_path)
+ except FileNotFoundError:
+ logger.warn("gifsicle not found, required for optimized GIFs, try: apt install gifsicle")
+
+
+ def _encode_video_to_b64(self, paths):
+ videos = []
+ for v_path in paths:
+ with open(v_path, "rb") as video_file:
+ videos.append(base64.b64encode(video_file.read()).decode("utf-8"))
+ return videos
+
+ def _install_requirement_if_absent(self,lib):
+ import launch
+ if not launch.is_installed(lib):
+ launch.run_pip(f"install {lib}", f"animatediff requirement: {lib}")
+
+ def _exist_bucket(self,s3_client,bucketname):
+ try:
+ s3_client.head_bucket(Bucket=bucketname)
+ return True
+ except ClientError as e:
+ if e.response['Error']['Code'] == '404':
+ return False
+ else:
+ raise
+
+ def _save_to_s3_stroge(self ,file_path):
+ """
+ put object to object storge
+ :type bucketname: string
+ :param bucketname: will save to this 'bucket' , access_key and secret_key must have permissions to save
+ :type file : file
+ :param file : the local file
+ """
+ self._install_requirement_if_absent('boto3')
+ import boto3
+ from botocore.exceptions import ClientError
+ import os
+ host = shared.opts.data.get("animatediff_s3_host", '127.0.0.1')
+ port = shared.opts.data.get("animatediff_s3_port", '9001')
+ access_key = shared.opts.data.get("animatediff_s3_access_key", '')
+ secret_key = shared.opts.data.get("animatediff_s3_secret_key", '')
+ bucket = shared.opts.data.get("animatediff_s3_storge_bucket", '')
+ client = boto3.client(
+ service_name='s3',
+ aws_access_key_id = access_key,
+ aws_secret_access_key = secret_key,
+ endpoint_url=f'http://{host}:{port}',
+ )
+
+ if not os.path.exists(file_path): return
+ date = datetime.datetime.now().strftime('%Y-%m-%d')
+ if not self._exist_bucket(client,bucket):
+ client.create_bucket(Bucket=bucket)
+
+ filename = os.path.split(file_path)[1]
+ targetpath = f"{date}/{filename}"
+ client.upload_file(file_path, bucket, targetpath)
+ logger.info(f"{file_path} saved to s3 in bucket: {bucket}")
+ return f"http://{host}:{port}/{bucket}/{targetpath}"
+
diff --git a/sd-webui-animatediff/scripts/animatediff_prompt.py b/sd-webui-animatediff/scripts/animatediff_prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..143753441784ff5dad4c1c8ece024e8eb06d056d
--- /dev/null
+++ b/sd-webui-animatediff/scripts/animatediff_prompt.py
@@ -0,0 +1,145 @@
+import re
+import torch
+
+from modules.processing import StableDiffusionProcessing, Processed
+
+from scripts.animatediff_logger import logger_animatediff as logger
+from scripts.animatediff_infotext import write_params_txt
+from scripts.animatediff_ui import AnimateDiffProcess
+
+class AnimateDiffPromptSchedule:
+
+ def __init__(self):
+ self.prompt_map = None
+ self.original_prompt = None
+
+
+ def save_infotext_img(self, p: StableDiffusionProcessing):
+ if self.prompt_map is not None:
+ p.prompts = [self.original_prompt for _ in range(p.batch_size)]
+
+
+ def save_infotext_txt(self, res: Processed):
+ if self.prompt_map is not None:
+ parts = res.info.split('\nNegative prompt: ', 1)
+ if len(parts) > 1:
+ res.info = f"{self.original_prompt}\nNegative prompt: {parts[1]}"
+ for i in range(len(res.infotexts)):
+ parts = res.infotexts[i].split('\nNegative prompt: ', 1)
+ if len(parts) > 1:
+ res.infotexts[i] = f"{self.original_prompt}\nNegative prompt: {parts[1]}"
+ write_params_txt(res.info)
+
+
+ def parse_prompt(self, p: StableDiffusionProcessing, params: AnimateDiffProcess):
+ if type(p.prompt) is not str:
+ logger.warn("prompt is not str, cannot support prompt map")
+ return
+
+ lines = p.prompt.strip().split('\n')
+ data = {
+ 'head_prompts': [],
+ 'mapp_prompts': {},
+ 'tail_prompts': []
+ }
+
+ mode = 'head'
+ for line in lines:
+ if mode == 'head':
+ if re.match(r'^\d+:', line):
+ mode = 'mapp'
+ else:
+ data['head_prompts'].append(line)
+
+ if mode == 'mapp':
+ match = re.match(r'^(\d+): (.+)$', line)
+ if match:
+ frame, prompt = match.groups()
+ assert int(frame) < params.video_length, \
+ f"invalid prompt travel frame number: {int(frame)} >= number of frames ({params.video_length})"
+ data['mapp_prompts'][int(frame)] = prompt
+ else:
+ mode = 'tail'
+
+ if mode == 'tail':
+ data['tail_prompts'].append(line)
+
+ if data['mapp_prompts']:
+ logger.info("You are using prompt travel.")
+ self.prompt_map = {}
+ prompt_list = []
+ last_frame = 0
+ current_prompt = ''
+ for frame, prompt in data['mapp_prompts'].items():
+ prompt_list += [current_prompt for _ in range(last_frame, frame)]
+ last_frame = frame
+ current_prompt = f"{', '.join(data['head_prompts'])}, {prompt}, {', '.join(data['tail_prompts'])}"
+ self.prompt_map[frame] = current_prompt
+ prompt_list += [current_prompt for _ in range(last_frame, p.batch_size)]
+ assert len(prompt_list) == p.batch_size, f"prompt_list length {len(prompt_list)} != batch_size {p.batch_size}"
+ self.original_prompt = p.prompt
+ p.prompt = prompt_list * p.n_iter
+
+
+ def single_cond(self, center_frame, video_length: int, cond: torch.Tensor, closed_loop = False):
+ if closed_loop:
+ key_prev = list(self.prompt_map.keys())[-1]
+ key_next = list(self.prompt_map.keys())[0]
+ else:
+ key_prev = list(self.prompt_map.keys())[0]
+ key_next = list(self.prompt_map.keys())[-1]
+
+ for p in self.prompt_map.keys():
+ if p > center_frame:
+ key_next = p
+ break
+ key_prev = p
+
+ dist_prev = center_frame - key_prev
+ if dist_prev < 0:
+ dist_prev += video_length
+ dist_next = key_next - center_frame
+ if dist_next < 0:
+ dist_next += video_length
+
+ if key_prev == key_next or dist_prev + dist_next == 0:
+ return cond[key_prev] if isinstance(cond, torch.Tensor) else {k: v[key_prev] for k, v in cond.items()}
+
+ rate = dist_prev / (dist_prev + dist_next)
+ if isinstance(cond, torch.Tensor):
+ return AnimateDiffPromptSchedule.slerp(cond[key_prev], cond[key_next], rate)
+ else: # isinstance(cond, dict)
+ return {
+ k: AnimateDiffPromptSchedule.slerp(v[key_prev], v[key_next], rate)
+ for k, v in cond.items()
+ }
+
+
+ def multi_cond(self, cond: torch.Tensor, closed_loop = False):
+ if self.prompt_map is None:
+ return cond
+ cond_list = [] if isinstance(cond, torch.Tensor) else {k: [] for k in cond.keys()}
+ for i in range(cond.shape[0]):
+ single_cond = self.single_cond(i, cond.shape[0], cond, closed_loop)
+ if isinstance(cond, torch.Tensor):
+ cond_list.append(single_cond)
+ else:
+ for k, v in single_cond.items():
+ cond_list[k].append(v)
+ if isinstance(cond, torch.Tensor):
+ return torch.stack(cond_list).to(cond.dtype).to(cond.device)
+ else:
+ return {k: torch.stack(v).to(cond[k].dtype).to(cond[k].device) for k, v in cond_list.items()}
+
+
+ @staticmethod
+ def slerp(
+ v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
+ ) -> torch.Tensor:
+ u0 = v0 / v0.norm()
+ u1 = v1 / v1.norm()
+ dot = (u0 * u1).sum()
+ if dot.abs() > DOT_THRESHOLD:
+ return (1.0 - t) * v0 + t * v1
+ omega = dot.acos()
+ return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
diff --git a/sd-webui-animatediff/scripts/animatediff_ui.py b/sd-webui-animatediff/scripts/animatediff_ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd5c773546643b637953c389be0f0101ecc02695
--- /dev/null
+++ b/sd-webui-animatediff/scripts/animatediff_ui.py
@@ -0,0 +1,374 @@
+import os
+
+import cv2
+import gradio as gr
+
+from modules import shared
+from modules.processing import StableDiffusionProcessing
+
+from scripts.animatediff_mm import mm_animatediff as motion_module
+from scripts.animatediff_i2ibatch import animatediff_i2ibatch
+from scripts.animatediff_lcm import AnimateDiffLCM
+from scripts.animatediff_logger import logger_animatediff as logger
+from scripts.animatediff_xyz import xyz_attrs
+
+supported_save_formats = ["GIF", "MP4", "WEBP", "WEBM", "PNG", "TXT"]
+
+class ToolButton(gr.Button, gr.components.FormComponent):
+ """Small button with single emoji as text, fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(variant="tool", **kwargs)
+
+
+ def get_block_name(self):
+ return "button"
+
+
+class AnimateDiffProcess:
+
+ def __init__(
+ self,
+ model="mm_sd_v15_v2.ckpt",
+ enable=False,
+ video_length=0,
+ fps=8,
+ loop_number=0,
+ closed_loop='R-P',
+ batch_size=16,
+ stride=1,
+ overlap=-1,
+ format=shared.opts.data.get("animatediff_default_save_formats", ["GIF", "PNG"]),
+ interp='Off',
+ interp_x=10,
+ video_source=None,
+ video_path='',
+ latent_power=1,
+ latent_scale=32,
+ last_frame=None,
+ latent_power_last=1,
+ latent_scale_last=32,
+ request_id = '',
+ ):
+ self.model = model
+ self.enable = enable
+ self.video_length = video_length
+ self.fps = fps
+ self.loop_number = loop_number
+ self.closed_loop = closed_loop
+ self.batch_size = batch_size
+ self.stride = stride
+ self.overlap = overlap
+ self.format = format
+ self.interp = interp
+ self.interp_x = interp_x
+ self.video_source = video_source
+ self.video_path = video_path
+ self.latent_power = latent_power
+ self.latent_scale = latent_scale
+ self.last_frame = last_frame
+ self.latent_power_last = latent_power_last
+ self.latent_scale_last = latent_scale_last
+ self.request_id = request_id
+
+
+ def get_list(self, is_img2img: bool):
+ list_var = list(vars(self).values())[:-1]
+ if is_img2img:
+ animatediff_i2ibatch.hack()
+ else:
+ list_var = list_var[:-5]
+ return list_var
+
+
+ def get_dict(self, is_img2img: bool):
+ infotext = {
+ "enable": self.enable,
+ "model": self.model,
+ "video_length": self.video_length,
+ "fps": self.fps,
+ "loop_number": self.loop_number,
+ "closed_loop": self.closed_loop,
+ "batch_size": self.batch_size,
+ "stride": self.stride,
+ "overlap": self.overlap,
+ "interp": self.interp,
+ "interp_x": self.interp_x,
+ }
+ if self.request_id:
+ infotext['request_id'] = self.request_id
+ if motion_module.mm is not None and motion_module.mm.mm_hash is not None:
+ infotext['mm_hash'] = motion_module.mm.mm_hash[:8]
+ if is_img2img:
+ infotext.update({
+ "latent_power": self.latent_power,
+ "latent_scale": self.latent_scale,
+ "latent_power_last": self.latent_power_last,
+ "latent_scale_last": self.latent_scale_last,
+ })
+ infotext_str = ', '.join(f"{k}: {v}" for k, v in infotext.items())
+ return infotext_str
+
+
+ def get_param_names(self, is_img2img: bool):
+ remove = ["format", "request_id", "video_source", "video_path", "last_frame"]
+ if not is_img2img:
+ remove.extend(["latent_power", "latent_power_last", "latent_scale", "latent_scale_last"])
+
+ return [
+ field
+ for field in self.__dict__
+ if field not in remove and not callable(getattr(self, field)) and not field.startswith("__")
+ ]
+
+
+ def _check(self):
+ assert (
+ self.video_length >= 0 and self.fps > 0
+ ), "Video length and FPS should be positive."
+ assert not set(["GIF", "MP4", "PNG", "WEBP", "WEBM"]).isdisjoint(
+ self.format
+ ), "At least one saving format should be selected."
+
+
+ def set_p(self, p: StableDiffusionProcessing):
+ self._check()
+ if self.video_length < self.batch_size:
+ p.batch_size = self.batch_size
+ else:
+ p.batch_size = self.video_length
+ if self.video_length == 0:
+ self.video_length = p.batch_size
+ self.video_default = True
+ else:
+ self.video_default = False
+ if self.overlap == -1:
+ self.overlap = self.batch_size // 4
+ if "PNG" not in self.format or shared.opts.data.get("animatediff_save_to_custom", False):
+ p.do_not_save_samples = True
+
+ def apply_xyz(self):
+ for k, v in xyz_attrs.items():
+ setattr(self, k, v)
+
+
+class AnimateDiffUiGroup:
+ txt2img_submit_button = None
+ img2img_submit_button = None
+
+ def __init__(self):
+ self.params = AnimateDiffProcess()
+
+
+ def render(self, is_img2img: bool, model_dir: str, infotext_fields, paste_field_names):
+ if not os.path.isdir(model_dir):
+ os.mkdir(model_dir)
+ elemid_prefix = "img2img-ad-" if is_img2img else "txt2img-ad-"
+ model_list = [f for f in os.listdir(model_dir) if f != ".gitkeep"]
+ with gr.Accordion("AnimateDiff", open=False):
+ gr.Markdown(value="Please click [this link](https://github.com/continue-revolution/sd-webui-animatediff#webui-parameters) to read the documentation of each parameter.")
+ with gr.Row():
+
+ def refresh_models(*inputs):
+ new_model_list = [
+ f for f in os.listdir(model_dir) if f != ".gitkeep"
+ ]
+ dd = inputs[0]
+ if dd in new_model_list:
+ selected = dd
+ elif len(new_model_list) > 0:
+ selected = new_model_list[0]
+ else:
+ selected = None
+ return gr.Dropdown.update(choices=new_model_list, value=selected)
+
+ with gr.Row():
+ self.params.model = gr.Dropdown(
+ choices=model_list,
+ value=(self.params.model if self.params.model in model_list else None),
+ label="Motion module",
+ type="value",
+ elem_id=f"{elemid_prefix}motion-module",
+ )
+ refresh_model = ToolButton(value="\U0001f504")
+ refresh_model.click(refresh_models, self.params.model, self.params.model)
+
+ self.params.format = gr.CheckboxGroup(
+ choices=supported_save_formats,
+ label="Save format",
+ type="value",
+ elem_id=f"{elemid_prefix}save-format",
+ value=self.params.format,
+ )
+ with gr.Row():
+ self.params.enable = gr.Checkbox(
+ value=self.params.enable, label="Enable AnimateDiff",
+ elem_id=f"{elemid_prefix}enable"
+ )
+ self.params.video_length = gr.Number(
+ minimum=0,
+ value=self.params.video_length,
+ label="Number of frames",
+ precision=0,
+ elem_id=f"{elemid_prefix}video-length",
+ )
+ self.params.fps = gr.Number(
+ value=self.params.fps, label="FPS", precision=0,
+ elem_id=f"{elemid_prefix}fps"
+ )
+ self.params.loop_number = gr.Number(
+ minimum=0,
+ value=self.params.loop_number,
+ label="Display loop number",
+ precision=0,
+ elem_id=f"{elemid_prefix}loop-number",
+ )
+ with gr.Row():
+ self.params.closed_loop = gr.Radio(
+ choices=["N", "R-P", "R+P", "A"],
+ value=self.params.closed_loop,
+ label="Closed loop",
+ elem_id=f"{elemid_prefix}closed-loop",
+ )
+ self.params.batch_size = gr.Slider(
+ minimum=1,
+ maximum=32,
+ value=self.params.batch_size,
+ label="Context batch size",
+ step=1,
+ precision=0,
+ elem_id=f"{elemid_prefix}batch-size",
+ )
+ self.params.stride = gr.Number(
+ minimum=1,
+ value=self.params.stride,
+ label="Stride",
+ precision=0,
+ elem_id=f"{elemid_prefix}stride",
+ )
+ self.params.overlap = gr.Number(
+ minimum=-1,
+ value=self.params.overlap,
+ label="Overlap",
+ precision=0,
+ elem_id=f"{elemid_prefix}overlap",
+ )
+ with gr.Row():
+ self.params.interp = gr.Radio(
+ choices=["Off", "FILM"],
+ label="Frame Interpolation",
+ elem_id=f"{elemid_prefix}interp-choice",
+ value=self.params.interp
+ )
+ self.params.interp_x = gr.Number(
+ value=self.params.interp_x, label="Interp X", precision=0,
+ elem_id=f"{elemid_prefix}interp-x"
+ )
+ self.params.video_source = gr.Video(
+ value=self.params.video_source,
+ label="Video source",
+ )
+ def update_fps(video_source):
+ if video_source is not None and video_source != '':
+ cap = cv2.VideoCapture(video_source)
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
+ cap.release()
+ return fps
+ else:
+ return int(self.params.fps.value)
+ self.params.video_source.change(update_fps, inputs=self.params.video_source, outputs=self.params.fps)
+ def update_frames(video_source):
+ if video_source is not None and video_source != '':
+ cap = cv2.VideoCapture(video_source)
+ frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ cap.release()
+ return frames
+ else:
+ return int(self.params.video_length.value)
+ self.params.video_source.change(update_frames, inputs=self.params.video_source, outputs=self.params.video_length)
+ self.params.video_path = gr.Textbox(
+ value=self.params.video_path,
+ label="Video path",
+ elem_id=f"{elemid_prefix}video-path"
+ )
+ if is_img2img:
+ with gr.Row():
+ self.params.latent_power = gr.Slider(
+ minimum=0.1,
+ maximum=10,
+ value=self.params.latent_power,
+ step=0.1,
+ label="Latent power",
+ elem_id=f"{elemid_prefix}latent-power",
+ )
+ self.params.latent_scale = gr.Slider(
+ minimum=1,
+ maximum=128,
+ value=self.params.latent_scale,
+ label="Latent scale",
+ elem_id=f"{elemid_prefix}latent-scale"
+ )
+ self.params.latent_power_last = gr.Slider(
+ minimum=0.1,
+ maximum=10,
+ value=self.params.latent_power_last,
+ step=0.1,
+ label="Optional latent power for last frame",
+ elem_id=f"{elemid_prefix}latent-power-last",
+ )
+ self.params.latent_scale_last = gr.Slider(
+ minimum=1,
+ maximum=128,
+ value=self.params.latent_scale_last,
+ label="Optional latent scale for last frame",
+ elem_id=f"{elemid_prefix}latent-scale-last"
+ )
+ self.params.last_frame = gr.Image(
+ label="Optional last frame. Leave it blank if you do not need one.",
+ type="pil",
+ )
+ with gr.Row():
+ unload = gr.Button(value="Move motion module to CPU (default if lowvram)")
+ remove = gr.Button(value="Remove motion module from any memory")
+ unload.click(fn=motion_module.unload)
+ remove.click(fn=motion_module.remove)
+
+ # Set up controls to be copy-pasted using infotext
+ fields = self.params.get_param_names(is_img2img)
+ infotext_fields.extend((getattr(self.params, field), f"AnimateDiff {field}") for field in fields)
+ paste_field_names.extend(f"AnimateDiff {field}" for field in fields)
+
+ return self.register_unit(is_img2img)
+
+
+ def register_unit(self, is_img2img: bool):
+ unit = gr.State(value=AnimateDiffProcess)
+ (
+ AnimateDiffUiGroup.img2img_submit_button
+ if is_img2img
+ else AnimateDiffUiGroup.txt2img_submit_button
+ ).click(
+ fn=AnimateDiffProcess,
+ inputs=self.params.get_list(is_img2img),
+ outputs=unit,
+ queue=False,
+ )
+ return unit
+
+
+ @staticmethod
+ def on_after_component(component, **_kwargs):
+ elem_id = getattr(component, "elem_id", None)
+
+ if elem_id == "txt2img_generate":
+ AnimateDiffUiGroup.txt2img_submit_button = component
+ return
+
+ if elem_id == "img2img_generate":
+ AnimateDiffUiGroup.img2img_submit_button = component
+ return
+
+
+ @staticmethod
+ def on_before_ui():
+ AnimateDiffLCM.hack_kdiff_ui()
diff --git a/sd-webui-animatediff/scripts/animatediff_xyz.py b/sd-webui-animatediff/scripts/animatediff_xyz.py
new file mode 100644
index 0000000000000000000000000000000000000000..b924a398af6f79f59b0633dbfef9db71374fc80c
--- /dev/null
+++ b/sd-webui-animatediff/scripts/animatediff_xyz.py
@@ -0,0 +1,126 @@
+import sys
+from types import ModuleType
+from typing import Optional
+
+from modules import scripts
+
+from scripts.animatediff_logger import logger_animatediff as logger
+
+xyz_attrs: dict = {}
+
+def patch_xyz():
+ xyz_module = find_xyz_module()
+ if xyz_module is None:
+ logger.warning("XYZ module not found.")
+ return
+ MODULE = "[AnimateDiff]"
+ xyz_module.axis_options.extend([
+ xyz_module.AxisOption(
+ label=f"{MODULE} Enabled",
+ type=str_to_bool,
+ apply=apply_state("enable"),
+ choices=choices_bool),
+ xyz_module.AxisOption(
+ label=f"{MODULE} Motion Module",
+ type=str,
+ apply=apply_state("model")),
+ xyz_module.AxisOption(
+ label=f"{MODULE} Video length",
+ type=int_or_float,
+ apply=apply_state("video_length")),
+ xyz_module.AxisOption(
+ label=f"{MODULE} FPS",
+ type=int_or_float,
+ apply=apply_state("fps")),
+ xyz_module.AxisOption(
+ label=f"{MODULE} Use main seed",
+ type=str_to_bool,
+ apply=apply_state("use_main_seed"),
+ choices=choices_bool),
+ xyz_module.AxisOption(
+ label=f"{MODULE} Closed loop",
+ type=str,
+ apply=apply_state("closed_loop"),
+ choices=lambda: ["N", "R-P", "R+P", "A"]),
+ xyz_module.AxisOption(
+ label=f"{MODULE} Batch size",
+ type=int_or_float,
+ apply=apply_state("batch_size")),
+ xyz_module.AxisOption(
+ label=f"{MODULE} Stride",
+ type=int_or_float,
+ apply=apply_state("stride")),
+ xyz_module.AxisOption(
+ label=f"{MODULE} Overlap",
+ type=int_or_float,
+ apply=apply_state("overlap")),
+ xyz_module.AxisOption(
+ label=f"{MODULE} Interp",
+ type=str_to_bool,
+ apply=apply_state("interp"),
+ choices=choices_bool),
+ xyz_module.AxisOption(
+ label=f"{MODULE} Interp X",
+ type=int_or_float,
+ apply=apply_state("interp_x")),
+ xyz_module.AxisOption(
+ label=f"{MODULE} Video path",
+ type=str,
+ apply=apply_state("video_path")),
+ xyz_module.AxisOptionImg2Img(
+ label=f"{MODULE} Latent power",
+ type=int_or_float,
+ apply=apply_state("latent_power")),
+ xyz_module.AxisOptionImg2Img(
+ label=f"{MODULE} Latent scale",
+ type=int_or_float,
+ apply=apply_state("latent_scale")),
+ xyz_module.AxisOptionImg2Img(
+ label=f"{MODULE} Latent power last",
+ type=int_or_float,
+ apply=apply_state("latent_power_last")),
+ xyz_module.AxisOptionImg2Img(
+ label=f"{MODULE} Latent scale last",
+ type=int_or_float,
+ apply=apply_state("latent_scale_last")),
+ ])
+
+
+def apply_state(k, key_map=None):
+ def callback(_p, v, _vs):
+ if key_map is not None:
+ v = key_map[v]
+ xyz_attrs[k] = v
+
+ return callback
+
+
+def str_to_bool(string):
+ string = str(string)
+ if string in ["None", ""]:
+ return None
+ elif string.lower() in ["true", "1"]:
+ return True
+ elif string.lower() in ["false", "0"]:
+ return False
+ else:
+ raise ValueError(f"Could not convert string to boolean: {string}")
+
+
+def int_or_float(string):
+ try:
+ return int(string)
+ except ValueError:
+ return float(string)
+
+
+def choices_bool():
+ return ["False", "True"]
+
+
+def find_xyz_module() -> Optional[ModuleType]:
+ for data in scripts.scripts_data:
+ if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py"} and hasattr(data, "module"):
+ return data.module
+
+ return None
diff --git a/sd-webui-controlnet/.github/ISSUE_TEMPLATE/bug_report.yml b/sd-webui-controlnet/.github/ISSUE_TEMPLATE/bug_report.yml
new file mode 100644
index 0000000000000000000000000000000000000000..ce58f6775fbefa358a2bf50562901ebe4898818c
--- /dev/null
+++ b/sd-webui-controlnet/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -0,0 +1,91 @@
+name: Bug Report
+description: Create a report
+title: "[Bug]: "
+labels: ["bug-report"]
+
+body:
+ - type: checkboxes
+ attributes:
+ label: Is there an existing issue for this?
+ description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
+ options:
+ - label: I have searched the existing issues and checked the recent builds/commits of both this extension and the webui
+ required: true
+ - type: markdown
+ attributes:
+ value: |
+ *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
+ - type: textarea
+ id: what-did
+ attributes:
+ label: What happened?
+ description: Tell us what happened in a very clear and simple way
+ validations:
+ required: true
+ - type: textarea
+ id: steps
+ attributes:
+ label: Steps to reproduce the problem
+ description: Please provide us with precise step by step information on how to reproduce the bug
+ value: |
+ 1. Go to ....
+ 2. Press ....
+ 3. ...
+ validations:
+ required: true
+ - type: textarea
+ id: what-should
+ attributes:
+ label: What should have happened?
+ description: Tell what you think the normal behavior should be
+ validations:
+ required: true
+ - type: textarea
+ id: commits
+ attributes:
+ label: Commit where the problem happens
+ description: Which commit of the extension are you running on? Please include the commit of both the extension and the webui (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
+ value: |
+ webui:
+ controlnet:
+ validations:
+ required: true
+ - type: dropdown
+ id: browsers
+ attributes:
+ label: What browsers do you use to access the UI ?
+ multiple: true
+ options:
+ - Mozilla Firefox
+ - Google Chrome
+ - Brave
+ - Apple Safari
+ - Microsoft Edge
+ - type: textarea
+ id: cmdargs
+ attributes:
+ label: Command Line Arguments
+ description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
+ render: Shell
+ validations:
+ required: true
+ - type: textarea
+ id: extensions
+ attributes:
+ label: List of enabled extensions
+ description: Please provide a full list of enabled extensions or screenshots of your "Extensions" tab.
+ validations:
+ required: true
+ - type: textarea
+ id: logs
+ attributes:
+ label: Console logs
+ description: Please provide full cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
+ render: Shell
+ validations:
+ required: true
+ - type: textarea
+ id: misc
+ attributes:
+ label: Additional information
+ description: Please provide us with any relevant additional info or context.
diff --git a/sd-webui-controlnet/.github/ISSUE_TEMPLATE/config.yml b/sd-webui-controlnet/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0086358db1eb971c0cfa8739c27518bbc18a5ff4
--- /dev/null
+++ b/sd-webui-controlnet/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1 @@
+blank_issues_enabled: true
diff --git a/sd-webui-controlnet/.github/workflows/tests.yml b/sd-webui-controlnet/.github/workflows/tests.yml
new file mode 100644
index 0000000000000000000000000000000000000000..51f19e2552f6fafdd70a58e09fdae82e69bc16e2
--- /dev/null
+++ b/sd-webui-controlnet/.github/workflows/tests.yml
@@ -0,0 +1,114 @@
+name: Run basic features tests on CPU
+
+on:
+ - push
+ - pull_request
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout Code
+ uses: actions/checkout@v3
+ with:
+ repository: 'AUTOMATIC1111/stable-diffusion-webui'
+ path: 'stable-diffusion-webui'
+ ref: '4afaaf8a020c1df457bcf7250cb1c7f609699fa7'
+ - name: Checkout Code
+ uses: actions/checkout@v3
+ with:
+ repository: 'Mikubill/sd-webui-controlnet'
+ path: 'stable-diffusion-webui/extensions/sd-webui-controlnet'
+ - name: Set up Python 3.10
+ uses: actions/setup-python@v4
+ with:
+ python-version: 3.10.6
+ cache: pip
+ cache-dependency-path: |
+ **/requirements*txt
+ launch.py
+ - name: Install test dependencies
+ run: |
+ pip install wait-for-it
+ pip install -r requirements-test.txt
+ working-directory: stable-diffusion-webui
+ env:
+ PIP_DISABLE_PIP_VERSION_CHECK: "1"
+ PIP_PROGRESS_BAR: "off"
+ - name: Setup environment
+ run: python launch.py --skip-torch-cuda-test --exit
+ working-directory: stable-diffusion-webui
+ env:
+ PIP_DISABLE_PIP_VERSION_CHECK: "1"
+ PIP_PROGRESS_BAR: "off"
+ TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
+ WEBUI_LAUNCH_LIVE_OUTPUT: "1"
+ PYTHONUNBUFFERED: "1"
+ - name: Cache ControlNet models
+ uses: actions/cache@v3
+ with:
+ path: stable-diffusion-webui/extensions/sd-webui-controlnet/models/
+ key: controlnet-models-v3
+ - name: Cache Preprocessor models
+ uses: actions/cache@v3
+ with:
+ path: stable-diffusion-webui/extensions/sd-webui-controlnet/annotator/downloads/
+ key: preprocessor-models-v1
+ - name: Download controlnet model for testing
+ run: |
+ declare -a urls=(
+ "https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny.pth"
+ "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-full-face_sd15.safetensors"
+ "https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus-face_sd15.safetensors"
+ "https://huggingface.co/TencentARC/T2I-Adapter/resolve/main/models/t2iadapter_canny_sd15v2.pth"
+ "https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_lora_rank128_v11p_sd15_canny_fp16.safetensors"
+ )
+
+ for url in "${urls[@]}"; do
+ filename="extensions/sd-webui-controlnet/models/${url##*/}" # Extracts the last part of the URL
+ if [ ! -f "$filename" ]; then
+ curl -Lo "$filename" "$url"
+ fi
+ done
+ working-directory: stable-diffusion-webui
+ - name: Start test server
+ run: >
+ python -m coverage run
+ --data-file=.coverage.server
+ launch.py
+ --skip-prepare-environment
+ --skip-torch-cuda-test
+ --test-server
+ --do-not-download-clip
+ --no-half
+ --disable-opt-split-attention
+ --use-cpu all
+ --api-server-stop
+ 2>&1 | tee output.txt &
+ working-directory: stable-diffusion-webui
+ - name: Run tests
+ run: |
+ wait-for-it --service 127.0.0.1:7860 -t 600
+ python -m pytest -vv --junitxml=test/results.xml --cov ./extensions/sd-webui-controlnet --cov-report=xml --verify-base-url ./extensions/sd-webui-controlnet/tests
+ working-directory: stable-diffusion-webui
+ - name: Kill test server
+ if: always()
+ run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
+ - name: Show coverage
+ run: |
+ python -m coverage combine .coverage*
+ python -m coverage report -i
+ python -m coverage html -i
+ working-directory: stable-diffusion-webui
+ - name: Upload main app output
+ uses: actions/upload-artifact@v3
+ if: always()
+ with:
+ name: output
+ path: stable-diffusion-webui/output.txt
+ - name: Upload coverage HTML
+ uses: actions/upload-artifact@v3
+ if: always()
+ with:
+ name: htmlcov
+ path: stable-diffusion-webui/htmlcov
diff --git a/sd-webui-controlnet/.gitignore b/sd-webui-controlnet/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..60d06e51ec71848d6700eac9c6f3db544ef3c1a0
--- /dev/null
+++ b/sd-webui-controlnet/.gitignore
@@ -0,0 +1,185 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea
+*.pt
+*.pth
+*.ckpt
+*.bin
+*.safetensors
+
+# Editor setting metadata
+.idea/
+.vscode/
+detected_maps/
+annotator/downloads/
+
+# test results and expectations
+web_tests/results/
+web_tests/expectations/
+tests/web_api/full_coverage/results/
+tests/web_api/full_coverage/expectations/
+
+*_diff.png
+
+# Presets
+presets/
+
+# Ignore existing dir of hand refiner if exists.
+annotator/hand_refiner_portable
\ No newline at end of file
diff --git a/sd-webui-controlnet/LICENSE b/sd-webui-controlnet/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..f288702d2fa16d3cdf0035b15a9fcbc552cd88e7
--- /dev/null
+++ b/sd-webui-controlnet/LICENSE
@@ -0,0 +1,674 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users. We, the Free Software Foundation, use the
+GNU General Public License for most of our software; it applies also to
+any other work released this way by its authors. You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you have
+certain responsibilities if you distribute copies of the software, or if
+you modify it: responsibilities to respect the freedom of others.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+ Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+ For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+ Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the manufacturer
+can do so. This is fundamentally incompatible with the aim of
+protecting users' freedom to change the software. The systematic
+pattern of such abuse occurs in the area of products for individuals to
+use, which is precisely where it is most unacceptable. Therefore, we
+have designed this version of the GPL to prohibit the practice for those
+products. If such problems arise substantially in other domains, we
+stand ready to extend this provision to those domains in future versions
+of the GPL, as needed to protect the freedom of users.
+
+ Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish to
+avoid the special danger that patents applied to a free program could
+make it effectively proprietary. To prevent this, the GPL assures that
+patents cannot be used to render the program non-free.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Use with the GNU Affero General Public License.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, your program's commands
+might be different; for a GUI interface, you would use an "about box".
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU GPL, see
+.
+
+ The GNU General Public License does not permit incorporating your program
+into proprietary programs. If your program is a subroutine library, you
+may consider it more useful to permit linking proprietary applications with
+the library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License. But first, please read
+.
diff --git a/sd-webui-controlnet/README.md b/sd-webui-controlnet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..38460ecaea68e1b6912f1b88b3fd6163e47bf1f5
--- /dev/null
+++ b/sd-webui-controlnet/README.md
@@ -0,0 +1,243 @@
+# ControlNet for Stable Diffusion WebUI
+
+The WebUI extension for ControlNet and other injection-based SD controls.
+
+![image](https://github.com/Mikubill/sd-webui-controlnet/assets/20929282/51172d20-606b-4b9f-aba5-db2f2417cb0b)
+
+This extension is for AUTOMATIC1111's [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui), allows the Web UI to add [ControlNet](https://github.com/lllyasviel/ControlNet) to the original Stable Diffusion model to generate images. The addition is on-the-fly, the merging is not required.
+
+# Installation
+
+1. Open "Extensions" tab.
+2. Open "Install from URL" tab in the tab.
+3. Enter `https://github.com/Mikubill/sd-webui-controlnet.git` to "URL for extension's git repository".
+4. Press "Install" button.
+5. Wait for 5 seconds, and you will see the message "Installed into stable-diffusion-webui\extensions\sd-webui-controlnet. Use Installed tab to restart".
+6. Go to "Installed" tab, click "Check for updates", and then click "Apply and restart UI". (The next time you can also use these buttons to update ControlNet.)
+7. Completely restart A1111 webui including your terminal. (If you do not know what is a "terminal", you can reboot your computer to achieve the same effect.)
+8. Download models (see below).
+9. After you put models in the correct folder, you may need to refresh to see the models. The refresh button is right to your "Model" dropdown.
+
+# Download Models
+
+Right now all the 14 models of ControlNet 1.1 are in the beta test.
+
+Download the models from ControlNet 1.1: https://huggingface.co/lllyasviel/ControlNet-v1-1/tree/main
+
+You need to download model files ending with ".pth" .
+
+Put models in your "stable-diffusion-webui\extensions\sd-webui-controlnet\models". You only need to download "pth" files.
+
+Do not right-click the filenames in HuggingFace website to download. Some users right-clicked those HuggingFace HTML websites and saved those HTML pages as PTH/YAML files. They are not downloading correct files. Instead, please click the small download arrow “↓” icon in HuggingFace to download.
+
+# Download Models for SDXL
+
+See instructions [here](https://github.com/Mikubill/sd-webui-controlnet/discussions/2039).
+
+# Features in ControlNet 1.1
+
+### Perfect Support for All ControlNet 1.0/1.1 and T2I Adapter Models.
+
+Now we have perfect support all available models and preprocessors, including perfect support for T2I style adapter and ControlNet 1.1 Shuffle. (Make sure that your YAML file names and model file names are same, see also YAML files in "stable-diffusion-webui\extensions\sd-webui-controlnet\models".)
+
+### Perfect Support for A1111 High-Res. Fix
+
+Now if you turn on High-Res Fix in A1111, each controlnet will output two different control images: a small one and a large one. The small one is for your basic generating, and the big one is for your High-Res Fix generating. The two control images are computed by a smart algorithm called "super high-quality control image resampling". This is turned on by default, and you do not need to change any setting.
+
+### Perfect Support for All A1111 Img2Img or Inpaint Settings and All Mask Types
+
+Now ControlNet is extensively tested with A1111's different types of masks, including "Inpaint masked"/"Inpaint not masked", and "Whole picture"/"Only masked", and "Only masked padding"&"Mask blur". The resizing perfectly matches A1111's "Just resize"/"Crop and resize"/"Resize and fill". This means you can use ControlNet in nearly everywhere in your A1111 UI without difficulty!
+
+### The New "Pixel-Perfect" Mode
+
+Now if you turn on pixel-perfect mode, you do not need to set preprocessor (annotator) resolutions manually. The ControlNet will automatically compute the best annotator resolution for you so that each pixel perfectly matches Stable Diffusion.
+
+### User-Friendly GUI and Preprocessor Preview
+
+We reorganized some previously confusing UI like "canvas width/height for new canvas" and it is in the 📝 button now. Now the preview GUI is controlled by the "allow preview" option and the trigger button 💥. The preview image size is better than before, and you do not need to scroll up and down - your a1111 GUI will not be messed up anymore!
+
+### Support for Almost All Upscaling Scripts
+
+Now ControlNet 1.1 can support almost all Upscaling/Tile methods. ControlNet 1.1 support the script "Ultimate SD upscale" and almost all other tile-based extensions. Please do not confuse ["Ultimate SD upscale"](https://github.com/Coyote-A/ultimate-upscale-for-automatic1111) with "SD upscale" - they are different scripts. Note that the most recommended upscaling method is ["Tiled VAE/Diffusion"](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111) but we test as many methods/extensions as possible. Note that "SD upscale" is supported since 1.1.117, and if you use it, you need to leave all ControlNet images as blank (We do not recommend "SD upscale" since it is somewhat buggy and cannot be maintained - use the "Ultimate SD upscale" instead).
+
+### More Control Modes (previously called Guess Mode)
+
+We have fixed many bugs in previous 1.0’s Guess Mode and now it is called Control Mode
+
+![image](https://user-images.githubusercontent.com/19834515/236641759-6c44ddf6-c7ad-4bda-92be-e90a52911d75.png)
+
+Now you can control which aspect is more important (your prompt or your ControlNet):
+
+* "Balanced": ControlNet on both sides of CFG scale, same as turning off "Guess Mode" in ControlNet 1.0
+
+* "My prompt is more important": ControlNet on both sides of CFG scale, with progressively reduced SD U-Net injections (layer_weight*=0.825**I, where 0<=I <13, and the 13 means ControlNet injected SD 13 times). In this way, you can make sure that your prompts are perfectly displayed in your generated images.
+
+* "ControlNet is more important": ControlNet only on the Conditional Side of CFG scale (the cond in A1111's batch-cond-uncond). This means the ControlNet will be X times stronger if your cfg-scale is X. For example, if your cfg-scale is 7, then ControlNet is 7 times stronger. Note that here the X times stronger is different from "Control Weights" since your weights are not modified. This "stronger" effect usually has less artifact and give ControlNet more room to guess what is missing from your prompts (and in the previous 1.0, it is called "Guess Mode").
+
+
+
+
Input (depth+canny+hed)
+
"Balanced"
+
"My prompt is more important"
+
"ControlNet is more important"
+
+
+
+
+
+
+
+
+
+### Reference-Only Control
+
+Now we have a `reference-only` preprocessor that does not require any control models. It can guide the diffusion directly using images as references.
+
+(Prompt "a dog running on grassland, best quality, ...")
+
+![image](samples/ref.png)
+
+This method is similar to inpaint-based reference but it does not make your image disordered.
+
+Many professional A1111 users know a trick to diffuse image with references by inpaint. For example, if you have a 512x512 image of a dog, and want to generate another 512x512 image with the same dog, some users will connect the 512x512 dog image and a 512x512 blank image into a 1024x512 image, send to inpaint, and mask out the blank 512x512 part to diffuse a dog with similar appearance. However, that method is usually not very satisfying since images are connected and many distortions will appear.
+
+This `reference-only` ControlNet can directly link the attention layers of your SD to any independent images, so that your SD will read arbitrary images for reference. You need at least ControlNet 1.1.153 to use it.
+
+To use, just select `reference-only` as preprocessor and put an image. Your SD will just use the image as reference.
+
+*Note that this method is as "non-opinioned" as possible. It only contains very basic connection codes, without any personal preferences, to connect the attention layers with your reference images. However, even if we tried best to not include any opinioned codes, we still need to write some subjective implementations to deal with weighting, cfg-scale, etc - tech report is on the way.*
+
+More examples [here](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236).
+
+# Technical Documents
+
+See also the documents of ControlNet 1.1:
+
+https://github.com/lllyasviel/ControlNet-v1-1-nightly#model-specification
+
+# Default Setting
+
+This is my setting. If you run into any problem, you can use this setting as a sanity check
+
+![image](https://user-images.githubusercontent.com/19834515/235620638-17937171-8ac1-45bc-a3cb-3aebf605b4ef.png)
+
+# Use Previous Models
+
+### Use ControlNet 1.0 Models
+
+https://huggingface.co/lllyasviel/ControlNet/tree/main/models
+
+You can still use all previous models in the previous ControlNet 1.0. Now, the previous "depth" is now called "depth_midas", the previous "normal" is called "normal_midas", the previous "hed" is called "softedge_hed". And starting from 1.1, all line maps, edge maps, lineart maps, boundary maps will have black background and white lines.
+
+### Use T2I-Adapter Models
+
+(From TencentARC/T2I-Adapter)
+
+To use T2I-Adapter models:
+
+1. Download files from https://huggingface.co/TencentARC/T2I-Adapter/tree/main/models
+2. Put them in "stable-diffusion-webui\extensions\sd-webui-controlnet\models".
+3. Make sure that the file names of pth files and yaml files are consistent.
+
+*Note that "CoAdapter" is not implemented yet.*
+
+# Gallery
+
+The below results are from ControlNet 1.0.
+
+| Source | Input | Output |
+|:-------------------------:|:-------------------------:|:-------------------------:|
+| (no preprocessor) | | |
+| (no preprocessor) | | |
+| | | |
+| | | |
+| | | |
+| | | |
+
+The below examples are from T2I-Adapter.
+
+From `t2iadapter_color_sd14v1.pth` :
+
+| Source | Input | Output |
+|:-------------------------:|:-------------------------:|:-------------------------:|
+| | | |
+
+From `t2iadapter_style_sd14v1.pth` :
+
+| Source | Input | Output |
+|:-------------------------:|:-------------------------:|:-------------------------:|
+| | (clip, non-image) | |
+
+# Minimum Requirements
+
+* (Windows) (NVIDIA: Ampere) 4gb - with `--xformers` enabled, and `Low VRAM` mode ticked in the UI, goes up to 768x832
+
+# Multi-ControlNet
+
+This option allows multiple ControlNet inputs for a single generation. To enable this option, change `Multi ControlNet: Max models amount (requires restart)` in the settings. Note that you will need to restart the WebUI for changes to take effect.
+
+
+
+
Source A
+
Source B
+
Output
+
+
+
+
+
+
+
+
+# Control Weight/Start/End
+
+Weight is the weight of the controlnet "influence". It's analogous to prompt attention/emphasis. E.g. (myprompt: 1.2). Technically, it's the factor by which to multiply the ControlNet outputs before merging them with original SD Unet.
+
+Guidance Start/End is the percentage of total steps the controlnet applies (guidance strength = guidance end). It's analogous to prompt editing/shifting. E.g. \[myprompt::0.8\] (It applies from the beginning until 80% of total steps)
+
+# Batch Mode
+
+Put any unit into batch mode to activate batch mode for all units. Specify a batch directory for each unit, or use the new textbox in the img2img batch tab as a fallback. Although the textbox is located in the img2img batch tab, you can use it to generate images in the txt2img tab as well.
+
+Note that this feature is only available in the gradio user interface. Call the APIs as many times as you want for custom batch scheduling.
+
+# API and Script Access
+
+This extension can accept txt2img or img2img tasks via API or external extension call. Note that you may need to enable `Allow other scripts to control this extension` in settings for external calls.
+
+To use the API: start WebUI with argument `--api` and go to `http://webui-address/docs` for documents or checkout [examples](https://github.com/Mikubill/sd-webui-controlnet/blob/main/example/txt2img_example/api_txt2img.py).
+
+To use external call: Checkout [Wiki](https://github.com/Mikubill/sd-webui-controlnet/wiki/API)
+
+# Command Line Arguments
+
+This extension adds these command line arguments to the webui:
+
+```
+ --controlnet-dir ADD a controlnet models directory
+ --controlnet-annotator-models-path SET the directory for annotator models
+ --no-half-controlnet load controlnet models in full precision
+ --controlnet-preprocessor-cache-size Cache size for controlnet preprocessor results
+ --controlnet-loglevel Log level for the controlnet extension
+ --controlnet-tracemalloc Enable malloc memory tracing
+```
+
+# MacOS Support
+
+Tested with pytorch nightly: https://github.com/Mikubill/sd-webui-controlnet/pull/143#issuecomment-1435058285
+
+To use this extension with mps and normal pytorch, currently you may need to start WebUI with `--no-half`.
+
+# Archive of Deprecated Versions
+
+The previous version (sd-webui-controlnet 1.0) is archived in
+
+https://github.com/lllyasviel/webui-controlnet-v1-archived
+
+Using this version is not a temporary stop of updates. You will stop all updates forever.
+
+Please consider this version if you work with professional studios that requires 100% reproducing of all previous results pixel by pixel.
+
+# Thanks
+
+This implementation is inspired by kohya-ss/sd-webui-additional-networks
diff --git a/sd-webui-controlnet/annotator/anime_face_segment/LICENSE b/sd-webui-controlnet/annotator/anime_face_segment/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..9bad05450ca061904f97acebe04ff7183cfbdc1a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/anime_face_segment/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 Miaomiao Li
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/sd-webui-controlnet/annotator/anime_face_segment/__init__.py b/sd-webui-controlnet/annotator/anime_face_segment/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..466feca11c0e26be0af707257b100e819eaf5b7f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/anime_face_segment/__init__.py
@@ -0,0 +1,172 @@
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+import fnmatch
+import cv2
+
+import sys
+
+import numpy as np
+from modules import devices
+from einops import rearrange
+from annotator.annotator_path import models_path
+
+import torchvision
+from torchvision.models import MobileNet_V2_Weights
+from torchvision import transforms
+
+COLOR_BACKGROUND = (255,255,0)
+COLOR_HAIR = (0,0,255)
+COLOR_EYE = (255,0,0)
+COLOR_MOUTH = (255,255,255)
+COLOR_FACE = (0,255,0)
+COLOR_SKIN = (0,255,255)
+COLOR_CLOTHES = (255,0,255)
+PALETTE = [COLOR_BACKGROUND,COLOR_HAIR,COLOR_EYE,COLOR_MOUTH,COLOR_FACE,COLOR_SKIN,COLOR_CLOTHES]
+
+class UNet(nn.Module):
+ def __init__(self):
+ super(UNet, self).__init__()
+ self.NUM_SEG_CLASSES = 7 # Background, hair, face, eye, mouth, skin, clothes
+
+ mobilenet_v2 = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
+ mob_blocks = mobilenet_v2.features
+
+ # Encoder
+ self.en_block0 = nn.Sequential( # in_ch=3 out_ch=16
+ mob_blocks[0],
+ mob_blocks[1]
+ )
+ self.en_block1 = nn.Sequential( # in_ch=16 out_ch=24
+ mob_blocks[2],
+ mob_blocks[3],
+ )
+ self.en_block2 = nn.Sequential( # in_ch=24 out_ch=32
+ mob_blocks[4],
+ mob_blocks[5],
+ mob_blocks[6],
+ )
+ self.en_block3 = nn.Sequential( # in_ch=32 out_ch=96
+ mob_blocks[7],
+ mob_blocks[8],
+ mob_blocks[9],
+ mob_blocks[10],
+ mob_blocks[11],
+ mob_blocks[12],
+ mob_blocks[13],
+ )
+ self.en_block4 = nn.Sequential( # in_ch=96 out_ch=160
+ mob_blocks[14],
+ mob_blocks[15],
+ mob_blocks[16],
+ )
+
+ # Decoder
+ self.de_block4 = nn.Sequential( # in_ch=160 out_ch=96
+ nn.UpsamplingNearest2d(scale_factor=2),
+ nn.Conv2d(160, 96, kernel_size=3, padding=1),
+ nn.InstanceNorm2d(96),
+ nn.LeakyReLU(0.1),
+ nn.Dropout(p=0.2)
+ )
+ self.de_block3 = nn.Sequential( # in_ch=96x2 out_ch=32
+ nn.UpsamplingNearest2d(scale_factor=2),
+ nn.Conv2d(96*2, 32, kernel_size=3, padding=1),
+ nn.InstanceNorm2d(32),
+ nn.LeakyReLU(0.1),
+ nn.Dropout(p=0.2)
+ )
+ self.de_block2 = nn.Sequential( # in_ch=32x2 out_ch=24
+ nn.UpsamplingNearest2d(scale_factor=2),
+ nn.Conv2d(32*2, 24, kernel_size=3, padding=1),
+ nn.InstanceNorm2d(24),
+ nn.LeakyReLU(0.1),
+ nn.Dropout(p=0.2)
+ )
+ self.de_block1 = nn.Sequential( # in_ch=24x2 out_ch=16
+ nn.UpsamplingNearest2d(scale_factor=2),
+ nn.Conv2d(24*2, 16, kernel_size=3, padding=1),
+ nn.InstanceNorm2d(16),
+ nn.LeakyReLU(0.1),
+ nn.Dropout(p=0.2)
+ )
+
+ self.de_block0 = nn.Sequential( # in_ch=16x2 out_ch=7
+ nn.UpsamplingNearest2d(scale_factor=2),
+ nn.Conv2d(16*2, self.NUM_SEG_CLASSES, kernel_size=3, padding=1),
+ nn.Softmax2d()
+ )
+
+ def forward(self, x):
+ e0 = self.en_block0(x)
+ e1 = self.en_block1(e0)
+ e2 = self.en_block2(e1)
+ e3 = self.en_block3(e2)
+ e4 = self.en_block4(e3)
+
+ d4 = self.de_block4(e4)
+ d4 = F.interpolate(d4, size=e3.size()[2:], mode='bilinear', align_corners=True)
+ c4 = torch.cat((d4,e3),1)
+
+ d3 = self.de_block3(c4)
+ d3 = F.interpolate(d3, size=e2.size()[2:], mode='bilinear', align_corners=True)
+ c3 = torch.cat((d3,e2),1)
+
+ d2 = self.de_block2(c3)
+ d2 = F.interpolate(d2, size=e1.size()[2:], mode='bilinear', align_corners=True)
+ c2 =torch.cat((d2,e1),1)
+
+ d1 = self.de_block1(c2)
+ d1 = F.interpolate(d1, size=e0.size()[2:], mode='bilinear', align_corners=True)
+ c1 = torch.cat((d1,e0),1)
+ y = self.de_block0(c1)
+
+ return y
+
+
+class AnimeFaceSegment:
+
+ model_dir = os.path.join(models_path, "anime_face_segment")
+
+ def __init__(self):
+ self.model = None
+ self.device = devices.get_device_for("controlnet")
+
+ def load_model(self):
+ remote_model_path = "https://huggingface.co/bdsqlsz/qinglong_controlnet-lllite/resolve/main/Annotators/UNet.pth"
+ modelpath = os.path.join(self.model_dir, "UNet.pth")
+ if not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=self.model_dir)
+ net = UNet()
+ ckpt = torch.load(modelpath, map_location=self.device)
+ for key in list(ckpt.keys()):
+ if 'module.' in key:
+ ckpt[key.replace('module.', '')] = ckpt[key]
+ del ckpt[key]
+ net.load_state_dict(ckpt)
+ net.eval()
+ self.model = net.to(self.device)
+
+ def unload_model(self):
+ if self.model is not None:
+ self.model.cpu()
+
+ def __call__(self, input_image):
+
+ if self.model is None:
+ self.load_model()
+ self.model.to(self.device)
+ transform = transforms.Compose([
+ transforms.Resize(512,interpolation=transforms.InterpolationMode.BICUBIC),
+ transforms.ToTensor(),])
+ img = Image.fromarray(input_image)
+ with torch.no_grad():
+ img = transform(img).unsqueeze(dim=0).to(self.device)
+ seg = self.model(img).squeeze(dim=0)
+ seg = seg.cpu().detach().numpy()
+ img = rearrange(seg,'h w c -> w c h')
+ img = [[PALETTE[np.argmax(val)] for val in buf]for buf in img]
+ return np.array(img).astype(np.uint8)
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/annotator_path.py b/sd-webui-controlnet/annotator/annotator_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba168e19cf0eb7f7dae6ac3d54c5977945e7386a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/annotator_path.py
@@ -0,0 +1,22 @@
+import os
+from modules import shared
+
+models_path = shared.opts.data.get('control_net_modules_path', None)
+if not models_path:
+ models_path = getattr(shared.cmd_opts, 'controlnet_annotator_models_path', None)
+if not models_path:
+ models_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'downloads')
+
+if not os.path.isabs(models_path):
+ models_path = os.path.join(shared.data_path, models_path)
+
+clip_vision_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'clip_vision')
+# clip vision is always inside controlnet "extensions\sd-webui-controlnet"
+# and any problem can be solved by removing controlnet and reinstall
+
+models_path = os.path.realpath(models_path)
+os.makedirs(models_path, exist_ok=True)
+print(f'ControlNet preprocessor location: {models_path}')
+# Make sure that the default location is inside controlnet "extensions\sd-webui-controlnet"
+# so that any problem can be solved by removing controlnet and reinstall
+# if users do not change configs on their own (otherwise users will know what is wrong)
diff --git a/sd-webui-controlnet/annotator/binary/__init__.py b/sd-webui-controlnet/annotator/binary/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d13ad692ffc109ad95789334bb5524d52794acc
--- /dev/null
+++ b/sd-webui-controlnet/annotator/binary/__init__.py
@@ -0,0 +1,14 @@
+import cv2
+
+
+def apply_binary(img, bin_threshold):
+ img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+
+ if bin_threshold == 0 or bin_threshold == 255:
+ # Otsu's threshold
+ otsu_threshold, img_bin = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
+ print("Otsu threshold:", otsu_threshold)
+ else:
+ _, img_bin = cv2.threshold(img_gray, bin_threshold, 255, cv2.THRESH_BINARY_INV)
+
+ return cv2.cvtColor(img_bin, cv2.COLOR_GRAY2RGB)
diff --git a/sd-webui-controlnet/annotator/canny/__init__.py b/sd-webui-controlnet/annotator/canny/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ace985839d3fc18dd4947f6c38e9f5d5a2625aca
--- /dev/null
+++ b/sd-webui-controlnet/annotator/canny/__init__.py
@@ -0,0 +1,5 @@
+import cv2
+
+
+def apply_canny(img, low_threshold, high_threshold):
+ return cv2.Canny(img, low_threshold, high_threshold)
diff --git a/sd-webui-controlnet/annotator/clipvision/__init__.py b/sd-webui-controlnet/annotator/clipvision/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..efb2f0508204ddff0d51c2223d16d205a05dd714
--- /dev/null
+++ b/sd-webui-controlnet/annotator/clipvision/__init__.py
@@ -0,0 +1,137 @@
+import os
+import cv2
+import torch
+
+from modules import devices
+from annotator.annotator_path import models_path
+from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor
+
+try:
+ from modules.modelloader import load_file_from_url
+except ImportError:
+ # backward compability for webui < 1.5.0
+ from basicsr.utils.download_util import load_file_from_url
+
+config_clip_g = {
+ "attention_dropout": 0.0,
+ "dropout": 0.0,
+ "hidden_act": "gelu",
+ "hidden_size": 1664,
+ "image_size": 224,
+ "initializer_factor": 1.0,
+ "initializer_range": 0.02,
+ "intermediate_size": 8192,
+ "layer_norm_eps": 1e-05,
+ "model_type": "clip_vision_model",
+ "num_attention_heads": 16,
+ "num_channels": 3,
+ "num_hidden_layers": 48,
+ "patch_size": 14,
+ "projection_dim": 1280,
+ "torch_dtype": "float32"
+}
+
+config_clip_h = {
+ "attention_dropout": 0.0,
+ "dropout": 0.0,
+ "hidden_act": "gelu",
+ "hidden_size": 1280,
+ "image_size": 224,
+ "initializer_factor": 1.0,
+ "initializer_range": 0.02,
+ "intermediate_size": 5120,
+ "layer_norm_eps": 1e-05,
+ "model_type": "clip_vision_model",
+ "num_attention_heads": 16,
+ "num_channels": 3,
+ "num_hidden_layers": 32,
+ "patch_size": 14,
+ "projection_dim": 1024,
+ "torch_dtype": "float32"
+}
+
+config_clip_vitl = {
+ "attention_dropout": 0.0,
+ "dropout": 0.0,
+ "hidden_act": "quick_gelu",
+ "hidden_size": 1024,
+ "image_size": 224,
+ "initializer_factor": 1.0,
+ "initializer_range": 0.02,
+ "intermediate_size": 4096,
+ "layer_norm_eps": 1e-05,
+ "model_type": "clip_vision_model",
+ "num_attention_heads": 16,
+ "num_channels": 3,
+ "num_hidden_layers": 24,
+ "patch_size": 14,
+ "projection_dim": 768,
+ "torch_dtype": "float32"
+}
+
+configs = {
+ 'clip_g': config_clip_g,
+ 'clip_h': config_clip_h,
+ 'clip_vitl': config_clip_vitl,
+}
+
+downloads = {
+ 'clip_vitl': 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin',
+ 'clip_g': 'https://huggingface.co/lllyasviel/Annotators/resolve/main/clip_g.pth',
+ 'clip_h': 'https://huggingface.co/h94/IP-Adapter/resolve/main/models/image_encoder/pytorch_model.bin'
+}
+
+
+clip_vision_h_uc = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'clip_vision_h_uc.data')
+clip_vision_h_uc = torch.load(clip_vision_h_uc, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))['uc']
+
+clip_vision_vith_uc = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'clip_vision_vith_uc.data')
+clip_vision_vith_uc = torch.load(clip_vision_vith_uc, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))['uc']
+
+
+class ClipVisionDetector:
+ def __init__(self, config, low_vram: bool):
+ assert config in downloads
+ self.download_link = downloads[config]
+ self.model_path = os.path.join(models_path, 'clip_vision')
+ self.file_name = config + '.pth'
+ self.config = configs[config]
+ self.device = (
+ torch.device("cpu") if low_vram else
+ devices.get_device_for("controlnet")
+ )
+ os.makedirs(self.model_path, exist_ok=True)
+ file_path = os.path.join(self.model_path, self.file_name)
+ if not os.path.exists(file_path):
+ load_file_from_url(url=self.download_link, model_dir=self.model_path, file_name=self.file_name)
+ config = CLIPVisionConfig(**self.config)
+
+ self.model = CLIPVisionModelWithProjection(config)
+ self.processor = CLIPImageProcessor(crop_size=224,
+ do_center_crop=True,
+ do_convert_rgb=True,
+ do_normalize=True,
+ do_resize=True,
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
+ image_std=[0.26862954, 0.26130258, 0.27577711],
+ resample=3,
+ size=224)
+ sd = torch.load(file_path, map_location=self.device)
+ self.model.load_state_dict(sd, strict=False)
+ del sd
+ self.model.to(self.device)
+ self.model.eval()
+
+ def unload_model(self):
+ if self.model is not None:
+ self.model.to('meta')
+
+ def __call__(self, input_image):
+ with torch.no_grad():
+ input_image = cv2.resize(input_image, (224, 224), interpolation=cv2.INTER_AREA)
+ feat = self.processor(images=input_image, return_tensors="pt")
+ feat['pixel_values'] = feat['pixel_values'].to(self.device)
+ result = self.model(**feat, output_hidden_states=True)
+ result['hidden_states'] = [v.to(self.device) for v in result['hidden_states']]
+ result = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in result.items()}
+ return result
diff --git a/sd-webui-controlnet/annotator/clipvision/clip_vision_h_uc.data b/sd-webui-controlnet/annotator/clipvision/clip_vision_h_uc.data
new file mode 100644
index 0000000000000000000000000000000000000000..70c4a7bc9aeef7445c3974e2618c4a78745d3c9d
Binary files /dev/null and b/sd-webui-controlnet/annotator/clipvision/clip_vision_h_uc.data differ
diff --git a/sd-webui-controlnet/annotator/clipvision/clip_vision_vith_uc.data b/sd-webui-controlnet/annotator/clipvision/clip_vision_vith_uc.data
new file mode 100644
index 0000000000000000000000000000000000000000..0c0a61afe9f3a864dc46f3928ac621c33eeff303
Binary files /dev/null and b/sd-webui-controlnet/annotator/clipvision/clip_vision_vith_uc.data differ
diff --git a/sd-webui-controlnet/annotator/color/__init__.py b/sd-webui-controlnet/annotator/color/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..65799a2a83efd18dc556600c99d43292845aa6f2
--- /dev/null
+++ b/sd-webui-controlnet/annotator/color/__init__.py
@@ -0,0 +1,20 @@
+import cv2
+
+def cv2_resize_shortest_edge(image, size):
+ h, w = image.shape[:2]
+ if h < w:
+ new_h = size
+ new_w = int(round(w / h * size))
+ else:
+ new_w = size
+ new_h = int(round(h / w * size))
+ resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
+ return resized_image
+
+def apply_color(img, res=512):
+ img = cv2_resize_shortest_edge(img, res)
+ h, w = img.shape[:2]
+
+ input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC)
+ input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST)
+ return input_img_color
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/densepose/__init__.py b/sd-webui-controlnet/annotator/densepose/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc46bc2da929a8e31495d98be4bf5ba2e8cf0d0e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/densepose/__init__.py
@@ -0,0 +1,57 @@
+import torchvision # Fix issue Unknown builtin op: torchvision::nms
+import cv2
+import numpy as np
+import torch
+from einops import rearrange
+from .densepose import DensePoseMaskedColormapResultsVisualizer, _extract_i_from_iuvarr, densepose_chart_predictor_output_to_result_with_confidences
+from modules import devices
+from annotator.annotator_path import models_path
+import os
+
+N_PART_LABELS = 24
+result_visualizer = DensePoseMaskedColormapResultsVisualizer(
+ alpha=1,
+ data_extractor=_extract_i_from_iuvarr,
+ segm_extractor=_extract_i_from_iuvarr,
+ val_scale = 255.0 / N_PART_LABELS
+)
+remote_torchscript_path = "https://huggingface.co/LayerNorm/DensePose-TorchScript-with-hint-image/resolve/main/densepose_r50_fpn_dl.torchscript"
+torchscript_model = None
+model_dir = os.path.join(models_path, "densepose")
+
+def apply_densepose(input_image, cmap="viridis"):
+ global torchscript_model
+ if torchscript_model is None:
+ model_path = os.path.join(model_dir, "densepose_r50_fpn_dl.torchscript")
+ if not os.path.exists(model_path):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_torchscript_path, model_dir=model_dir)
+ torchscript_model = torch.jit.load(model_path, map_location="cpu").to(devices.get_device_for("controlnet")).eval()
+ H, W = input_image.shape[:2]
+
+ hint_image_canvas = np.zeros([H, W], dtype=np.uint8)
+ hint_image_canvas = np.tile(hint_image_canvas[:, :, np.newaxis], [1, 1, 3])
+ input_image = rearrange(torch.from_numpy(input_image).to(devices.get_device_for("controlnet")), 'h w c -> c h w')
+ pred_boxes, corase_segm, fine_segm, u, v = torchscript_model(input_image)
+
+ extractor = densepose_chart_predictor_output_to_result_with_confidences
+ densepose_results = [extractor(pred_boxes[i:i+1], corase_segm[i:i+1], fine_segm[i:i+1], u[i:i+1], v[i:i+1]) for i in range(len(pred_boxes))]
+
+ if cmap=="viridis":
+ result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_VIRIDIS
+ hint_image = result_visualizer.visualize(hint_image_canvas, densepose_results)
+ hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB)
+ hint_image[:, :, 0][hint_image[:, :, 0] == 0] = 68
+ hint_image[:, :, 1][hint_image[:, :, 1] == 0] = 1
+ hint_image[:, :, 2][hint_image[:, :, 2] == 0] = 84
+ else:
+ result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_PARULA
+ hint_image = result_visualizer.visualize(hint_image_canvas, densepose_results)
+ hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB)
+
+ return hint_image
+
+def unload_model():
+ global torchscript_model
+ if torchscript_model is not None:
+ torchscript_model.cpu()
diff --git a/sd-webui-controlnet/annotator/densepose/densepose.py b/sd-webui-controlnet/annotator/densepose/densepose.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e43b05fcd76efd5774485ca35e715e64acefdbe
--- /dev/null
+++ b/sd-webui-controlnet/annotator/densepose/densepose.py
@@ -0,0 +1,347 @@
+from typing import Tuple
+import math
+import numpy as np
+from enum import IntEnum
+from typing import List, Tuple, Union
+import torch
+from torch.nn import functional as F
+import logging
+import cv2
+
+Image = np.ndarray
+Boxes = torch.Tensor
+ImageSizeType = Tuple[int, int]
+_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray]
+IntTupleBox = Tuple[int, int, int, int]
+
+class BoxMode(IntEnum):
+ """
+ Enum of different ways to represent a box.
+ """
+
+ XYXY_ABS = 0
+ """
+ (x0, y0, x1, y1) in absolute floating points coordinates.
+ The coordinates in range [0, width or height].
+ """
+ XYWH_ABS = 1
+ """
+ (x0, y0, w, h) in absolute floating points coordinates.
+ """
+ XYXY_REL = 2
+ """
+ Not yet supported!
+ (x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image.
+ """
+ XYWH_REL = 3
+ """
+ Not yet supported!
+ (x0, y0, w, h) in range [0, 1]. They are relative to the size of the image.
+ """
+ XYWHA_ABS = 4
+ """
+ (xc, yc, w, h, a) in absolute floating points coordinates.
+ (xc, yc) is the center of the rotated box, and the angle a is in degrees ccw.
+ """
+
+ @staticmethod
+ def convert(box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode") -> _RawBoxType:
+ """
+ Args:
+ box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5
+ from_mode, to_mode (BoxMode)
+
+ Returns:
+ The converted box of the same type.
+ """
+ if from_mode == to_mode:
+ return box
+
+ original_type = type(box)
+ is_numpy = isinstance(box, np.ndarray)
+ single_box = isinstance(box, (list, tuple))
+ if single_box:
+ assert len(box) == 4 or len(box) == 5, (
+ "BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor,"
+ " where k == 4 or 5"
+ )
+ arr = torch.tensor(box)[None, :]
+ else:
+ # avoid modifying the input box
+ if is_numpy:
+ arr = torch.from_numpy(np.asarray(box)).clone()
+ else:
+ arr = box.clone()
+
+ assert to_mode not in [BoxMode.XYXY_REL, BoxMode.XYWH_REL] and from_mode not in [
+ BoxMode.XYXY_REL,
+ BoxMode.XYWH_REL,
+ ], "Relative mode not yet supported!"
+
+ if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS:
+ assert (
+ arr.shape[-1] == 5
+ ), "The last dimension of input shape must be 5 for XYWHA format"
+ original_dtype = arr.dtype
+ arr = arr.double()
+
+ w = arr[:, 2]
+ h = arr[:, 3]
+ a = arr[:, 4]
+ c = torch.abs(torch.cos(a * math.pi / 180.0))
+ s = torch.abs(torch.sin(a * math.pi / 180.0))
+ # This basically computes the horizontal bounding rectangle of the rotated box
+ new_w = c * w + s * h
+ new_h = c * h + s * w
+
+ # convert center to top-left corner
+ arr[:, 0] -= new_w / 2.0
+ arr[:, 1] -= new_h / 2.0
+ # bottom-right corner
+ arr[:, 2] = arr[:, 0] + new_w
+ arr[:, 3] = arr[:, 1] + new_h
+
+ arr = arr[:, :4].to(dtype=original_dtype)
+ elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS:
+ original_dtype = arr.dtype
+ arr = arr.double()
+ arr[:, 0] += arr[:, 2] / 2.0
+ arr[:, 1] += arr[:, 3] / 2.0
+ angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype)
+ arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype)
+ else:
+ if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS:
+ arr[:, 2] += arr[:, 0]
+ arr[:, 3] += arr[:, 1]
+ elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS:
+ arr[:, 2] -= arr[:, 0]
+ arr[:, 3] -= arr[:, 1]
+ else:
+ raise NotImplementedError(
+ "Conversion from BoxMode {} to {} is not supported yet".format(
+ from_mode, to_mode
+ )
+ )
+
+ if single_box:
+ return original_type(arr.flatten().tolist())
+ if is_numpy:
+ return arr.numpy()
+ else:
+ return arr
+
+class MatrixVisualizer:
+ """
+ Base visualizer for matrix data
+ """
+
+ def __init__(
+ self,
+ inplace=True,
+ cmap=cv2.COLORMAP_PARULA,
+ val_scale=1.0,
+ alpha=0.7,
+ interp_method_matrix=cv2.INTER_LINEAR,
+ interp_method_mask=cv2.INTER_NEAREST,
+ ):
+ self.inplace = inplace
+ self.cmap = cmap
+ self.val_scale = val_scale
+ self.alpha = alpha
+ self.interp_method_matrix = interp_method_matrix
+ self.interp_method_mask = interp_method_mask
+
+ def visualize(self, image_bgr, mask, matrix, bbox_xywh):
+ self._check_image(image_bgr)
+ self._check_mask_matrix(mask, matrix)
+ if self.inplace:
+ image_target_bgr = image_bgr
+ else:
+ image_target_bgr = image_bgr * 0
+ x, y, w, h = [int(v) for v in bbox_xywh]
+ if w <= 0 or h <= 0:
+ return image_bgr
+ mask, matrix = self._resize(mask, matrix, w, h)
+ mask_bg = np.tile((mask == 0)[:, :, np.newaxis], [1, 1, 3])
+ matrix_scaled = matrix.astype(np.float32) * self.val_scale
+ _EPSILON = 1e-6
+ if np.any(matrix_scaled > 255 + _EPSILON):
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ f"Matrix has values > {255 + _EPSILON} after " f"scaling, clipping to [0..255]"
+ )
+ matrix_scaled_8u = matrix_scaled.clip(0, 255).astype(np.uint8)
+ matrix_vis = cv2.applyColorMap(matrix_scaled_8u, self.cmap)
+ matrix_vis[mask_bg] = image_target_bgr[y : y + h, x : x + w, :][mask_bg]
+ image_target_bgr[y : y + h, x : x + w, :] = (
+ image_target_bgr[y : y + h, x : x + w, :] * (1.0 - self.alpha) + matrix_vis * self.alpha
+ )
+ return image_target_bgr.astype(np.uint8)
+
+ def _resize(self, mask, matrix, w, h):
+ if (w != mask.shape[1]) or (h != mask.shape[0]):
+ mask = cv2.resize(mask, (w, h), self.interp_method_mask)
+ if (w != matrix.shape[1]) or (h != matrix.shape[0]):
+ matrix = cv2.resize(matrix, (w, h), self.interp_method_matrix)
+ return mask, matrix
+
+ def _check_image(self, image_rgb):
+ assert len(image_rgb.shape) == 3
+ assert image_rgb.shape[2] == 3
+ assert image_rgb.dtype == np.uint8
+
+ def _check_mask_matrix(self, mask, matrix):
+ assert len(matrix.shape) == 2
+ assert len(mask.shape) == 2
+ assert mask.dtype == np.uint8
+
+class DensePoseResultsVisualizer:
+ def visualize(
+ self,
+ image_bgr: Image,
+ results,
+ ) -> Image:
+ context = self.create_visualization_context(image_bgr)
+ for i, result in enumerate(results):
+ boxes_xywh, labels, uv = result
+ iuv_array = torch.cat(
+ (labels[None].type(torch.float32), uv * 255.0)
+ ).type(torch.uint8)
+ self.visualize_iuv_arr(context, iuv_array.cpu().numpy(), boxes_xywh)
+ image_bgr = self.context_to_image_bgr(context)
+ return image_bgr
+
+ def create_visualization_context(self, image_bgr: Image):
+ return image_bgr
+
+ def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None:
+ pass
+
+ def context_to_image_bgr(self, context):
+ return context
+
+ def get_image_bgr_from_context(self, context):
+ return context
+
+class DensePoseMaskedColormapResultsVisualizer(DensePoseResultsVisualizer):
+ def __init__(
+ self,
+ data_extractor,
+ segm_extractor,
+ inplace=True,
+ cmap=cv2.COLORMAP_PARULA,
+ alpha=0.7,
+ val_scale=1.0,
+ **kwargs,
+ ):
+ self.mask_visualizer = MatrixVisualizer(
+ inplace=inplace, cmap=cmap, val_scale=val_scale, alpha=alpha
+ )
+ self.data_extractor = data_extractor
+ self.segm_extractor = segm_extractor
+
+ def context_to_image_bgr(self, context):
+ return context
+
+ def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None:
+ image_bgr = self.get_image_bgr_from_context(context)
+ matrix = self.data_extractor(iuv_arr)
+ segm = self.segm_extractor(iuv_arr)
+ mask = np.zeros(matrix.shape, dtype=np.uint8)
+ mask[segm > 0] = 1
+ image_bgr = self.mask_visualizer.visualize(image_bgr, mask, matrix, bbox_xywh)
+
+
+def _extract_i_from_iuvarr(iuv_arr):
+ return iuv_arr[0, :, :]
+
+
+def _extract_u_from_iuvarr(iuv_arr):
+ return iuv_arr[1, :, :]
+
+
+def _extract_v_from_iuvarr(iuv_arr):
+ return iuv_arr[2, :, :]
+
+def make_int_box(box: torch.Tensor) -> IntTupleBox:
+ int_box = [0, 0, 0, 0]
+ int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist())
+ return int_box[0], int_box[1], int_box[2], int_box[3]
+
+def densepose_chart_predictor_output_to_result_with_confidences(
+ boxes: Boxes,
+ coarse_segm,
+ fine_segm,
+ u, v
+
+):
+ boxes_xyxy_abs = boxes.clone()
+ boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
+ box_xywh = make_int_box(boxes_xywh_abs[0])
+
+ labels = resample_fine_and_coarse_segm_tensors_to_bbox(fine_segm, coarse_segm, box_xywh).squeeze(0)
+ uv = resample_uv_tensors_to_bbox(u, v, labels, box_xywh)
+ confidences = []
+ return box_xywh, labels, uv
+
+def resample_fine_and_coarse_segm_tensors_to_bbox(
+ fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox
+):
+ """
+ Resample fine and coarse segmentation tensors to the given
+ bounding box and derive labels for each pixel of the bounding box
+
+ Args:
+ fine_segm: float tensor of shape [1, C, Hout, Wout]
+ coarse_segm: float tensor of shape [1, K, Hout, Wout]
+ box_xywh_abs (tuple of 4 int): bounding box given by its upper-left
+ corner coordinates, width (W) and height (H)
+ Return:
+ Labels for each pixel of the bounding box, a long tensor of size [1, H, W]
+ """
+ x, y, w, h = box_xywh_abs
+ w = max(int(w), 1)
+ h = max(int(h), 1)
+ # coarse segmentation
+ coarse_segm_bbox = F.interpolate(
+ coarse_segm,
+ (h, w),
+ mode="bilinear",
+ align_corners=False,
+ ).argmax(dim=1)
+ # combined coarse and fine segmentation
+ labels = (
+ F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
+ * (coarse_segm_bbox > 0).long()
+ )
+ return labels
+
+def resample_uv_tensors_to_bbox(
+ u: torch.Tensor,
+ v: torch.Tensor,
+ labels: torch.Tensor,
+ box_xywh_abs: IntTupleBox,
+) -> torch.Tensor:
+ """
+ Resamples U and V coordinate estimates for the given bounding box
+
+ Args:
+ u (tensor [1, C, H, W] of float): U coordinates
+ v (tensor [1, C, H, W] of float): V coordinates
+ labels (tensor [H, W] of long): labels obtained by resampling segmentation
+ outputs for the given bounding box
+ box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs
+ Return:
+ Resampled U and V coordinates - a tensor [2, H, W] of float
+ """
+ x, y, w, h = box_xywh_abs
+ w = max(int(w), 1)
+ h = max(int(h), 1)
+ u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False)
+ v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False)
+ uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device)
+ for part_id in range(1, u_bbox.size(1)):
+ uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id]
+ uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id]
+ return uv
+
diff --git a/sd-webui-controlnet/annotator/depth_anything.py b/sd-webui-controlnet/annotator/depth_anything.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbe480c5e0947a3b4b7bbd2b8fce45a8c783e936
--- /dev/null
+++ b/sd-webui-controlnet/annotator/depth_anything.py
@@ -0,0 +1,79 @@
+import os
+import torch
+import cv2
+import numpy as np
+import torch.nn.functional as F
+from torchvision.transforms import Compose
+
+from depth_anything.dpt import DPT_DINOv2
+from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
+from .util import load_model
+from .annotator_path import models_path
+
+
+transform = Compose(
+ [
+ Resize(
+ width=518,
+ height=518,
+ resize_target=False,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=14,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ PrepareForNet(),
+ ]
+)
+
+
+class DepthAnythingDetector:
+ """https://github.com/LiheYoung/Depth-Anything"""
+
+ model_dir = os.path.join(models_path, "depth_anything")
+
+ def __init__(self, device: torch.device):
+ self.device = device
+ self.model = (
+ DPT_DINOv2(
+ encoder="vitl",
+ features=256,
+ out_channels=[256, 512, 1024, 1024],
+ localhub=False,
+ )
+ .to(device)
+ .eval()
+ )
+ remote_url = os.environ.get(
+ "CONTROLNET_DEPTH_ANYTHING_MODEL_URL",
+ "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth",
+ )
+ model_path = load_model(
+ "depth_anything_vitl14.pth", remote_url=remote_url, model_dir=self.model_dir
+ )
+ self.model.load_state_dict(torch.load(model_path))
+
+ def __call__(self, image: np.ndarray, colored: bool = True) -> np.ndarray:
+ self.model.to(self.device)
+ h, w = image.shape[:2]
+
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
+ image = transform({"image": image})["image"]
+ image = torch.from_numpy(image).unsqueeze(0).to(self.device)
+ @torch.no_grad()
+ def predict_depth(model, image):
+ return model(image)
+ depth = predict_depth(self.model, image)
+ depth = F.interpolate(
+ depth[None], (h, w), mode="bilinear", align_corners=False
+ )[0, 0]
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
+ depth = depth.cpu().numpy().astype(np.uint8)
+ if colored:
+ return cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)[:, :, ::-1]
+ else:
+ return depth
+
+ def unload_model(self):
+ self.model.to("cpu")
diff --git a/sd-webui-controlnet/annotator/hed/__init__.py b/sd-webui-controlnet/annotator/hed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cabfef0db9089e415812dd8311a091397b39966
--- /dev/null
+++ b/sd-webui-controlnet/annotator/hed/__init__.py
@@ -0,0 +1,98 @@
+# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
+# Please use this implementation in your products
+# This implementation may produce slightly different results from Saining Xie's official implementations,
+# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
+# Different from official models and other implementations, this is an RGB-input model (rather than BGR)
+# and in this way it works better for gradio's RGB protocol
+
+import os
+import cv2
+import torch
+import numpy as np
+
+from einops import rearrange
+import os
+from modules import devices
+from annotator.annotator_path import models_path
+from annotator.util import safe_step, nms
+
+
+class DoubleConvBlock(torch.nn.Module):
+ def __init__(self, input_channel, output_channel, layer_number):
+ super().__init__()
+ self.convs = torch.nn.Sequential()
+ self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
+ for i in range(1, layer_number):
+ self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
+ self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
+
+ def __call__(self, x, down_sampling=False):
+ h = x
+ if down_sampling:
+ h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
+ for conv in self.convs:
+ h = conv(h)
+ h = torch.nn.functional.relu(h)
+ return h, self.projection(h)
+
+
+class ControlNetHED_Apache2(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
+ self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
+ self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
+ self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
+ self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
+ self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
+
+ def __call__(self, x):
+ h = x - self.norm
+ h, projection1 = self.block1(h)
+ h, projection2 = self.block2(h, down_sampling=True)
+ h, projection3 = self.block3(h, down_sampling=True)
+ h, projection4 = self.block4(h, down_sampling=True)
+ h, projection5 = self.block5(h, down_sampling=True)
+ return projection1, projection2, projection3, projection4, projection5
+
+
+netNetwork = None
+remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
+modeldir = os.path.join(models_path, "hed")
+old_modeldir = os.path.dirname(os.path.realpath(__file__))
+
+
+def apply_hed(input_image, is_safe=False):
+ global netNetwork
+ if netNetwork is None:
+ modelpath = os.path.join(modeldir, "ControlNetHED.pth")
+ old_modelpath = os.path.join(old_modeldir, "ControlNetHED.pth")
+ if os.path.exists(old_modelpath):
+ modelpath = old_modelpath
+ elif not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=modeldir)
+ netNetwork = ControlNetHED_Apache2().to(devices.get_device_for("controlnet"))
+ netNetwork.load_state_dict(torch.load(modelpath, map_location='cpu'))
+ netNetwork.to(devices.get_device_for("controlnet")).float().eval()
+
+ assert input_image.ndim == 3
+ H, W, C = input_image.shape
+ with torch.no_grad():
+ image_hed = torch.from_numpy(input_image.copy()).float().to(devices.get_device_for("controlnet"))
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
+ edges = netNetwork(image_hed)
+ edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
+ edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
+ edges = np.stack(edges, axis=2)
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
+ if is_safe:
+ edge = safe_step(edge)
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
+ return edge
+
+
+def unload_hed_model():
+ global netNetwork
+ if netNetwork is not None:
+ netNetwork.cpu()
diff --git a/sd-webui-controlnet/annotator/keypose/__init__.py b/sd-webui-controlnet/annotator/keypose/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa3dfa2e1589f22471411b3180ccaf870f147d73
--- /dev/null
+++ b/sd-webui-controlnet/annotator/keypose/__init__.py
@@ -0,0 +1,212 @@
+import numpy as np
+import cv2
+import torch
+
+import os
+from modules import devices
+from annotator.annotator_path import models_path
+
+import mmcv
+from mmdet.apis import inference_detector, init_detector
+from mmpose.apis import inference_top_down_pose_model
+from mmpose.apis import init_pose_model, process_mmdet_results, vis_pose_result
+
+
+def preprocessing(image, device):
+ # Resize
+ scale = 640 / max(image.shape[:2])
+ image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
+ raw_image = image.astype(np.uint8)
+
+ # Subtract mean values
+ image = image.astype(np.float32)
+ image -= np.array(
+ [
+ float(104.008),
+ float(116.669),
+ float(122.675),
+ ]
+ )
+
+ # Convert to torch.Tensor and add "batch" axis
+ image = torch.from_numpy(image.transpose(2, 0, 1)).float().unsqueeze(0)
+ image = image.to(device)
+
+ return image, raw_image
+
+
+def imshow_keypoints(img,
+ pose_result,
+ skeleton=None,
+ kpt_score_thr=0.1,
+ pose_kpt_color=None,
+ pose_link_color=None,
+ radius=4,
+ thickness=1):
+ """Draw keypoints and links on an image.
+ Args:
+ img (ndarry): The image to draw poses on.
+ pose_result (list[kpts]): The poses to draw. Each element kpts is
+ a set of K keypoints as an Kx3 numpy.ndarray, where each
+ keypoint is represented as x, y, score.
+ kpt_score_thr (float, optional): Minimum score of keypoints
+ to be shown. Default: 0.3.
+ pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
+ the keypoint will not be drawn.
+ pose_link_color (np.array[Mx3]): Color of M links. If None, the
+ links will not be drawn.
+ thickness (int): Thickness of lines.
+ """
+
+ img_h, img_w, _ = img.shape
+ img = np.zeros(img.shape)
+
+ for idx, kpts in enumerate(pose_result):
+ if idx > 1:
+ continue
+ kpts = kpts['keypoints']
+ # print(kpts)
+ kpts = np.array(kpts, copy=False)
+
+ # draw each point on image
+ if pose_kpt_color is not None:
+ assert len(pose_kpt_color) == len(kpts)
+
+ for kid, kpt in enumerate(kpts):
+ x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
+
+ if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
+ # skip the point that should not be drawn
+ continue
+
+ color = tuple(int(c) for c in pose_kpt_color[kid])
+ cv2.circle(img, (int(x_coord), int(y_coord)),
+ radius, color, -1)
+
+ # draw links
+ if skeleton is not None and pose_link_color is not None:
+ assert len(pose_link_color) == len(skeleton)
+
+ for sk_id, sk in enumerate(skeleton):
+ pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
+ pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
+
+ if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
+ or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
+ or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
+ # skip the link that should not be drawn
+ continue
+ color = tuple(int(c) for c in pose_link_color[sk_id])
+ cv2.line(img, pos1, pos2, color, thickness=thickness)
+
+ return img
+
+
+human_det, pose_model = None, None
+det_model_path = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
+pose_model_path = "https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth"
+
+modeldir = os.path.join(models_path, "keypose")
+old_modeldir = os.path.dirname(os.path.realpath(__file__))
+
+det_config = 'faster_rcnn_r50_fpn_coco.py'
+pose_config = 'hrnet_w48_coco_256x192.py'
+
+det_checkpoint = 'faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
+pose_checkpoint = 'hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
+det_cat_id = 1
+bbox_thr = 0.2
+
+skeleton = [
+ [15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8],
+ [7, 9], [8, 10],
+ [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]
+]
+
+pose_kpt_color = [
+ [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
+ [0, 255, 0],
+ [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0],
+ [255, 128, 0],
+ [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]
+]
+
+pose_link_color = [
+ [0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
+ [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
+ [255, 128, 0],
+ [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255],
+ [51, 153, 255],
+ [51, 153, 255], [51, 153, 255], [51, 153, 255]
+]
+
+def find_download_model(checkpoint, remote_path):
+ modelpath = os.path.join(modeldir, checkpoint)
+ old_modelpath = os.path.join(old_modeldir, checkpoint)
+
+ if os.path.exists(old_modelpath):
+ modelpath = old_modelpath
+ elif not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_path, model_dir=modeldir)
+
+ return modelpath
+
+def apply_keypose(input_image):
+ global human_det, pose_model
+ if netNetwork is None:
+ det_model_local = find_download_model(det_checkpoint, det_model_path)
+ hrnet_model_local = find_download_model(pose_checkpoint, pose_model_path)
+ det_config_mmcv = mmcv.Config.fromfile(det_config)
+ pose_config_mmcv = mmcv.Config.fromfile(pose_config)
+ human_det = init_detector(det_config_mmcv, det_model_local, device=devices.get_device_for("controlnet"))
+ pose_model = init_pose_model(pose_config_mmcv, hrnet_model_local, device=devices.get_device_for("controlnet"))
+
+ assert input_image.ndim == 3
+ input_image = input_image.copy()
+ with torch.no_grad():
+ image = torch.from_numpy(input_image).float().to(devices.get_device_for("controlnet"))
+ image = image / 255.0
+ mmdet_results = inference_detector(human_det, image)
+
+ # keep the person class bounding boxes.
+ person_results = process_mmdet_results(mmdet_results, det_cat_id)
+
+ return_heatmap = False
+ dataset = pose_model.cfg.data['test']['type']
+
+ # e.g. use ('backbone', ) to return backbone feature
+ output_layer_names = None
+ pose_results, _ = inference_top_down_pose_model(
+ pose_model,
+ image,
+ person_results,
+ bbox_thr=bbox_thr,
+ format='xyxy',
+ dataset=dataset,
+ dataset_info=None,
+ return_heatmap=return_heatmap,
+ outputs=output_layer_names
+ )
+
+ im_keypose_out = imshow_keypoints(
+ image,
+ pose_results,
+ skeleton=skeleton,
+ pose_kpt_color=pose_kpt_color,
+ pose_link_color=pose_link_color,
+ radius=2,
+ thickness=2
+ )
+ im_keypose_out = im_keypose_out.astype(np.uint8)
+
+ # image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
+ # edge = netNetwork(image_hed)[0]
+ # edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
+ return im_keypose_out
+
+
+def unload_hed_model():
+ global netNetwork
+ if netNetwork is not None:
+ netNetwork.cpu()
diff --git a/sd-webui-controlnet/annotator/keypose/faster_rcnn_r50_fpn_coco.py b/sd-webui-controlnet/annotator/keypose/faster_rcnn_r50_fpn_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9ad9528b22163ae7ce1390375b69227fd6eafd9
--- /dev/null
+++ b/sd-webui-controlnet/annotator/keypose/faster_rcnn_r50_fpn_coco.py
@@ -0,0 +1,182 @@
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ # dict(type='TensorboardLoggerHook')
+ ])
+# yapf:enable
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=0.001,
+ step=[8, 11])
+total_epochs = 12
+
+model = dict(
+ type='FasterRCNN',
+ pretrained='torchvision://resnet50',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=1,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ style='pytorch'),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ num_outs=5),
+ rpn_head=dict(
+ type='RPNHead',
+ in_channels=256,
+ feat_channels=256,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ scales=[8],
+ ratios=[0.5, 1.0, 2.0],
+ strides=[4, 8, 16, 32, 64]),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[.0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
+ loss_cls=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
+ roi_head=dict(
+ type='StandardRoIHead',
+ bbox_roi_extractor=dict(
+ type='SingleRoIExtractor',
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
+ out_channels=256,
+ featmap_strides=[4, 8, 16, 32]),
+ bbox_head=dict(
+ type='Shared2FCBBoxHead',
+ in_channels=256,
+ fc_out_channels=1024,
+ roi_feat_size=7,
+ num_classes=80,
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[0., 0., 0., 0.],
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
+ reg_class_agnostic=False,
+ loss_cls=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
+ # model training and testing settings
+ train_cfg=dict(
+ rpn=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.7,
+ neg_iou_thr=0.3,
+ min_pos_iou=0.3,
+ match_low_quality=True,
+ ignore_iof_thr=-1),
+ sampler=dict(
+ type='RandomSampler',
+ num=256,
+ pos_fraction=0.5,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=False),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ rpn_proposal=dict(
+ nms_pre=2000,
+ max_per_img=1000,
+ nms=dict(type='nms', iou_threshold=0.7),
+ min_bbox_size=0),
+ rcnn=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.5,
+ neg_iou_thr=0.5,
+ min_pos_iou=0.5,
+ match_low_quality=False,
+ ignore_iof_thr=-1),
+ sampler=dict(
+ type='RandomSampler',
+ num=512,
+ pos_fraction=0.25,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True),
+ pos_weight=-1,
+ debug=False)),
+ test_cfg=dict(
+ rpn=dict(
+ nms_pre=1000,
+ max_per_img=1000,
+ nms=dict(type='nms', iou_threshold=0.7),
+ min_bbox_size=0),
+ rcnn=dict(
+ score_thr=0.05,
+ nms=dict(type='nms', iou_threshold=0.5),
+ max_per_img=100)
+ # soft-nms is also supported for rcnn testing
+ # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
+ ))
+
+dataset_type = 'CocoDataset'
+data_root = 'data/coco'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=2,
+ train=dict(
+ type=dataset_type,
+ ann_file=f'{data_root}/annotations/instances_train2017.json',
+ img_prefix=f'{data_root}/train2017/',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
+ pipeline=test_pipeline))
+evaluation = dict(interval=1, metric='bbox')
diff --git a/sd-webui-controlnet/annotator/keypose/hrnet_w48_coco_256x192.py b/sd-webui-controlnet/annotator/keypose/hrnet_w48_coco_256x192.py
new file mode 100644
index 0000000000000000000000000000000000000000..9755e6773cd3a8c0d2ac684c612d716cfd44b0ca
--- /dev/null
+++ b/sd-webui-controlnet/annotator/keypose/hrnet_w48_coco_256x192.py
@@ -0,0 +1,169 @@
+# _base_ = [
+# '../../../../_base_/default_runtime.py',
+# '../../../../_base_/datasets/coco.py'
+# ]
+evaluation = dict(interval=10, metric='mAP', save_best='AP')
+
+optimizer = dict(
+ type='Adam',
+ lr=5e-4,
+)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=0.001,
+ step=[170, 200])
+total_epochs = 210
+channel_cfg = dict(
+ num_output_channels=17,
+ dataset_joints=17,
+ dataset_channel=[
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
+ ],
+ inference_channel=[
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
+ ])
+
+# model settings
+model = dict(
+ type='TopDown',
+ pretrained='https://download.openmmlab.com/mmpose/'
+ 'pretrain_models/hrnet_w48-8ef0771d.pth',
+ backbone=dict(
+ type='HRNet',
+ in_channels=3,
+ extra=dict(
+ stage1=dict(
+ num_modules=1,
+ num_branches=1,
+ block='BOTTLENECK',
+ num_blocks=(4, ),
+ num_channels=(64, )),
+ stage2=dict(
+ num_modules=1,
+ num_branches=2,
+ block='BASIC',
+ num_blocks=(4, 4),
+ num_channels=(48, 96)),
+ stage3=dict(
+ num_modules=4,
+ num_branches=3,
+ block='BASIC',
+ num_blocks=(4, 4, 4),
+ num_channels=(48, 96, 192)),
+ stage4=dict(
+ num_modules=3,
+ num_branches=4,
+ block='BASIC',
+ num_blocks=(4, 4, 4, 4),
+ num_channels=(48, 96, 192, 384))),
+ ),
+ keypoint_head=dict(
+ type='TopdownHeatmapSimpleHead',
+ in_channels=48,
+ out_channels=channel_cfg['num_output_channels'],
+ num_deconv_layers=0,
+ extra=dict(final_conv_kernel=1, ),
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=True,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[192, 256],
+ heatmap_size=[48, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'],
+ soft_nms=False,
+ nms_thr=1.0,
+ oks_thr=0.9,
+ vis_thr=0.2,
+ use_gt_bbox=False,
+ det_bbox_thr=0.0,
+ bbox_file='data/coco/person_detection_results/'
+ 'COCO_val2017_detections_AP_H_56_person.json',
+)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownGetBboxCenterScale', padding=1.25),
+ dict(type='TopDownRandomShiftBboxCenter', shift_factor=0.16, prob=0.3),
+ dict(type='TopDownRandomFlip', flip_prob=0.5),
+ dict(
+ type='TopDownHalfBodyTransform',
+ num_joints_half_body=8,
+ prob_half_body=0.3),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
+ dict(type='TopDownAffine'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTarget', sigma=2),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs'
+ ]),
+]
+
+val_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownGetBboxCenterScale', padding=1.25),
+ dict(type='TopDownAffine'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'image_file', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs'
+ ]),
+]
+
+test_pipeline = val_pipeline
+
+data_root = 'data/coco'
+data = dict(
+ samples_per_gpu=32,
+ workers_per_gpu=2,
+ val_dataloader=dict(samples_per_gpu=32),
+ test_dataloader=dict(samples_per_gpu=32),
+ train=dict(
+ type='TopDownCocoDataset',
+ ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
+ img_prefix=f'{data_root}/train2017/',
+ data_cfg=data_cfg,
+ pipeline=train_pipeline,
+ dataset_info={{_base_.dataset_info}}),
+ val=dict(
+ type='TopDownCocoDataset',
+ ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
+ data_cfg=data_cfg,
+ pipeline=val_pipeline,
+ dataset_info={{_base_.dataset_info}}),
+ test=dict(
+ type='TopDownCocoDataset',
+ ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
+ data_cfg=data_cfg,
+ pipeline=test_pipeline,
+ dataset_info={{_base_.dataset_info}}),
+)
diff --git a/sd-webui-controlnet/annotator/lama/__init__.py b/sd-webui-controlnet/annotator/lama/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7784a3837d8454fe8991d7f7a4341331d8b1f0d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/__init__.py
@@ -0,0 +1,58 @@
+# https://github.com/advimman/lama
+
+import yaml
+import torch
+from omegaconf import OmegaConf
+import numpy as np
+
+from einops import rearrange
+import os
+from modules import devices
+from annotator.annotator_path import models_path
+from annotator.lama.saicinpainting.training.trainers import load_checkpoint
+
+
+class LamaInpainting:
+ model_dir = os.path.join(models_path, "lama")
+
+ def __init__(self):
+ self.model = None
+ self.device = devices.get_device_for("controlnet")
+
+ def load_model(self):
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetLama.pth"
+ modelpath = os.path.join(self.model_dir, "ControlNetLama.pth")
+ if not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=self.model_dir)
+ config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.yaml')
+ cfg = yaml.safe_load(open(config_path, 'rt'))
+ cfg = OmegaConf.create(cfg)
+ cfg.training_model.predict_only = True
+ cfg.visualizer.kind = 'noop'
+ self.model = load_checkpoint(cfg, os.path.abspath(modelpath), strict=False, map_location='cpu')
+ self.model = self.model.to(self.device)
+ self.model.eval()
+
+ def unload_model(self):
+ if self.model is not None:
+ self.model.cpu()
+
+ def __call__(self, input_image):
+ if self.model is None:
+ self.load_model()
+ self.model.to(self.device)
+ color = np.ascontiguousarray(input_image[:, :, 0:3]).astype(np.float32) / 255.0
+ mask = np.ascontiguousarray(input_image[:, :, 3:4]).astype(np.float32) / 255.0
+ with torch.no_grad():
+ color = torch.from_numpy(color).float().to(self.device)
+ mask = torch.from_numpy(mask).float().to(self.device)
+ mask = (mask > 0.5).float()
+ color = color * (1 - mask)
+ image_feed = torch.cat([color, mask], dim=2)
+ image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
+ result = self.model(image_feed)[0]
+ result = rearrange(result, 'c h w -> h w c')
+ result = result * mask + color * (1 - mask)
+ result *= 255.0
+ return result.detach().cpu().numpy().clip(0, 255).astype(np.uint8)
diff --git a/sd-webui-controlnet/annotator/lama/config.yaml b/sd-webui-controlnet/annotator/lama/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..55fd91b5bcacd654e3045a2331e9c186818e6edc
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/config.yaml
@@ -0,0 +1,157 @@
+run_title: b18_ffc075_batch8x15
+training_model:
+ kind: default
+ visualize_each_iters: 1000
+ concat_mask: true
+ store_discr_outputs_for_vis: true
+losses:
+ l1:
+ weight_missing: 0
+ weight_known: 10
+ perceptual:
+ weight: 0
+ adversarial:
+ kind: r1
+ weight: 10
+ gp_coef: 0.001
+ mask_as_fake_target: true
+ allow_scale_mask: true
+ feature_matching:
+ weight: 100
+ resnet_pl:
+ weight: 30
+ weights_path: ${env:TORCH_HOME}
+
+optimizers:
+ generator:
+ kind: adam
+ lr: 0.001
+ discriminator:
+ kind: adam
+ lr: 0.0001
+visualizer:
+ key_order:
+ - image
+ - predicted_image
+ - discr_output_fake
+ - discr_output_real
+ - inpainted
+ rescale_keys:
+ - discr_output_fake
+ - discr_output_real
+ kind: directory
+ outdir: /group-volume/User-Driven-Content-Generation/r.suvorov/inpainting/experiments/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/samples
+location:
+ data_root_dir: /group-volume/User-Driven-Content-Generation/datasets/inpainting_data_root_large
+ out_root_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/experiments
+ tb_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/tb_logs
+data:
+ batch_size: 15
+ val_batch_size: 2
+ num_workers: 3
+ train:
+ indir: ${location.data_root_dir}/train
+ out_size: 256
+ mask_gen_kwargs:
+ irregular_proba: 1
+ irregular_kwargs:
+ max_angle: 4
+ max_len: 200
+ max_width: 100
+ max_times: 5
+ min_times: 1
+ box_proba: 1
+ box_kwargs:
+ margin: 10
+ bbox_min_size: 30
+ bbox_max_size: 150
+ max_times: 3
+ min_times: 1
+ segm_proba: 0
+ segm_kwargs:
+ confidence_threshold: 0.5
+ max_object_area: 0.5
+ min_mask_area: 0.07
+ downsample_levels: 6
+ num_variants_per_mask: 1
+ rigidness_mode: 1
+ max_foreground_coverage: 0.3
+ max_foreground_intersection: 0.7
+ max_mask_intersection: 0.1
+ max_hidden_area: 0.1
+ max_scale_change: 0.25
+ horizontal_flip: true
+ max_vertical_shift: 0.2
+ position_shuffle: true
+ transform_variant: distortions
+ dataloader_kwargs:
+ batch_size: ${data.batch_size}
+ shuffle: true
+ num_workers: ${data.num_workers}
+ val:
+ indir: ${location.data_root_dir}/val
+ img_suffix: .png
+ dataloader_kwargs:
+ batch_size: ${data.val_batch_size}
+ shuffle: false
+ num_workers: ${data.num_workers}
+ visual_test:
+ indir: ${location.data_root_dir}/korean_test
+ img_suffix: _input.png
+ pad_out_to_modulo: 32
+ dataloader_kwargs:
+ batch_size: 1
+ shuffle: false
+ num_workers: ${data.num_workers}
+generator:
+ kind: ffc_resnet
+ input_nc: 4
+ output_nc: 3
+ ngf: 64
+ n_downsampling: 3
+ n_blocks: 18
+ add_out_act: sigmoid
+ init_conv_kwargs:
+ ratio_gin: 0
+ ratio_gout: 0
+ enable_lfu: false
+ downsample_conv_kwargs:
+ ratio_gin: ${generator.init_conv_kwargs.ratio_gout}
+ ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin}
+ enable_lfu: false
+ resnet_conv_kwargs:
+ ratio_gin: 0.75
+ ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin}
+ enable_lfu: false
+discriminator:
+ kind: pix2pixhd_nlayer
+ input_nc: 3
+ ndf: 64
+ n_layers: 4
+evaluator:
+ kind: default
+ inpainted_key: inpainted
+ integral_kind: ssim_fid100_f1
+trainer:
+ kwargs:
+ gpus: -1
+ accelerator: ddp
+ max_epochs: 200
+ gradient_clip_val: 1
+ log_gpu_memory: None
+ limit_train_batches: 25000
+ val_check_interval: ${trainer.kwargs.limit_train_batches}
+ log_every_n_steps: 1000
+ precision: 32
+ terminate_on_nan: false
+ check_val_every_n_epoch: 1
+ num_sanity_val_steps: 8
+ limit_val_batches: 1000
+ replace_sampler_ddp: false
+ checkpoint_kwargs:
+ verbose: true
+ save_top_k: 5
+ save_last: true
+ period: 1
+ monitor: val_ssim_fid100_f1_total_mean
+ mode: max
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/__init__.py b/sd-webui-controlnet/annotator/lama/saicinpainting/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/__init__.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/data/__init__.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/data/masks.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/data/masks.py
new file mode 100644
index 0000000000000000000000000000000000000000..27cb9050fa67c40d7d8d492a7088a621ad1ba2ce
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/data/masks.py
@@ -0,0 +1,332 @@
+import math
+import random
+import hashlib
+import logging
+from enum import Enum
+
+import cv2
+import numpy as np
+
+# from annotator.lama.saicinpainting.evaluation.masks.mask import SegmentationMask
+from annotator.lama.saicinpainting.utils import LinearRamp
+
+LOGGER = logging.getLogger(__name__)
+
+
+class DrawMethod(Enum):
+ LINE = 'line'
+ CIRCLE = 'circle'
+ SQUARE = 'square'
+
+
+def make_random_irregular_mask(shape, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10,
+ draw_method=DrawMethod.LINE):
+ draw_method = DrawMethod(draw_method)
+
+ height, width = shape
+ mask = np.zeros((height, width), np.float32)
+ times = np.random.randint(min_times, max_times + 1)
+ for i in range(times):
+ start_x = np.random.randint(width)
+ start_y = np.random.randint(height)
+ for j in range(1 + np.random.randint(5)):
+ angle = 0.01 + np.random.randint(max_angle)
+ if i % 2 == 0:
+ angle = 2 * 3.1415926 - angle
+ length = 10 + np.random.randint(max_len)
+ brush_w = 5 + np.random.randint(max_width)
+ end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width)
+ end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height)
+ if draw_method == DrawMethod.LINE:
+ cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
+ elif draw_method == DrawMethod.CIRCLE:
+ cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1., thickness=-1)
+ elif draw_method == DrawMethod.SQUARE:
+ radius = brush_w // 2
+ mask[start_y - radius:start_y + radius, start_x - radius:start_x + radius] = 1
+ start_x, start_y = end_x, end_y
+ return mask[None, ...]
+
+
+class RandomIrregularMaskGenerator:
+ def __init__(self, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10, ramp_kwargs=None,
+ draw_method=DrawMethod.LINE):
+ self.max_angle = max_angle
+ self.max_len = max_len
+ self.max_width = max_width
+ self.min_times = min_times
+ self.max_times = max_times
+ self.draw_method = draw_method
+ self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
+
+ def __call__(self, img, iter_i=None, raw_image=None):
+ coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
+ cur_max_len = int(max(1, self.max_len * coef))
+ cur_max_width = int(max(1, self.max_width * coef))
+ cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times) * coef)
+ return make_random_irregular_mask(img.shape[1:], max_angle=self.max_angle, max_len=cur_max_len,
+ max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times,
+ draw_method=self.draw_method)
+
+
+def make_random_rectangle_mask(shape, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3):
+ height, width = shape
+ mask = np.zeros((height, width), np.float32)
+ bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2)
+ times = np.random.randint(min_times, max_times + 1)
+ for i in range(times):
+ box_width = np.random.randint(bbox_min_size, bbox_max_size)
+ box_height = np.random.randint(bbox_min_size, bbox_max_size)
+ start_x = np.random.randint(margin, width - margin - box_width + 1)
+ start_y = np.random.randint(margin, height - margin - box_height + 1)
+ mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1
+ return mask[None, ...]
+
+
+class RandomRectangleMaskGenerator:
+ def __init__(self, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3, ramp_kwargs=None):
+ self.margin = margin
+ self.bbox_min_size = bbox_min_size
+ self.bbox_max_size = bbox_max_size
+ self.min_times = min_times
+ self.max_times = max_times
+ self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
+
+ def __call__(self, img, iter_i=None, raw_image=None):
+ coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
+ cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size) * coef)
+ cur_max_times = int(self.min_times + (self.max_times - self.min_times) * coef)
+ return make_random_rectangle_mask(img.shape[1:], margin=self.margin, bbox_min_size=self.bbox_min_size,
+ bbox_max_size=cur_bbox_max_size, min_times=self.min_times,
+ max_times=cur_max_times)
+
+
+class RandomSegmentationMaskGenerator:
+ def __init__(self, **kwargs):
+ self.impl = None # will be instantiated in first call (effectively in subprocess)
+ self.kwargs = kwargs
+
+ def __call__(self, img, iter_i=None, raw_image=None):
+ if self.impl is None:
+ self.impl = SegmentationMask(**self.kwargs)
+
+ masks = self.impl.get_masks(np.transpose(img, (1, 2, 0)))
+ masks = [m for m in masks if len(np.unique(m)) > 1]
+ return np.random.choice(masks)
+
+
+def make_random_superres_mask(shape, min_step=2, max_step=4, min_width=1, max_width=3):
+ height, width = shape
+ mask = np.zeros((height, width), np.float32)
+ step_x = np.random.randint(min_step, max_step + 1)
+ width_x = np.random.randint(min_width, min(step_x, max_width + 1))
+ offset_x = np.random.randint(0, step_x)
+
+ step_y = np.random.randint(min_step, max_step + 1)
+ width_y = np.random.randint(min_width, min(step_y, max_width + 1))
+ offset_y = np.random.randint(0, step_y)
+
+ for dy in range(width_y):
+ mask[offset_y + dy::step_y] = 1
+ for dx in range(width_x):
+ mask[:, offset_x + dx::step_x] = 1
+ return mask[None, ...]
+
+
+class RandomSuperresMaskGenerator:
+ def __init__(self, **kwargs):
+ self.kwargs = kwargs
+
+ def __call__(self, img, iter_i=None):
+ return make_random_superres_mask(img.shape[1:], **self.kwargs)
+
+
+class DumbAreaMaskGenerator:
+ min_ratio = 0.1
+ max_ratio = 0.35
+ default_ratio = 0.225
+
+ def __init__(self, is_training):
+ #Parameters:
+ # is_training(bool): If true - random rectangular mask, if false - central square mask
+ self.is_training = is_training
+
+ def _random_vector(self, dimension):
+ if self.is_training:
+ lower_limit = math.sqrt(self.min_ratio)
+ upper_limit = math.sqrt(self.max_ratio)
+ mask_side = round((random.random() * (upper_limit - lower_limit) + lower_limit) * dimension)
+ u = random.randint(0, dimension-mask_side-1)
+ v = u+mask_side
+ else:
+ margin = (math.sqrt(self.default_ratio) / 2) * dimension
+ u = round(dimension/2 - margin)
+ v = round(dimension/2 + margin)
+ return u, v
+
+ def __call__(self, img, iter_i=None, raw_image=None):
+ c, height, width = img.shape
+ mask = np.zeros((height, width), np.float32)
+ x1, x2 = self._random_vector(width)
+ y1, y2 = self._random_vector(height)
+ mask[x1:x2, y1:y2] = 1
+ return mask[None, ...]
+
+
+class OutpaintingMaskGenerator:
+ def __init__(self, min_padding_percent:float=0.04, max_padding_percent:int=0.25, left_padding_prob:float=0.5, top_padding_prob:float=0.5,
+ right_padding_prob:float=0.5, bottom_padding_prob:float=0.5, is_fixed_randomness:bool=False):
+ """
+ is_fixed_randomness - get identical paddings for the same image if args are the same
+ """
+ self.min_padding_percent = min_padding_percent
+ self.max_padding_percent = max_padding_percent
+ self.probs = [left_padding_prob, top_padding_prob, right_padding_prob, bottom_padding_prob]
+ self.is_fixed_randomness = is_fixed_randomness
+
+ assert self.min_padding_percent <= self.max_padding_percent
+ assert self.max_padding_percent > 0
+ assert len([x for x in [self.min_padding_percent, self.max_padding_percent] if (x>=0 and x<=1)]) == 2, f"Padding percentage should be in [0,1]"
+ assert sum(self.probs) > 0, f"At least one of the padding probs should be greater than 0 - {self.probs}"
+ assert len([x for x in self.probs if (x >= 0) and (x <= 1)]) == 4, f"At least one of padding probs is not in [0,1] - {self.probs}"
+ if len([x for x in self.probs if x > 0]) == 1:
+ LOGGER.warning(f"Only one padding prob is greater than zero - {self.probs}. That means that the outpainting masks will be always on the same side")
+
+ def apply_padding(self, mask, coord):
+ mask[int(coord[0][0]*self.img_h):int(coord[1][0]*self.img_h),
+ int(coord[0][1]*self.img_w):int(coord[1][1]*self.img_w)] = 1
+ return mask
+
+ def get_padding(self, size):
+ n1 = int(self.min_padding_percent*size)
+ n2 = int(self.max_padding_percent*size)
+ return self.rnd.randint(n1, n2) / size
+
+ @staticmethod
+ def _img2rs(img):
+ arr = np.ascontiguousarray(img.astype(np.uint8))
+ str_hash = hashlib.sha1(arr).hexdigest()
+ res = hash(str_hash)%(2**32)
+ return res
+
+ def __call__(self, img, iter_i=None, raw_image=None):
+ c, self.img_h, self.img_w = img.shape
+ mask = np.zeros((self.img_h, self.img_w), np.float32)
+ at_least_one_mask_applied = False
+
+ if self.is_fixed_randomness:
+ assert raw_image is not None, f"Cant calculate hash on raw_image=None"
+ rs = self._img2rs(raw_image)
+ self.rnd = np.random.RandomState(rs)
+ else:
+ self.rnd = np.random
+
+ coords = [[
+ (0,0),
+ (1,self.get_padding(size=self.img_h))
+ ],
+ [
+ (0,0),
+ (self.get_padding(size=self.img_w),1)
+ ],
+ [
+ (0,1-self.get_padding(size=self.img_h)),
+ (1,1)
+ ],
+ [
+ (1-self.get_padding(size=self.img_w),0),
+ (1,1)
+ ]]
+
+ for pp, coord in zip(self.probs, coords):
+ if self.rnd.random() < pp:
+ at_least_one_mask_applied = True
+ mask = self.apply_padding(mask=mask, coord=coord)
+
+ if not at_least_one_mask_applied:
+ idx = self.rnd.choice(range(len(coords)), p=np.array(self.probs)/sum(self.probs))
+ mask = self.apply_padding(mask=mask, coord=coords[idx])
+ return mask[None, ...]
+
+
+class MixedMaskGenerator:
+ def __init__(self, irregular_proba=1/3, irregular_kwargs=None,
+ box_proba=1/3, box_kwargs=None,
+ segm_proba=1/3, segm_kwargs=None,
+ squares_proba=0, squares_kwargs=None,
+ superres_proba=0, superres_kwargs=None,
+ outpainting_proba=0, outpainting_kwargs=None,
+ invert_proba=0):
+ self.probas = []
+ self.gens = []
+
+ if irregular_proba > 0:
+ self.probas.append(irregular_proba)
+ if irregular_kwargs is None:
+ irregular_kwargs = {}
+ else:
+ irregular_kwargs = dict(irregular_kwargs)
+ irregular_kwargs['draw_method'] = DrawMethod.LINE
+ self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs))
+
+ if box_proba > 0:
+ self.probas.append(box_proba)
+ if box_kwargs is None:
+ box_kwargs = {}
+ self.gens.append(RandomRectangleMaskGenerator(**box_kwargs))
+
+ if segm_proba > 0:
+ self.probas.append(segm_proba)
+ if segm_kwargs is None:
+ segm_kwargs = {}
+ self.gens.append(RandomSegmentationMaskGenerator(**segm_kwargs))
+
+ if squares_proba > 0:
+ self.probas.append(squares_proba)
+ if squares_kwargs is None:
+ squares_kwargs = {}
+ else:
+ squares_kwargs = dict(squares_kwargs)
+ squares_kwargs['draw_method'] = DrawMethod.SQUARE
+ self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs))
+
+ if superres_proba > 0:
+ self.probas.append(superres_proba)
+ if superres_kwargs is None:
+ superres_kwargs = {}
+ self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs))
+
+ if outpainting_proba > 0:
+ self.probas.append(outpainting_proba)
+ if outpainting_kwargs is None:
+ outpainting_kwargs = {}
+ self.gens.append(OutpaintingMaskGenerator(**outpainting_kwargs))
+
+ self.probas = np.array(self.probas, dtype='float32')
+ self.probas /= self.probas.sum()
+ self.invert_proba = invert_proba
+
+ def __call__(self, img, iter_i=None, raw_image=None):
+ kind = np.random.choice(len(self.probas), p=self.probas)
+ gen = self.gens[kind]
+ result = gen(img, iter_i=iter_i, raw_image=raw_image)
+ if self.invert_proba > 0 and random.random() < self.invert_proba:
+ result = 1 - result
+ return result
+
+
+def get_mask_generator(kind, kwargs):
+ if kind is None:
+ kind = "mixed"
+ if kwargs is None:
+ kwargs = {}
+
+ if kind == "mixed":
+ cl = MixedMaskGenerator
+ elif kind == "outpainting":
+ cl = OutpaintingMaskGenerator
+ elif kind == "dumb":
+ cl = DumbAreaMaskGenerator
+ else:
+ raise NotImplementedError(f"No such generator kind = {kind}")
+ return cl(**kwargs)
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/__init__.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/adversarial.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/adversarial.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6db2967ce5074d94ed3b4c51fc743ff2f7831b1
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/adversarial.py
@@ -0,0 +1,177 @@
+from typing import Tuple, Dict, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BaseAdversarialLoss:
+ def pre_generator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
+ generator: nn.Module, discriminator: nn.Module):
+ """
+ Prepare for generator step
+ :param real_batch: Tensor, a batch of real samples
+ :param fake_batch: Tensor, a batch of samples produced by generator
+ :param generator:
+ :param discriminator:
+ :return: None
+ """
+
+ def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
+ generator: nn.Module, discriminator: nn.Module):
+ """
+ Prepare for discriminator step
+ :param real_batch: Tensor, a batch of real samples
+ :param fake_batch: Tensor, a batch of samples produced by generator
+ :param generator:
+ :param discriminator:
+ :return: None
+ """
+
+ def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
+ discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
+ mask: Optional[torch.Tensor] = None) \
+ -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+ """
+ Calculate generator loss
+ :param real_batch: Tensor, a batch of real samples
+ :param fake_batch: Tensor, a batch of samples produced by generator
+ :param discr_real_pred: Tensor, discriminator output for real_batch
+ :param discr_fake_pred: Tensor, discriminator output for fake_batch
+ :param mask: Tensor, actual mask, which was at input of generator when making fake_batch
+ :return: total generator loss along with some values that might be interesting to log
+ """
+ raise NotImplemented()
+
+ def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
+ discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
+ mask: Optional[torch.Tensor] = None) \
+ -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+ """
+ Calculate discriminator loss and call .backward() on it
+ :param real_batch: Tensor, a batch of real samples
+ :param fake_batch: Tensor, a batch of samples produced by generator
+ :param discr_real_pred: Tensor, discriminator output for real_batch
+ :param discr_fake_pred: Tensor, discriminator output for fake_batch
+ :param mask: Tensor, actual mask, which was at input of generator when making fake_batch
+ :return: total discriminator loss along with some values that might be interesting to log
+ """
+ raise NotImplemented()
+
+ def interpolate_mask(self, mask, shape):
+ assert mask is not None
+ assert self.allow_scale_mask or shape == mask.shape[-2:]
+ if shape != mask.shape[-2:] and self.allow_scale_mask:
+ if self.mask_scale_mode == 'maxpool':
+ mask = F.adaptive_max_pool2d(mask, shape)
+ else:
+ mask = F.interpolate(mask, size=shape, mode=self.mask_scale_mode)
+ return mask
+
+def make_r1_gp(discr_real_pred, real_batch):
+ if torch.is_grad_enabled():
+ grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0]
+ grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean()
+ else:
+ grad_penalty = 0
+ real_batch.requires_grad = False
+
+ return grad_penalty
+
+class NonSaturatingWithR1(BaseAdversarialLoss):
+ def __init__(self, gp_coef=5, weight=1, mask_as_fake_target=False, allow_scale_mask=False,
+ mask_scale_mode='nearest', extra_mask_weight_for_gen=0,
+ use_unmasked_for_gen=True, use_unmasked_for_discr=True):
+ self.gp_coef = gp_coef
+ self.weight = weight
+ # use for discr => use for gen;
+ # otherwise we teach only the discr to pay attention to very small difference
+ assert use_unmasked_for_gen or (not use_unmasked_for_discr)
+ # mask as target => use unmasked for discr:
+ # if we don't care about unmasked regions at all
+ # then it doesn't matter if the value of mask_as_fake_target is true or false
+ assert use_unmasked_for_discr or (not mask_as_fake_target)
+ self.use_unmasked_for_gen = use_unmasked_for_gen
+ self.use_unmasked_for_discr = use_unmasked_for_discr
+ self.mask_as_fake_target = mask_as_fake_target
+ self.allow_scale_mask = allow_scale_mask
+ self.mask_scale_mode = mask_scale_mode
+ self.extra_mask_weight_for_gen = extra_mask_weight_for_gen
+
+ def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
+ discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
+ mask=None) \
+ -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+ fake_loss = F.softplus(-discr_fake_pred)
+ if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \
+ not self.use_unmasked_for_gen: # == if masked region should be treated differently
+ mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
+ if not self.use_unmasked_for_gen:
+ fake_loss = fake_loss * mask
+ else:
+ pixel_weights = 1 + mask * self.extra_mask_weight_for_gen
+ fake_loss = fake_loss * pixel_weights
+
+ return fake_loss.mean() * self.weight, dict()
+
+ def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
+ generator: nn.Module, discriminator: nn.Module):
+ real_batch.requires_grad = True
+
+ def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
+ discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
+ mask=None) \
+ -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+
+ real_loss = F.softplus(-discr_real_pred)
+ grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef
+ fake_loss = F.softplus(discr_fake_pred)
+
+ if not self.use_unmasked_for_discr or self.mask_as_fake_target:
+ # == if masked region should be treated differently
+ mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
+ # use_unmasked_for_discr=False only makes sense for fakes;
+ # for reals there is no difference beetween two regions
+ fake_loss = fake_loss * mask
+ if self.mask_as_fake_target:
+ fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred)
+
+ sum_discr_loss = real_loss + grad_penalty + fake_loss
+ metrics = dict(discr_real_out=discr_real_pred.mean(),
+ discr_fake_out=discr_fake_pred.mean(),
+ discr_real_gp=grad_penalty)
+ return sum_discr_loss.mean(), metrics
+
+class BCELoss(BaseAdversarialLoss):
+ def __init__(self, weight):
+ self.weight = weight
+ self.bce_loss = nn.BCEWithLogitsLoss()
+
+ def generator_loss(self, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+ real_mask_gt = torch.zeros(discr_fake_pred.shape).to(discr_fake_pred.device)
+ fake_loss = self.bce_loss(discr_fake_pred, real_mask_gt) * self.weight
+ return fake_loss, dict()
+
+ def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
+ generator: nn.Module, discriminator: nn.Module):
+ real_batch.requires_grad = True
+
+ def discriminator_loss(self,
+ mask: torch.Tensor,
+ discr_real_pred: torch.Tensor,
+ discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+
+ real_mask_gt = torch.zeros(discr_real_pred.shape).to(discr_real_pred.device)
+ sum_discr_loss = (self.bce_loss(discr_real_pred, real_mask_gt) + self.bce_loss(discr_fake_pred, mask)) / 2
+ metrics = dict(discr_real_out=discr_real_pred.mean(),
+ discr_fake_out=discr_fake_pred.mean(),
+ discr_real_gp=0)
+ return sum_discr_loss, metrics
+
+
+def make_discrim_loss(kind, **kwargs):
+ if kind == 'r1':
+ return NonSaturatingWithR1(**kwargs)
+ elif kind == 'bce':
+ return BCELoss(**kwargs)
+ raise ValueError(f'Unknown adversarial loss kind {kind}')
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/constants.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae3e5e151342232be8e2c2a77fe6fd5798dc2a8c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/constants.py
@@ -0,0 +1,152 @@
+weights = {"ade20k":
+ [6.34517766497462,
+ 9.328358208955224,
+ 11.389521640091116,
+ 16.10305958132045,
+ 20.833333333333332,
+ 22.22222222222222,
+ 25.125628140703515,
+ 43.29004329004329,
+ 50.5050505050505,
+ 54.6448087431694,
+ 55.24861878453038,
+ 60.24096385542168,
+ 62.5,
+ 66.2251655629139,
+ 84.74576271186442,
+ 90.90909090909092,
+ 91.74311926605505,
+ 96.15384615384616,
+ 96.15384615384616,
+ 97.08737864077669,
+ 102.04081632653062,
+ 135.13513513513513,
+ 149.2537313432836,
+ 153.84615384615384,
+ 163.93442622950818,
+ 166.66666666666666,
+ 188.67924528301887,
+ 192.30769230769232,
+ 217.3913043478261,
+ 227.27272727272725,
+ 227.27272727272725,
+ 227.27272727272725,
+ 303.03030303030306,
+ 322.5806451612903,
+ 333.3333333333333,
+ 370.3703703703703,
+ 384.61538461538464,
+ 416.6666666666667,
+ 416.6666666666667,
+ 434.7826086956522,
+ 434.7826086956522,
+ 454.5454545454545,
+ 454.5454545454545,
+ 500.0,
+ 526.3157894736842,
+ 526.3157894736842,
+ 555.5555555555555,
+ 555.5555555555555,
+ 555.5555555555555,
+ 555.5555555555555,
+ 555.5555555555555,
+ 555.5555555555555,
+ 555.5555555555555,
+ 588.2352941176471,
+ 588.2352941176471,
+ 588.2352941176471,
+ 588.2352941176471,
+ 588.2352941176471,
+ 666.6666666666666,
+ 666.6666666666666,
+ 666.6666666666666,
+ 666.6666666666666,
+ 714.2857142857143,
+ 714.2857142857143,
+ 714.2857142857143,
+ 714.2857142857143,
+ 714.2857142857143,
+ 769.2307692307693,
+ 769.2307692307693,
+ 769.2307692307693,
+ 833.3333333333334,
+ 833.3333333333334,
+ 833.3333333333334,
+ 833.3333333333334,
+ 909.090909090909,
+ 1000.0,
+ 1111.111111111111,
+ 1111.111111111111,
+ 1111.111111111111,
+ 1111.111111111111,
+ 1111.111111111111,
+ 1250.0,
+ 1250.0,
+ 1250.0,
+ 1250.0,
+ 1250.0,
+ 1428.5714285714287,
+ 1428.5714285714287,
+ 1428.5714285714287,
+ 1428.5714285714287,
+ 1428.5714285714287,
+ 1428.5714285714287,
+ 1428.5714285714287,
+ 1666.6666666666667,
+ 1666.6666666666667,
+ 1666.6666666666667,
+ 1666.6666666666667,
+ 1666.6666666666667,
+ 1666.6666666666667,
+ 1666.6666666666667,
+ 1666.6666666666667,
+ 1666.6666666666667,
+ 1666.6666666666667,
+ 1666.6666666666667,
+ 2000.0,
+ 2000.0,
+ 2000.0,
+ 2000.0,
+ 2000.0,
+ 2000.0,
+ 2000.0,
+ 2000.0,
+ 2000.0,
+ 2000.0,
+ 2000.0,
+ 2000.0,
+ 2000.0,
+ 2000.0,
+ 2000.0,
+ 2000.0,
+ 2000.0,
+ 2500.0,
+ 2500.0,
+ 2500.0,
+ 2500.0,
+ 2500.0,
+ 2500.0,
+ 2500.0,
+ 2500.0,
+ 2500.0,
+ 2500.0,
+ 2500.0,
+ 2500.0,
+ 2500.0,
+ 3333.3333333333335,
+ 3333.3333333333335,
+ 3333.3333333333335,
+ 3333.3333333333335,
+ 3333.3333333333335,
+ 3333.3333333333335,
+ 3333.3333333333335,
+ 3333.3333333333335,
+ 3333.3333333333335,
+ 3333.3333333333335,
+ 3333.3333333333335,
+ 3333.3333333333335,
+ 3333.3333333333335,
+ 5000.0,
+ 5000.0,
+ 5000.0]
+}
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/distance_weighting.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/distance_weighting.py
new file mode 100644
index 0000000000000000000000000000000000000000..90ce05bee5f633662057b3347d8791e1b4d115a0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/distance_weighting.py
@@ -0,0 +1,126 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+
+from annotator.lama.saicinpainting.training.losses.perceptual import IMAGENET_STD, IMAGENET_MEAN
+
+
+def dummy_distance_weighter(real_img, pred_img, mask):
+ return mask
+
+
+def get_gauss_kernel(kernel_size, width_factor=1):
+ coords = torch.stack(torch.meshgrid(torch.arange(kernel_size),
+ torch.arange(kernel_size)),
+ dim=0).float()
+ diff = torch.exp(-((coords - kernel_size // 2) ** 2).sum(0) / kernel_size / width_factor)
+ diff /= diff.sum()
+ return diff
+
+
+class BlurMask(nn.Module):
+ def __init__(self, kernel_size=5, width_factor=1):
+ super().__init__()
+ self.filter = nn.Conv2d(1, 1, kernel_size, padding=kernel_size // 2, padding_mode='replicate', bias=False)
+ self.filter.weight.data.copy_(get_gauss_kernel(kernel_size, width_factor=width_factor))
+
+ def forward(self, real_img, pred_img, mask):
+ with torch.no_grad():
+ result = self.filter(mask) * mask
+ return result
+
+
+class EmulatedEDTMask(nn.Module):
+ def __init__(self, dilate_kernel_size=5, blur_kernel_size=5, width_factor=1):
+ super().__init__()
+ self.dilate_filter = nn.Conv2d(1, 1, dilate_kernel_size, padding=dilate_kernel_size// 2, padding_mode='replicate',
+ bias=False)
+ self.dilate_filter.weight.data.copy_(torch.ones(1, 1, dilate_kernel_size, dilate_kernel_size, dtype=torch.float))
+ self.blur_filter = nn.Conv2d(1, 1, blur_kernel_size, padding=blur_kernel_size // 2, padding_mode='replicate', bias=False)
+ self.blur_filter.weight.data.copy_(get_gauss_kernel(blur_kernel_size, width_factor=width_factor))
+
+ def forward(self, real_img, pred_img, mask):
+ with torch.no_grad():
+ known_mask = 1 - mask
+ dilated_known_mask = (self.dilate_filter(known_mask) > 1).float()
+ result = self.blur_filter(1 - dilated_known_mask) * mask
+ return result
+
+
+class PropagatePerceptualSim(nn.Module):
+ def __init__(self, level=2, max_iters=10, temperature=500, erode_mask_size=3):
+ super().__init__()
+ vgg = torchvision.models.vgg19(pretrained=True).features
+ vgg_avg_pooling = []
+
+ for weights in vgg.parameters():
+ weights.requires_grad = False
+
+ cur_level_i = 0
+ for module in vgg.modules():
+ if module.__class__.__name__ == 'Sequential':
+ continue
+ elif module.__class__.__name__ == 'MaxPool2d':
+ vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
+ else:
+ vgg_avg_pooling.append(module)
+ if module.__class__.__name__ == 'ReLU':
+ cur_level_i += 1
+ if cur_level_i == level:
+ break
+
+ self.features = nn.Sequential(*vgg_avg_pooling)
+
+ self.max_iters = max_iters
+ self.temperature = temperature
+ self.do_erode = erode_mask_size > 0
+ if self.do_erode:
+ self.erode_mask = nn.Conv2d(1, 1, erode_mask_size, padding=erode_mask_size // 2, bias=False)
+ self.erode_mask.weight.data.fill_(1)
+
+ def forward(self, real_img, pred_img, mask):
+ with torch.no_grad():
+ real_img = (real_img - IMAGENET_MEAN.to(real_img)) / IMAGENET_STD.to(real_img)
+ real_feats = self.features(real_img)
+
+ vertical_sim = torch.exp(-(real_feats[:, :, 1:] - real_feats[:, :, :-1]).pow(2).sum(1, keepdim=True)
+ / self.temperature)
+ horizontal_sim = torch.exp(-(real_feats[:, :, :, 1:] - real_feats[:, :, :, :-1]).pow(2).sum(1, keepdim=True)
+ / self.temperature)
+
+ mask_scaled = F.interpolate(mask, size=real_feats.shape[-2:], mode='bilinear', align_corners=False)
+ if self.do_erode:
+ mask_scaled = (self.erode_mask(mask_scaled) > 1).float()
+
+ cur_knowness = 1 - mask_scaled
+
+ for iter_i in range(self.max_iters):
+ new_top_knowness = F.pad(cur_knowness[:, :, :-1] * vertical_sim, (0, 0, 1, 0), mode='replicate')
+ new_bottom_knowness = F.pad(cur_knowness[:, :, 1:] * vertical_sim, (0, 0, 0, 1), mode='replicate')
+
+ new_left_knowness = F.pad(cur_knowness[:, :, :, :-1] * horizontal_sim, (1, 0, 0, 0), mode='replicate')
+ new_right_knowness = F.pad(cur_knowness[:, :, :, 1:] * horizontal_sim, (0, 1, 0, 0), mode='replicate')
+
+ new_knowness = torch.stack([new_top_knowness, new_bottom_knowness,
+ new_left_knowness, new_right_knowness],
+ dim=0).max(0).values
+
+ cur_knowness = torch.max(cur_knowness, new_knowness)
+
+ cur_knowness = F.interpolate(cur_knowness, size=mask.shape[-2:], mode='bilinear')
+ result = torch.min(mask, 1 - cur_knowness)
+
+ return result
+
+
+def make_mask_distance_weighter(kind='none', **kwargs):
+ if kind == 'none':
+ return dummy_distance_weighter
+ if kind == 'blur':
+ return BlurMask(**kwargs)
+ if kind == 'edt':
+ return EmulatedEDTMask(**kwargs)
+ if kind == 'pps':
+ return PropagatePerceptualSim(**kwargs)
+ raise ValueError(f'Unknown mask distance weighter kind {kind}')
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/feature_matching.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/feature_matching.py
new file mode 100644
index 0000000000000000000000000000000000000000..c019895c9178817837d1a6773367b178a861dc61
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/feature_matching.py
@@ -0,0 +1,33 @@
+from typing import List
+
+import torch
+import torch.nn.functional as F
+
+
+def masked_l2_loss(pred, target, mask, weight_known, weight_missing):
+ per_pixel_l2 = F.mse_loss(pred, target, reduction='none')
+ pixel_weights = mask * weight_missing + (1 - mask) * weight_known
+ return (pixel_weights * per_pixel_l2).mean()
+
+
+def masked_l1_loss(pred, target, mask, weight_known, weight_missing):
+ per_pixel_l1 = F.l1_loss(pred, target, reduction='none')
+ pixel_weights = mask * weight_missing + (1 - mask) * weight_known
+ return (pixel_weights * per_pixel_l1).mean()
+
+
+def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None):
+ if mask is None:
+ res = torch.stack([F.mse_loss(fake_feat, target_feat)
+ for fake_feat, target_feat in zip(fake_features, target_features)]).mean()
+ else:
+ res = 0
+ norm = 0
+ for fake_feat, target_feat in zip(fake_features, target_features):
+ cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False)
+ error_weights = 1 - cur_mask
+ cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean()
+ res = res + cur_val
+ norm += 1
+ res = res / norm
+ return res
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/perceptual.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/perceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d8b0b309b2b8ba95172cb16af440033a4aeafae
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/perceptual.py
@@ -0,0 +1,113 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+
+# from models.ade20k import ModelBuilder
+from annotator.lama.saicinpainting.utils import check_and_warn_input_range
+
+
+IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None]
+IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None]
+
+
+class PerceptualLoss(nn.Module):
+ def __init__(self, normalize_inputs=True):
+ super(PerceptualLoss, self).__init__()
+
+ self.normalize_inputs = normalize_inputs
+ self.mean_ = IMAGENET_MEAN
+ self.std_ = IMAGENET_STD
+
+ vgg = torchvision.models.vgg19(pretrained=True).features
+ vgg_avg_pooling = []
+
+ for weights in vgg.parameters():
+ weights.requires_grad = False
+
+ for module in vgg.modules():
+ if module.__class__.__name__ == 'Sequential':
+ continue
+ elif module.__class__.__name__ == 'MaxPool2d':
+ vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
+ else:
+ vgg_avg_pooling.append(module)
+
+ self.vgg = nn.Sequential(*vgg_avg_pooling)
+
+ def do_normalize_inputs(self, x):
+ return (x - self.mean_.to(x.device)) / self.std_.to(x.device)
+
+ def partial_losses(self, input, target, mask=None):
+ check_and_warn_input_range(target, 0, 1, 'PerceptualLoss target in partial_losses')
+
+ # we expect input and target to be in [0, 1] range
+ losses = []
+
+ if self.normalize_inputs:
+ features_input = self.do_normalize_inputs(input)
+ features_target = self.do_normalize_inputs(target)
+ else:
+ features_input = input
+ features_target = target
+
+ for layer in self.vgg[:30]:
+
+ features_input = layer(features_input)
+ features_target = layer(features_target)
+
+ if layer.__class__.__name__ == 'ReLU':
+ loss = F.mse_loss(features_input, features_target, reduction='none')
+
+ if mask is not None:
+ cur_mask = F.interpolate(mask, size=features_input.shape[-2:],
+ mode='bilinear', align_corners=False)
+ loss = loss * (1 - cur_mask)
+
+ loss = loss.mean(dim=tuple(range(1, len(loss.shape))))
+ losses.append(loss)
+
+ return losses
+
+ def forward(self, input, target, mask=None):
+ losses = self.partial_losses(input, target, mask=mask)
+ return torch.stack(losses).sum(dim=0)
+
+ def get_global_features(self, input):
+ check_and_warn_input_range(input, 0, 1, 'PerceptualLoss input in get_global_features')
+
+ if self.normalize_inputs:
+ features_input = self.do_normalize_inputs(input)
+ else:
+ features_input = input
+
+ features_input = self.vgg(features_input)
+ return features_input
+
+
+class ResNetPL(nn.Module):
+ def __init__(self, weight=1,
+ weights_path=None, arch_encoder='resnet50dilated', segmentation=True):
+ super().__init__()
+ self.impl = ModelBuilder.get_encoder(weights_path=weights_path,
+ arch_encoder=arch_encoder,
+ arch_decoder='ppm_deepsup',
+ fc_dim=2048,
+ segmentation=segmentation)
+ self.impl.eval()
+ for w in self.impl.parameters():
+ w.requires_grad_(False)
+
+ self.weight = weight
+
+ def forward(self, pred, target):
+ pred = (pred - IMAGENET_MEAN.to(pred)) / IMAGENET_STD.to(pred)
+ target = (target - IMAGENET_MEAN.to(target)) / IMAGENET_STD.to(target)
+
+ pred_feats = self.impl(pred, return_feature_maps=True)
+ target_feats = self.impl(target, return_feature_maps=True)
+
+ result = torch.stack([F.mse_loss(cur_pred, cur_target)
+ for cur_pred, cur_target
+ in zip(pred_feats, target_feats)]).sum() * self.weight
+ return result
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/segmentation.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d4a9f94eaae84722db584277dbbf9bc41ede357
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/segmentation.py
@@ -0,0 +1,43 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .constants import weights as constant_weights
+
+
+class CrossEntropy2d(nn.Module):
+ def __init__(self, reduction="mean", ignore_label=255, weights=None, *args, **kwargs):
+ """
+ weight (Tensor, optional): a manual rescaling weight given to each class.
+ If given, has to be a Tensor of size "nclasses"
+ """
+ super(CrossEntropy2d, self).__init__()
+ self.reduction = reduction
+ self.ignore_label = ignore_label
+ self.weights = weights
+ if self.weights is not None:
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.weights = torch.FloatTensor(constant_weights[weights]).to(device)
+
+ def forward(self, predict, target):
+ """
+ Args:
+ predict:(n, c, h, w)
+ target:(n, 1, h, w)
+ """
+ target = target.long()
+ assert not target.requires_grad
+ assert predict.dim() == 4, "{0}".format(predict.size())
+ assert target.dim() == 4, "{0}".format(target.size())
+ assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
+ assert target.size(1) == 1, "{0}".format(target.size(1))
+ assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2))
+ assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3))
+ target = target.squeeze(1)
+ n, c, h, w = predict.size()
+ target_mask = (target >= 0) * (target != self.ignore_label)
+ target = target[target_mask]
+ predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
+ predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
+ loss = F.cross_entropy(predict, target, weight=self.weights, reduction=self.reduction)
+ return loss
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/style_loss.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/style_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bb42d7fbc5d17a47bec7365889868505f5fdfb5
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/losses/style_loss.py
@@ -0,0 +1,155 @@
+import torch
+import torch.nn as nn
+import torchvision.models as models
+
+
+class PerceptualLoss(nn.Module):
+ r"""
+ Perceptual loss, VGG-based
+ https://arxiv.org/abs/1603.08155
+ https://github.com/dxyang/StyleTransfer/blob/master/utils.py
+ """
+
+ def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
+ super(PerceptualLoss, self).__init__()
+ self.add_module('vgg', VGG19())
+ self.criterion = torch.nn.L1Loss()
+ self.weights = weights
+
+ def __call__(self, x, y):
+ # Compute features
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
+
+ content_loss = 0.0
+ content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
+ content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
+ content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
+ content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
+ content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
+
+
+ return content_loss
+
+
+class VGG19(torch.nn.Module):
+ def __init__(self):
+ super(VGG19, self).__init__()
+ features = models.vgg19(pretrained=True).features
+ self.relu1_1 = torch.nn.Sequential()
+ self.relu1_2 = torch.nn.Sequential()
+
+ self.relu2_1 = torch.nn.Sequential()
+ self.relu2_2 = torch.nn.Sequential()
+
+ self.relu3_1 = torch.nn.Sequential()
+ self.relu3_2 = torch.nn.Sequential()
+ self.relu3_3 = torch.nn.Sequential()
+ self.relu3_4 = torch.nn.Sequential()
+
+ self.relu4_1 = torch.nn.Sequential()
+ self.relu4_2 = torch.nn.Sequential()
+ self.relu4_3 = torch.nn.Sequential()
+ self.relu4_4 = torch.nn.Sequential()
+
+ self.relu5_1 = torch.nn.Sequential()
+ self.relu5_2 = torch.nn.Sequential()
+ self.relu5_3 = torch.nn.Sequential()
+ self.relu5_4 = torch.nn.Sequential()
+
+ for x in range(2):
+ self.relu1_1.add_module(str(x), features[x])
+
+ for x in range(2, 4):
+ self.relu1_2.add_module(str(x), features[x])
+
+ for x in range(4, 7):
+ self.relu2_1.add_module(str(x), features[x])
+
+ for x in range(7, 9):
+ self.relu2_2.add_module(str(x), features[x])
+
+ for x in range(9, 12):
+ self.relu3_1.add_module(str(x), features[x])
+
+ for x in range(12, 14):
+ self.relu3_2.add_module(str(x), features[x])
+
+ for x in range(14, 16):
+ self.relu3_2.add_module(str(x), features[x])
+
+ for x in range(16, 18):
+ self.relu3_4.add_module(str(x), features[x])
+
+ for x in range(18, 21):
+ self.relu4_1.add_module(str(x), features[x])
+
+ for x in range(21, 23):
+ self.relu4_2.add_module(str(x), features[x])
+
+ for x in range(23, 25):
+ self.relu4_3.add_module(str(x), features[x])
+
+ for x in range(25, 27):
+ self.relu4_4.add_module(str(x), features[x])
+
+ for x in range(27, 30):
+ self.relu5_1.add_module(str(x), features[x])
+
+ for x in range(30, 32):
+ self.relu5_2.add_module(str(x), features[x])
+
+ for x in range(32, 34):
+ self.relu5_3.add_module(str(x), features[x])
+
+ for x in range(34, 36):
+ self.relu5_4.add_module(str(x), features[x])
+
+ # don't need the gradients, just want the features
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ relu1_1 = self.relu1_1(x)
+ relu1_2 = self.relu1_2(relu1_1)
+
+ relu2_1 = self.relu2_1(relu1_2)
+ relu2_2 = self.relu2_2(relu2_1)
+
+ relu3_1 = self.relu3_1(relu2_2)
+ relu3_2 = self.relu3_2(relu3_1)
+ relu3_3 = self.relu3_3(relu3_2)
+ relu3_4 = self.relu3_4(relu3_3)
+
+ relu4_1 = self.relu4_1(relu3_4)
+ relu4_2 = self.relu4_2(relu4_1)
+ relu4_3 = self.relu4_3(relu4_2)
+ relu4_4 = self.relu4_4(relu4_3)
+
+ relu5_1 = self.relu5_1(relu4_4)
+ relu5_2 = self.relu5_2(relu5_1)
+ relu5_3 = self.relu5_3(relu5_2)
+ relu5_4 = self.relu5_4(relu5_3)
+
+ out = {
+ 'relu1_1': relu1_1,
+ 'relu1_2': relu1_2,
+
+ 'relu2_1': relu2_1,
+ 'relu2_2': relu2_2,
+
+ 'relu3_1': relu3_1,
+ 'relu3_2': relu3_2,
+ 'relu3_3': relu3_3,
+ 'relu3_4': relu3_4,
+
+ 'relu4_1': relu4_1,
+ 'relu4_2': relu4_2,
+ 'relu4_3': relu4_3,
+ 'relu4_4': relu4_4,
+
+ 'relu5_1': relu5_1,
+ 'relu5_2': relu5_2,
+ 'relu5_3': relu5_3,
+ 'relu5_4': relu5_4,
+ }
+ return out
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/__init__.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5c56ad9965ec95f3ae28c35c2ab42456eb06066
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/__init__.py
@@ -0,0 +1,31 @@
+import logging
+
+from annotator.lama.saicinpainting.training.modules.ffc import FFCResNetGenerator
+from annotator.lama.saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \
+ NLayerDiscriminator, MultidilatedNLayerDiscriminator
+
+def make_generator(config, kind, **kwargs):
+ logging.info(f'Make generator {kind}')
+
+ if kind == 'pix2pixhd_multidilated':
+ return MultiDilatedGlobalGenerator(**kwargs)
+
+ if kind == 'pix2pixhd_global':
+ return GlobalGenerator(**kwargs)
+
+ if kind == 'ffc_resnet':
+ return FFCResNetGenerator(**kwargs)
+
+ raise ValueError(f'Unknown generator kind {kind}')
+
+
+def make_discriminator(kind, **kwargs):
+ logging.info(f'Make discriminator {kind}')
+
+ if kind == 'pix2pixhd_nlayer_multidilated':
+ return MultidilatedNLayerDiscriminator(**kwargs)
+
+ if kind == 'pix2pixhd_nlayer':
+ return NLayerDiscriminator(**kwargs)
+
+ raise ValueError(f'Unknown discriminator kind {kind}')
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/base.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..58c513987601d6a442ca8f066f82f1af46e28939
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/base.py
@@ -0,0 +1,80 @@
+import abc
+from typing import Tuple, List
+
+import torch
+import torch.nn as nn
+
+from annotator.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
+from annotator.lama.saicinpainting.training.modules.multidilated_conv import MultidilatedConv
+
+
+class BaseDiscriminator(nn.Module):
+ @abc.abstractmethod
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """
+ Predict scores and get intermediate activations. Useful for feature matching loss
+ :return tuple (scores, list of intermediate activations)
+ """
+ raise NotImplemented()
+
+
+def get_conv_block_ctor(kind='default'):
+ if not isinstance(kind, str):
+ return kind
+ if kind == 'default':
+ return nn.Conv2d
+ if kind == 'depthwise':
+ return DepthWiseSeperableConv
+ if kind == 'multidilated':
+ return MultidilatedConv
+ raise ValueError(f'Unknown convolutional block kind {kind}')
+
+
+def get_norm_layer(kind='bn'):
+ if not isinstance(kind, str):
+ return kind
+ if kind == 'bn':
+ return nn.BatchNorm2d
+ if kind == 'in':
+ return nn.InstanceNorm2d
+ raise ValueError(f'Unknown norm block kind {kind}')
+
+
+def get_activation(kind='tanh'):
+ if kind == 'tanh':
+ return nn.Tanh()
+ if kind == 'sigmoid':
+ return nn.Sigmoid()
+ if kind is False:
+ return nn.Identity()
+ raise ValueError(f'Unknown activation kind {kind}')
+
+
+class SimpleMultiStepGenerator(nn.Module):
+ def __init__(self, steps: List[nn.Module]):
+ super().__init__()
+ self.steps = nn.ModuleList(steps)
+
+ def forward(self, x):
+ cur_in = x
+ outs = []
+ for step in self.steps:
+ cur_out = step(cur_in)
+ outs.append(cur_out)
+ cur_in = torch.cat((cur_in, cur_out), dim=1)
+ return torch.cat(outs[::-1], dim=1)
+
+def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features):
+ if kind == 'convtranspose':
+ return [nn.ConvTranspose2d(min(max_features, ngf * mult),
+ min(max_features, int(ngf * mult / 2)),
+ kernel_size=3, stride=2, padding=1, output_padding=1),
+ norm_layer(min(max_features, int(ngf * mult / 2))), activation]
+ elif kind == 'bilinear':
+ return [nn.Upsample(scale_factor=2, mode='bilinear'),
+ DepthWiseSeperableConv(min(max_features, ngf * mult),
+ min(max_features, int(ngf * mult / 2)),
+ kernel_size=3, stride=1, padding=1),
+ norm_layer(min(max_features, int(ngf * mult / 2))), activation]
+ else:
+ raise Exception(f"Invalid deconv kind: {kind}")
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/depthwise_sep_conv.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/depthwise_sep_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..83dd15c3df1d9f40baf0091a373fa224532c9ddd
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/depthwise_sep_conv.py
@@ -0,0 +1,17 @@
+import torch
+import torch.nn as nn
+
+class DepthWiseSeperableConv(nn.Module):
+ def __init__(self, in_dim, out_dim, *args, **kwargs):
+ super().__init__()
+ if 'groups' in kwargs:
+ # ignoring groups for Depthwise Sep Conv
+ del kwargs['groups']
+
+ self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs)
+ self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1)
+
+ def forward(self, x):
+ out = self.depthwise(x)
+ out = self.pointwise(out)
+ return out
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/fake_fakes.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/fake_fakes.py
new file mode 100644
index 0000000000000000000000000000000000000000..45c4ad559cef2730b771a709197e00ae1c87683c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/fake_fakes.py
@@ -0,0 +1,47 @@
+import torch
+from kornia import SamplePadding
+from kornia.augmentation import RandomAffine, CenterCrop
+
+
+class FakeFakesGenerator:
+ def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2):
+ self.grad_aug = RandomAffine(degrees=360,
+ translate=0.2,
+ padding_mode=SamplePadding.REFLECTION,
+ keepdim=False,
+ p=1)
+ self.img_aug = RandomAffine(degrees=img_aug_degree,
+ translate=img_aug_translate,
+ padding_mode=SamplePadding.REFLECTION,
+ keepdim=True,
+ p=1)
+ self.aug_proba = aug_proba
+
+ def __call__(self, input_images, masks):
+ blend_masks = self._fill_masks_with_gradient(masks)
+ blend_target = self._make_blend_target(input_images)
+ result = input_images * (1 - blend_masks) + blend_target * blend_masks
+ return result, blend_masks
+
+ def _make_blend_target(self, input_images):
+ batch_size = input_images.shape[0]
+ permuted = input_images[torch.randperm(batch_size)]
+ augmented = self.img_aug(input_images)
+ is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float()
+ result = augmented * is_aug + permuted * (1 - is_aug)
+ return result
+
+ def _fill_masks_with_gradient(self, masks):
+ batch_size, _, height, width = masks.shape
+ grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \
+ .view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2)
+ grad = self.grad_aug(grad)
+ grad = CenterCrop((height, width))(grad)
+ grad *= masks
+
+ grad_for_min = grad + (1 - masks) * 10
+ grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None]
+ grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6
+ grad.clamp_(min=0, max=1)
+
+ return grad
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/ffc.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/ffc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e67ff9c832463e5518d6ccea2c6f27531ed778d4
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/ffc.py
@@ -0,0 +1,485 @@
+# Fast Fourier Convolution NeurIPS 2020
+# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
+# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from annotator.lama.saicinpainting.training.modules.base import get_activation, BaseDiscriminator
+from annotator.lama.saicinpainting.training.modules.spatial_transform import LearnableSpatialTransformWrapper
+from annotator.lama.saicinpainting.training.modules.squeeze_excitation import SELayer
+from annotator.lama.saicinpainting.utils import get_shape
+
+
+class FFCSE_block(nn.Module):
+
+ def __init__(self, channels, ratio_g):
+ super(FFCSE_block, self).__init__()
+ in_cg = int(channels * ratio_g)
+ in_cl = channels - in_cg
+ r = 16
+
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ self.conv1 = nn.Conv2d(channels, channels // r,
+ kernel_size=1, bias=True)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.conv_a2l = None if in_cl == 0 else nn.Conv2d(
+ channels // r, in_cl, kernel_size=1, bias=True)
+ self.conv_a2g = None if in_cg == 0 else nn.Conv2d(
+ channels // r, in_cg, kernel_size=1, bias=True)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ x = x if type(x) is tuple else (x, 0)
+ id_l, id_g = x
+
+ x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1)
+ x = self.avgpool(x)
+ x = self.relu1(self.conv1(x))
+
+ x_l = 0 if self.conv_a2l is None else id_l * \
+ self.sigmoid(self.conv_a2l(x))
+ x_g = 0 if self.conv_a2g is None else id_g * \
+ self.sigmoid(self.conv_a2g(x))
+ return x_l, x_g
+
+
+class FourierUnit(nn.Module):
+
+ def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
+ spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'):
+ # bn_layer not used
+ super(FourierUnit, self).__init__()
+ self.groups = groups
+
+ self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
+ out_channels=out_channels * 2,
+ kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
+ self.bn = torch.nn.BatchNorm2d(out_channels * 2)
+ self.relu = torch.nn.ReLU(inplace=True)
+
+ # squeeze and excitation block
+ self.use_se = use_se
+ if use_se:
+ if se_kwargs is None:
+ se_kwargs = {}
+ self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
+
+ self.spatial_scale_factor = spatial_scale_factor
+ self.spatial_scale_mode = spatial_scale_mode
+ self.spectral_pos_encoding = spectral_pos_encoding
+ self.ffc3d = ffc3d
+ self.fft_norm = fft_norm
+
+ def forward(self, x):
+ batch = x.shape[0]
+
+ if self.spatial_scale_factor is not None:
+ orig_size = x.shape[-2:]
+ x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False)
+
+ r_size = x.size()
+ # (batch, c, h, w/2+1, 2)
+ fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
+ ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
+ ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
+ ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
+ ffted = ffted.view((batch, -1,) + ffted.size()[3:])
+
+ if self.spectral_pos_encoding:
+ height, width = ffted.shape[-2:]
+ coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted)
+ coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted)
+ ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
+
+ if self.use_se:
+ ffted = self.se(ffted)
+
+ ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1)
+ ffted = self.relu(self.bn(ffted))
+
+ ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
+ 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
+ ffted = torch.complex(ffted[..., 0], ffted[..., 1])
+
+ ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
+ output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
+
+ if self.spatial_scale_factor is not None:
+ output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
+
+ return output
+
+
+class SeparableFourierUnit(nn.Module):
+
+ def __init__(self, in_channels, out_channels, groups=1, kernel_size=3):
+ # bn_layer not used
+ super(SeparableFourierUnit, self).__init__()
+ self.groups = groups
+ row_out_channels = out_channels // 2
+ col_out_channels = out_channels - row_out_channels
+ self.row_conv = torch.nn.Conv2d(in_channels=in_channels * 2,
+ out_channels=row_out_channels * 2,
+ kernel_size=(kernel_size, 1), # kernel size is always like this, but the data will be transposed
+ stride=1, padding=(kernel_size // 2, 0),
+ padding_mode='reflect',
+ groups=self.groups, bias=False)
+ self.col_conv = torch.nn.Conv2d(in_channels=in_channels * 2,
+ out_channels=col_out_channels * 2,
+ kernel_size=(kernel_size, 1), # kernel size is always like this, but the data will be transposed
+ stride=1, padding=(kernel_size // 2, 0),
+ padding_mode='reflect',
+ groups=self.groups, bias=False)
+ self.row_bn = torch.nn.BatchNorm2d(row_out_channels * 2)
+ self.col_bn = torch.nn.BatchNorm2d(col_out_channels * 2)
+ self.relu = torch.nn.ReLU(inplace=True)
+
+ def process_branch(self, x, conv, bn):
+ batch = x.shape[0]
+
+ r_size = x.size()
+ # (batch, c, h, w/2+1, 2)
+ ffted = torch.fft.rfft(x, norm="ortho")
+ ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
+ ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
+ ffted = ffted.view((batch, -1,) + ffted.size()[3:])
+
+ ffted = self.relu(bn(conv(ffted)))
+
+ ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
+ 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
+ ffted = torch.complex(ffted[..., 0], ffted[..., 1])
+
+ output = torch.fft.irfft(ffted, s=x.shape[-1:], norm="ortho")
+ return output
+
+
+ def forward(self, x):
+ rowwise = self.process_branch(x, self.row_conv, self.row_bn)
+ colwise = self.process_branch(x.permute(0, 1, 3, 2), self.col_conv, self.col_bn).permute(0, 1, 3, 2)
+ out = torch.cat((rowwise, colwise), dim=1)
+ return out
+
+
+class SpectralTransform(nn.Module):
+
+ def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, separable_fu=False, **fu_kwargs):
+ # bn_layer not used
+ super(SpectralTransform, self).__init__()
+ self.enable_lfu = enable_lfu
+ if stride == 2:
+ self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
+ else:
+ self.downsample = nn.Identity()
+
+ self.stride = stride
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels //
+ 2, kernel_size=1, groups=groups, bias=False),
+ nn.BatchNorm2d(out_channels // 2),
+ nn.ReLU(inplace=True)
+ )
+ fu_class = SeparableFourierUnit if separable_fu else FourierUnit
+ self.fu = fu_class(
+ out_channels // 2, out_channels // 2, groups, **fu_kwargs)
+ if self.enable_lfu:
+ self.lfu = fu_class(
+ out_channels // 2, out_channels // 2, groups)
+ self.conv2 = torch.nn.Conv2d(
+ out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)
+
+ def forward(self, x):
+
+ x = self.downsample(x)
+ x = self.conv1(x)
+ output = self.fu(x)
+
+ if self.enable_lfu:
+ n, c, h, w = x.shape
+ split_no = 2
+ split_s = h // split_no
+ xs = torch.cat(torch.split(
+ x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
+ xs = torch.cat(torch.split(xs, split_s, dim=-1),
+ dim=1).contiguous()
+ xs = self.lfu(xs)
+ xs = xs.repeat(1, 1, split_no, split_no).contiguous()
+ else:
+ xs = 0
+
+ output = self.conv2(x + output + xs)
+
+ return output
+
+
+class FFC(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size,
+ ratio_gin, ratio_gout, stride=1, padding=0,
+ dilation=1, groups=1, bias=False, enable_lfu=True,
+ padding_type='reflect', gated=False, **spectral_kwargs):
+ super(FFC, self).__init__()
+
+ assert stride == 1 or stride == 2, "Stride should be 1 or 2."
+ self.stride = stride
+
+ in_cg = int(in_channels * ratio_gin)
+ in_cl = in_channels - in_cg
+ out_cg = int(out_channels * ratio_gout)
+ out_cl = out_channels - out_cg
+ #groups_g = 1 if groups == 1 else int(groups * ratio_gout)
+ #groups_l = 1 if groups == 1 else groups - groups_g
+
+ self.ratio_gin = ratio_gin
+ self.ratio_gout = ratio_gout
+ self.global_in_num = in_cg
+
+ module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
+ self.convl2l = module(in_cl, out_cl, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
+ module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
+ self.convl2g = module(in_cl, out_cg, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
+ module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
+ self.convg2l = module(in_cg, out_cl, kernel_size,
+ stride, padding, dilation, groups, bias, padding_mode=padding_type)
+ module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
+ self.convg2g = module(
+ in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs)
+
+ self.gated = gated
+ module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
+ self.gate = module(in_channels, 2, 1)
+
+ def forward(self, x):
+ x_l, x_g = x if type(x) is tuple else (x, 0)
+ out_xl, out_xg = 0, 0
+
+ if self.gated:
+ total_input_parts = [x_l]
+ if torch.is_tensor(x_g):
+ total_input_parts.append(x_g)
+ total_input = torch.cat(total_input_parts, dim=1)
+
+ gates = torch.sigmoid(self.gate(total_input))
+ g2l_gate, l2g_gate = gates.chunk(2, dim=1)
+ else:
+ g2l_gate, l2g_gate = 1, 1
+
+ if self.ratio_gout != 1:
+ out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
+ if self.ratio_gout != 0:
+ out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
+
+ return out_xl, out_xg
+
+
+class FFC_BN_ACT(nn.Module):
+
+ def __init__(self, in_channels, out_channels,
+ kernel_size, ratio_gin, ratio_gout,
+ stride=1, padding=0, dilation=1, groups=1, bias=False,
+ norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity,
+ padding_type='reflect',
+ enable_lfu=True, **kwargs):
+ super(FFC_BN_ACT, self).__init__()
+ self.ffc = FFC(in_channels, out_channels, kernel_size,
+ ratio_gin, ratio_gout, stride, padding, dilation,
+ groups, bias, enable_lfu, padding_type=padding_type, **kwargs)
+ lnorm = nn.Identity if ratio_gout == 1 else norm_layer
+ gnorm = nn.Identity if ratio_gout == 0 else norm_layer
+ global_channels = int(out_channels * ratio_gout)
+ self.bn_l = lnorm(out_channels - global_channels)
+ self.bn_g = gnorm(global_channels)
+
+ lact = nn.Identity if ratio_gout == 1 else activation_layer
+ gact = nn.Identity if ratio_gout == 0 else activation_layer
+ self.act_l = lact(inplace=True)
+ self.act_g = gact(inplace=True)
+
+ def forward(self, x):
+ x_l, x_g = self.ffc(x)
+ x_l = self.act_l(self.bn_l(x_l))
+ x_g = self.act_g(self.bn_g(x_g))
+ return x_l, x_g
+
+
+class FFCResnetBlock(nn.Module):
+ def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
+ spatial_transform_kwargs=None, inline=False, **conv_kwargs):
+ super().__init__()
+ self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
+ norm_layer=norm_layer,
+ activation_layer=activation_layer,
+ padding_type=padding_type,
+ **conv_kwargs)
+ self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
+ norm_layer=norm_layer,
+ activation_layer=activation_layer,
+ padding_type=padding_type,
+ **conv_kwargs)
+ if spatial_transform_kwargs is not None:
+ self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs)
+ self.conv2 = LearnableSpatialTransformWrapper(self.conv2, **spatial_transform_kwargs)
+ self.inline = inline
+
+ def forward(self, x):
+ if self.inline:
+ x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:]
+ else:
+ x_l, x_g = x if type(x) is tuple else (x, 0)
+
+ id_l, id_g = x_l, x_g
+
+ x_l, x_g = self.conv1((x_l, x_g))
+ x_l, x_g = self.conv2((x_l, x_g))
+
+ x_l, x_g = id_l + x_l, id_g + x_g
+ out = x_l, x_g
+ if self.inline:
+ out = torch.cat(out, dim=1)
+ return out
+
+
+class ConcatTupleLayer(nn.Module):
+ def forward(self, x):
+ assert isinstance(x, tuple)
+ x_l, x_g = x
+ assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
+ if not torch.is_tensor(x_g):
+ return x_l
+ return torch.cat(x, dim=1)
+
+
+class FFCResNetGenerator(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
+ padding_type='reflect', activation_layer=nn.ReLU,
+ up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True),
+ init_conv_kwargs={}, downsample_conv_kwargs={}, resnet_conv_kwargs={},
+ spatial_transform_layers=None, spatial_transform_kwargs={},
+ add_out_act=True, max_features=1024, out_ffc=False, out_ffc_kwargs={}):
+ assert (n_blocks >= 0)
+ super().__init__()
+
+ model = [nn.ReflectionPad2d(3),
+ FFC_BN_ACT(input_nc, ngf, kernel_size=7, padding=0, norm_layer=norm_layer,
+ activation_layer=activation_layer, **init_conv_kwargs)]
+
+ ### downsample
+ for i in range(n_downsampling):
+ mult = 2 ** i
+ if i == n_downsampling - 1:
+ cur_conv_kwargs = dict(downsample_conv_kwargs)
+ cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get('ratio_gin', 0)
+ else:
+ cur_conv_kwargs = downsample_conv_kwargs
+ model += [FFC_BN_ACT(min(max_features, ngf * mult),
+ min(max_features, ngf * mult * 2),
+ kernel_size=3, stride=2, padding=1,
+ norm_layer=norm_layer,
+ activation_layer=activation_layer,
+ **cur_conv_kwargs)]
+
+ mult = 2 ** n_downsampling
+ feats_num_bottleneck = min(max_features, ngf * mult)
+
+ ### resnet blocks
+ for i in range(n_blocks):
+ cur_resblock = FFCResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation_layer=activation_layer,
+ norm_layer=norm_layer, **resnet_conv_kwargs)
+ if spatial_transform_layers is not None and i in spatial_transform_layers:
+ cur_resblock = LearnableSpatialTransformWrapper(cur_resblock, **spatial_transform_kwargs)
+ model += [cur_resblock]
+
+ model += [ConcatTupleLayer()]
+
+ ### upsample
+ for i in range(n_downsampling):
+ mult = 2 ** (n_downsampling - i)
+ model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
+ min(max_features, int(ngf * mult / 2)),
+ kernel_size=3, stride=2, padding=1, output_padding=1),
+ up_norm_layer(min(max_features, int(ngf * mult / 2))),
+ up_activation]
+
+ if out_ffc:
+ model += [FFCResnetBlock(ngf, padding_type=padding_type, activation_layer=activation_layer,
+ norm_layer=norm_layer, inline=True, **out_ffc_kwargs)]
+
+ model += [nn.ReflectionPad2d(3),
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ if add_out_act:
+ model.append(get_activation('tanh' if add_out_act is True else add_out_act))
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ return self.model(input)
+
+
+class FFCNLayerDiscriminator(BaseDiscriminator):
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, max_features=512,
+ init_conv_kwargs={}, conv_kwargs={}):
+ super().__init__()
+ self.n_layers = n_layers
+
+ def _act_ctor(inplace=True):
+ return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
+
+ kw = 3
+ padw = int(np.ceil((kw-1.0)/2))
+ sequence = [[FFC_BN_ACT(input_nc, ndf, kernel_size=kw, padding=padw, norm_layer=norm_layer,
+ activation_layer=_act_ctor, **init_conv_kwargs)]]
+
+ nf = ndf
+ for n in range(1, n_layers):
+ nf_prev = nf
+ nf = min(nf * 2, max_features)
+
+ cur_model = [
+ FFC_BN_ACT(nf_prev, nf,
+ kernel_size=kw, stride=2, padding=padw,
+ norm_layer=norm_layer,
+ activation_layer=_act_ctor,
+ **conv_kwargs)
+ ]
+ sequence.append(cur_model)
+
+ nf_prev = nf
+ nf = min(nf * 2, 512)
+
+ cur_model = [
+ FFC_BN_ACT(nf_prev, nf,
+ kernel_size=kw, stride=1, padding=padw,
+ norm_layer=norm_layer,
+ activation_layer=lambda *args, **kwargs: nn.LeakyReLU(*args, negative_slope=0.2, **kwargs),
+ **conv_kwargs),
+ ConcatTupleLayer()
+ ]
+ sequence.append(cur_model)
+
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
+
+ for n in range(len(sequence)):
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
+
+ def get_all_activations(self, x):
+ res = [x]
+ for n in range(self.n_layers + 2):
+ model = getattr(self, 'model' + str(n))
+ res.append(model(res[-1]))
+ return res[1:]
+
+ def forward(self, x):
+ act = self.get_all_activations(x)
+ feats = []
+ for out in act[:-1]:
+ if isinstance(out, tuple):
+ if torch.is_tensor(out[1]):
+ out = torch.cat(out, dim=1)
+ else:
+ out = out[0]
+ feats.append(out)
+ return act[-1], feats
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/multidilated_conv.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/multidilated_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57d0b457d4b30aeeffcd8cba138a502ba7affc5
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/multidilated_conv.py
@@ -0,0 +1,98 @@
+import torch
+import torch.nn as nn
+import random
+from annotator.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv
+
+class MultidilatedConv(nn.Module):
+ def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True,
+ shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs):
+ super().__init__()
+ convs = []
+ self.equal_dim = equal_dim
+ assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode
+ if comb_mode in ('cat_out', 'cat_both'):
+ self.cat_out = True
+ if equal_dim:
+ assert out_dim % dilation_num == 0
+ out_dims = [out_dim // dilation_num] * dilation_num
+ self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], [])
+ else:
+ out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
+ out_dims.append(out_dim - sum(out_dims))
+ index = []
+ starts = [0] + out_dims[:-1]
+ lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)]
+ for i in range(out_dims[-1]):
+ for j in range(dilation_num):
+ index += list(range(starts[j], starts[j] + lengths[j]))
+ starts[j] += lengths[j]
+ self.index = index
+ assert(len(index) == out_dim)
+ self.out_dims = out_dims
+ else:
+ self.cat_out = False
+ self.out_dims = [out_dim] * dilation_num
+
+ if comb_mode in ('cat_in', 'cat_both'):
+ if equal_dim:
+ assert in_dim % dilation_num == 0
+ in_dims = [in_dim // dilation_num] * dilation_num
+ else:
+ in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)]
+ in_dims.append(in_dim - sum(in_dims))
+ self.in_dims = in_dims
+ self.cat_in = True
+ else:
+ self.cat_in = False
+ self.in_dims = [in_dim] * dilation_num
+
+ conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d
+ dilation = min_dilation
+ for i in range(dilation_num):
+ if isinstance(padding, int):
+ cur_padding = padding * dilation
+ else:
+ cur_padding = padding[i]
+ convs.append(conv_type(
+ self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs
+ ))
+ if i > 0 and shared_weights:
+ convs[-1].weight = convs[0].weight
+ convs[-1].bias = convs[0].bias
+ dilation *= 2
+ self.convs = nn.ModuleList(convs)
+
+ self.shuffle_in_channels = shuffle_in_channels
+ if self.shuffle_in_channels:
+ # shuffle list as shuffling of tensors is nondeterministic
+ in_channels_permute = list(range(in_dim))
+ random.shuffle(in_channels_permute)
+ # save as buffer so it is saved and loaded with checkpoint
+ self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute))
+
+ def forward(self, x):
+ if self.shuffle_in_channels:
+ x = x[:, self.in_channels_permute]
+
+ outs = []
+ if self.cat_in:
+ if self.equal_dim:
+ x = x.chunk(len(self.convs), dim=1)
+ else:
+ new_x = []
+ start = 0
+ for dim in self.in_dims:
+ new_x.append(x[:, start:start+dim])
+ start += dim
+ x = new_x
+ for i, conv in enumerate(self.convs):
+ if self.cat_in:
+ input = x[i]
+ else:
+ input = x
+ outs.append(conv(input))
+ if self.cat_out:
+ out = torch.cat(outs, dim=1)[:, self.index]
+ else:
+ out = sum(outs)
+ return out
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/multiscale.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/multiscale.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f41252f3c7509ee58b939215baef328cfbe48c8
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/multiscale.py
@@ -0,0 +1,244 @@
+from typing import List, Tuple, Union, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from annotator.lama.saicinpainting.training.modules.base import get_conv_block_ctor, get_activation
+from annotator.lama.saicinpainting.training.modules.pix2pixhd import ResnetBlock
+
+
+class ResNetHead(nn.Module):
+ def __init__(self, input_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
+ padding_type='reflect', conv_kind='default', activation=nn.ReLU(True)):
+ assert (n_blocks >= 0)
+ super(ResNetHead, self).__init__()
+
+ conv_layer = get_conv_block_ctor(conv_kind)
+
+ model = [nn.ReflectionPad2d(3),
+ conv_layer(input_nc, ngf, kernel_size=7, padding=0),
+ norm_layer(ngf),
+ activation]
+
+ ### downsample
+ for i in range(n_downsampling):
+ mult = 2 ** i
+ model += [conv_layer(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
+ norm_layer(ngf * mult * 2),
+ activation]
+
+ mult = 2 ** n_downsampling
+
+ ### resnet blocks
+ for i in range(n_blocks):
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
+ conv_kind=conv_kind)]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ return self.model(input)
+
+
+class ResNetTail(nn.Module):
+ def __init__(self, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
+ padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
+ up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
+ add_in_proj=None):
+ assert (n_blocks >= 0)
+ super(ResNetTail, self).__init__()
+
+ mult = 2 ** n_downsampling
+
+ model = []
+
+ if add_in_proj is not None:
+ model.append(nn.Conv2d(add_in_proj, ngf * mult, kernel_size=1))
+
+ ### resnet blocks
+ for i in range(n_blocks):
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
+ conv_kind=conv_kind)]
+
+ ### upsample
+ for i in range(n_downsampling):
+ mult = 2 ** (n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1,
+ output_padding=1),
+ up_norm_layer(int(ngf * mult / 2)),
+ up_activation]
+ self.model = nn.Sequential(*model)
+
+ out_layers = []
+ for _ in range(out_extra_layers_n):
+ out_layers += [nn.Conv2d(ngf, ngf, kernel_size=1, padding=0),
+ up_norm_layer(ngf),
+ up_activation]
+ out_layers += [nn.ReflectionPad2d(3),
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+
+ if add_out_act:
+ out_layers.append(get_activation('tanh' if add_out_act is True else add_out_act))
+
+ self.out_proj = nn.Sequential(*out_layers)
+
+ def forward(self, input, return_last_act=False):
+ features = self.model(input)
+ out = self.out_proj(features)
+ if return_last_act:
+ return out, features
+ else:
+ return out
+
+
+class MultiscaleResNet(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=2, n_blocks_head=2, n_blocks_tail=6, n_scales=3,
+ norm_layer=nn.BatchNorm2d, padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
+ up_norm_layer=nn.BatchNorm2d, up_activation=nn.ReLU(True), add_out_act=False, out_extra_layers_n=0,
+ out_cumulative=False, return_only_hr=False):
+ super().__init__()
+
+ self.heads = nn.ModuleList([ResNetHead(input_nc, ngf=ngf, n_downsampling=n_downsampling,
+ n_blocks=n_blocks_head, norm_layer=norm_layer, padding_type=padding_type,
+ conv_kind=conv_kind, activation=activation)
+ for i in range(n_scales)])
+ tail_in_feats = ngf * (2 ** n_downsampling) + ngf
+ self.tails = nn.ModuleList([ResNetTail(output_nc,
+ ngf=ngf, n_downsampling=n_downsampling,
+ n_blocks=n_blocks_tail, norm_layer=norm_layer, padding_type=padding_type,
+ conv_kind=conv_kind, activation=activation, up_norm_layer=up_norm_layer,
+ up_activation=up_activation, add_out_act=add_out_act,
+ out_extra_layers_n=out_extra_layers_n,
+ add_in_proj=None if (i == n_scales - 1) else tail_in_feats)
+ for i in range(n_scales)])
+
+ self.out_cumulative = out_cumulative
+ self.return_only_hr = return_only_hr
+
+ @property
+ def num_scales(self):
+ return len(self.heads)
+
+ def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
+ -> Union[torch.Tensor, List[torch.Tensor]]:
+ """
+ :param ms_inputs: List of inputs of different resolutions from HR to LR
+ :param smallest_scales_num: int or None, number of smallest scales to take at input
+ :return: Depending on return_only_hr:
+ True: Only the most HR output
+ False: List of outputs of different resolutions from HR to LR
+ """
+ if smallest_scales_num is None:
+ assert len(self.heads) == len(ms_inputs), (len(self.heads), len(ms_inputs), smallest_scales_num)
+ smallest_scales_num = len(self.heads)
+ else:
+ assert smallest_scales_num == len(ms_inputs) <= len(self.heads), (len(self.heads), len(ms_inputs), smallest_scales_num)
+
+ cur_heads = self.heads[-smallest_scales_num:]
+ ms_features = [cur_head(cur_inp) for cur_head, cur_inp in zip(cur_heads, ms_inputs)]
+
+ all_outputs = []
+ prev_tail_features = None
+ for i in range(len(ms_features)):
+ scale_i = -i - 1
+
+ cur_tail_input = ms_features[-i - 1]
+ if prev_tail_features is not None:
+ if prev_tail_features.shape != cur_tail_input.shape:
+ prev_tail_features = F.interpolate(prev_tail_features, size=cur_tail_input.shape[2:],
+ mode='bilinear', align_corners=False)
+ cur_tail_input = torch.cat((cur_tail_input, prev_tail_features), dim=1)
+
+ cur_out, cur_tail_feats = self.tails[scale_i](cur_tail_input, return_last_act=True)
+
+ prev_tail_features = cur_tail_feats
+ all_outputs.append(cur_out)
+
+ if self.out_cumulative:
+ all_outputs_cum = [all_outputs[0]]
+ for i in range(1, len(ms_features)):
+ cur_out = all_outputs[i]
+ cur_out_cum = cur_out + F.interpolate(all_outputs_cum[-1], size=cur_out.shape[2:],
+ mode='bilinear', align_corners=False)
+ all_outputs_cum.append(cur_out_cum)
+ all_outputs = all_outputs_cum
+
+ if self.return_only_hr:
+ return all_outputs[-1]
+ else:
+ return all_outputs[::-1]
+
+
+class MultiscaleDiscriminatorSimple(nn.Module):
+ def __init__(self, ms_impl):
+ super().__init__()
+ self.ms_impl = nn.ModuleList(ms_impl)
+
+ @property
+ def num_scales(self):
+ return len(self.ms_impl)
+
+ def forward(self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None) \
+ -> List[Tuple[torch.Tensor, List[torch.Tensor]]]:
+ """
+ :param ms_inputs: List of inputs of different resolutions from HR to LR
+ :param smallest_scales_num: int or None, number of smallest scales to take at input
+ :return: List of pairs (prediction, features) for different resolutions from HR to LR
+ """
+ if smallest_scales_num is None:
+ assert len(self.ms_impl) == len(ms_inputs), (len(self.ms_impl), len(ms_inputs), smallest_scales_num)
+ smallest_scales_num = len(self.heads)
+ else:
+ assert smallest_scales_num == len(ms_inputs) <= len(self.ms_impl), \
+ (len(self.ms_impl), len(ms_inputs), smallest_scales_num)
+
+ return [cur_discr(cur_input) for cur_discr, cur_input in zip(self.ms_impl[-smallest_scales_num:], ms_inputs)]
+
+
+class SingleToMultiScaleInputMixin:
+ def forward(self, x: torch.Tensor) -> List:
+ orig_height, orig_width = x.shape[2:]
+ factors = [2 ** i for i in range(self.num_scales)]
+ ms_inputs = [F.interpolate(x, size=(orig_height // f, orig_width // f), mode='bilinear', align_corners=False)
+ for f in factors]
+ return super().forward(ms_inputs)
+
+
+class GeneratorMultiToSingleOutputMixin:
+ def forward(self, x):
+ return super().forward(x)[0]
+
+
+class DiscriminatorMultiToSingleOutputMixin:
+ def forward(self, x):
+ out_feat_tuples = super().forward(x)
+ return out_feat_tuples[0][0], [f for _, flist in out_feat_tuples for f in flist]
+
+
+class DiscriminatorMultiToSingleOutputStackedMixin:
+ def __init__(self, *args, return_feats_only_levels=None, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.return_feats_only_levels = return_feats_only_levels
+
+ def forward(self, x):
+ out_feat_tuples = super().forward(x)
+ outs = [out for out, _ in out_feat_tuples]
+ scaled_outs = [outs[0]] + [F.interpolate(cur_out, size=outs[0].shape[-2:],
+ mode='bilinear', align_corners=False)
+ for cur_out in outs[1:]]
+ out = torch.cat(scaled_outs, dim=1)
+ if self.return_feats_only_levels is not None:
+ feat_lists = [out_feat_tuples[i][1] for i in self.return_feats_only_levels]
+ else:
+ feat_lists = [flist for _, flist in out_feat_tuples]
+ feats = [f for flist in feat_lists for f in flist]
+ return out, feats
+
+
+class MultiscaleDiscrSingleInput(SingleToMultiScaleInputMixin, DiscriminatorMultiToSingleOutputStackedMixin, MultiscaleDiscriminatorSimple):
+ pass
+
+
+class MultiscaleResNetSingle(GeneratorMultiToSingleOutputMixin, SingleToMultiScaleInputMixin, MultiscaleResNet):
+ pass
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/pix2pixhd.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/pix2pixhd.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e4fcfcff083f9ce4d3c7880ff0f74f8f745a251
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/pix2pixhd.py
@@ -0,0 +1,669 @@
+# original: https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py
+import collections
+from functools import partial
+import functools
+import logging
+from collections import defaultdict
+
+import numpy as np
+import torch.nn as nn
+
+from annotator.lama.saicinpainting.training.modules.base import BaseDiscriminator, deconv_factory, get_conv_block_ctor, get_norm_layer, get_activation
+from annotator.lama.saicinpainting.training.modules.ffc import FFCResnetBlock
+from annotator.lama.saicinpainting.training.modules.multidilated_conv import MultidilatedConv
+
+class DotDict(defaultdict):
+ # https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary
+ """dot.notation access to dictionary attributes"""
+ __getattr__ = defaultdict.get
+ __setattr__ = defaultdict.__setitem__
+ __delattr__ = defaultdict.__delitem__
+
+class Identity(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
+ dilation=1, in_dim=None, groups=1, second_dilation=None):
+ super(ResnetBlock, self).__init__()
+ self.in_dim = in_dim
+ self.dim = dim
+ if second_dilation is None:
+ second_dilation = dilation
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
+ conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
+ second_dilation=second_dilation)
+
+ if self.in_dim is not None:
+ self.input_conv = nn.Conv2d(in_dim, dim, 1)
+
+ self.out_channnels = dim
+
+ def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
+ dilation=1, in_dim=None, groups=1, second_dilation=1):
+ conv_layer = get_conv_block_ctor(conv_kind)
+
+ conv_block = []
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(dilation)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(dilation)]
+ elif padding_type == 'zero':
+ p = dilation
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+
+ if in_dim is None:
+ in_dim = dim
+
+ conv_block += [conv_layer(in_dim, dim, kernel_size=3, padding=p, dilation=dilation),
+ norm_layer(dim),
+ activation]
+ if use_dropout:
+ conv_block += [nn.Dropout(0.5)]
+
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(second_dilation)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(second_dilation)]
+ elif padding_type == 'zero':
+ p = second_dilation
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+ conv_block += [conv_layer(dim, dim, kernel_size=3, padding=p, dilation=second_dilation, groups=groups),
+ norm_layer(dim)]
+
+ return nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ x_before = x
+ if self.in_dim is not None:
+ x = self.input_conv(x)
+ out = x + self.conv_block(x_before)
+ return out
+
+class ResnetBlock5x5(nn.Module):
+ def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False, conv_kind='default',
+ dilation=1, in_dim=None, groups=1, second_dilation=None):
+ super(ResnetBlock5x5, self).__init__()
+ self.in_dim = in_dim
+ self.dim = dim
+ if second_dilation is None:
+ second_dilation = dilation
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout,
+ conv_kind=conv_kind, dilation=dilation, in_dim=in_dim, groups=groups,
+ second_dilation=second_dilation)
+
+ if self.in_dim is not None:
+ self.input_conv = nn.Conv2d(in_dim, dim, 1)
+
+ self.out_channnels = dim
+
+ def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout, conv_kind='default',
+ dilation=1, in_dim=None, groups=1, second_dilation=1):
+ conv_layer = get_conv_block_ctor(conv_kind)
+
+ conv_block = []
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(dilation * 2)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(dilation * 2)]
+ elif padding_type == 'zero':
+ p = dilation * 2
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+
+ if in_dim is None:
+ in_dim = dim
+
+ conv_block += [conv_layer(in_dim, dim, kernel_size=5, padding=p, dilation=dilation),
+ norm_layer(dim),
+ activation]
+ if use_dropout:
+ conv_block += [nn.Dropout(0.5)]
+
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(second_dilation * 2)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(second_dilation * 2)]
+ elif padding_type == 'zero':
+ p = second_dilation * 2
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+ conv_block += [conv_layer(dim, dim, kernel_size=5, padding=p, dilation=second_dilation, groups=groups),
+ norm_layer(dim)]
+
+ return nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ x_before = x
+ if self.in_dim is not None:
+ x = self.input_conv(x)
+ out = x + self.conv_block(x_before)
+ return out
+
+
+class MultidilatedResnetBlock(nn.Module):
+ def __init__(self, dim, padding_type, conv_layer, norm_layer, activation=nn.ReLU(True), use_dropout=False):
+ super().__init__()
+ self.conv_block = self.build_conv_block(dim, padding_type, conv_layer, norm_layer, activation, use_dropout)
+
+ def build_conv_block(self, dim, padding_type, conv_layer, norm_layer, activation, use_dropout, dilation=1):
+ conv_block = []
+ conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
+ norm_layer(dim),
+ activation]
+ if use_dropout:
+ conv_block += [nn.Dropout(0.5)]
+
+ conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type),
+ norm_layer(dim)]
+
+ return nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ out = x + self.conv_block(x)
+ return out
+
+
+class MultiDilatedGlobalGenerator(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
+ n_blocks=3, norm_layer=nn.BatchNorm2d,
+ padding_type='reflect', conv_kind='default',
+ deconv_kind='convtranspose', activation=nn.ReLU(True),
+ up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
+ add_out_act=True, max_features=1024, multidilation_kwargs={},
+ ffc_positions=None, ffc_kwargs={}):
+ assert (n_blocks >= 0)
+ super().__init__()
+
+ conv_layer = get_conv_block_ctor(conv_kind)
+ resnet_conv_layer = functools.partial(get_conv_block_ctor('multidilated'), **multidilation_kwargs)
+ norm_layer = get_norm_layer(norm_layer)
+ if affine is not None:
+ norm_layer = partial(norm_layer, affine=affine)
+ up_norm_layer = get_norm_layer(up_norm_layer)
+ if affine is not None:
+ up_norm_layer = partial(up_norm_layer, affine=affine)
+
+ model = [nn.ReflectionPad2d(3),
+ conv_layer(input_nc, ngf, kernel_size=7, padding=0),
+ norm_layer(ngf),
+ activation]
+
+ identity = Identity()
+ ### downsample
+ for i in range(n_downsampling):
+ mult = 2 ** i
+
+ model += [conv_layer(min(max_features, ngf * mult),
+ min(max_features, ngf * mult * 2),
+ kernel_size=3, stride=2, padding=1),
+ norm_layer(min(max_features, ngf * mult * 2)),
+ activation]
+
+ mult = 2 ** n_downsampling
+ feats_num_bottleneck = min(max_features, ngf * mult)
+
+ ### resnet blocks
+ for i in range(n_blocks):
+ if ffc_positions is not None and i in ffc_positions:
+ model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
+ inline=True, **ffc_kwargs)]
+ model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
+ conv_layer=resnet_conv_layer, activation=activation,
+ norm_layer=norm_layer)]
+
+ ### upsample
+ for i in range(n_downsampling):
+ mult = 2 ** (n_downsampling - i)
+ model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
+ model += [nn.ReflectionPad2d(3),
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ if add_out_act:
+ model.append(get_activation('tanh' if add_out_act is True else add_out_act))
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ return self.model(input)
+
+class ConfigGlobalGenerator(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3,
+ n_blocks=3, norm_layer=nn.BatchNorm2d,
+ padding_type='reflect', conv_kind='default',
+ deconv_kind='convtranspose', activation=nn.ReLU(True),
+ up_norm_layer=nn.BatchNorm2d, affine=None, up_activation=nn.ReLU(True),
+ add_out_act=True, max_features=1024,
+ manual_block_spec=[],
+ resnet_block_kind='multidilatedresnetblock',
+ resnet_conv_kind='multidilated',
+ resnet_dilation=1,
+ multidilation_kwargs={}):
+ assert (n_blocks >= 0)
+ super().__init__()
+
+ conv_layer = get_conv_block_ctor(conv_kind)
+ resnet_conv_layer = functools.partial(get_conv_block_ctor(resnet_conv_kind), **multidilation_kwargs)
+ norm_layer = get_norm_layer(norm_layer)
+ if affine is not None:
+ norm_layer = partial(norm_layer, affine=affine)
+ up_norm_layer = get_norm_layer(up_norm_layer)
+ if affine is not None:
+ up_norm_layer = partial(up_norm_layer, affine=affine)
+
+ model = [nn.ReflectionPad2d(3),
+ conv_layer(input_nc, ngf, kernel_size=7, padding=0),
+ norm_layer(ngf),
+ activation]
+
+ identity = Identity()
+
+ ### downsample
+ for i in range(n_downsampling):
+ mult = 2 ** i
+ model += [conv_layer(min(max_features, ngf * mult),
+ min(max_features, ngf * mult * 2),
+ kernel_size=3, stride=2, padding=1),
+ norm_layer(min(max_features, ngf * mult * 2)),
+ activation]
+
+ mult = 2 ** n_downsampling
+ feats_num_bottleneck = min(max_features, ngf * mult)
+
+ if len(manual_block_spec) == 0:
+ manual_block_spec = [
+ DotDict(lambda : None, {
+ 'n_blocks': n_blocks,
+ 'use_default': True})
+ ]
+
+ ### resnet blocks
+ for block_spec in manual_block_spec:
+ def make_and_add_blocks(model, block_spec):
+ block_spec = DotDict(lambda : None, block_spec)
+ if not block_spec.use_default:
+ resnet_conv_layer = functools.partial(get_conv_block_ctor(block_spec.resnet_conv_kind), **block_spec.multidilation_kwargs)
+ resnet_conv_kind = block_spec.resnet_conv_kind
+ resnet_block_kind = block_spec.resnet_block_kind
+ if block_spec.resnet_dilation is not None:
+ resnet_dilation = block_spec.resnet_dilation
+ for i in range(block_spec.n_blocks):
+ if resnet_block_kind == "multidilatedresnetblock":
+ model += [MultidilatedResnetBlock(feats_num_bottleneck, padding_type=padding_type,
+ conv_layer=resnet_conv_layer, activation=activation,
+ norm_layer=norm_layer)]
+ if resnet_block_kind == "resnetblock":
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
+ conv_kind=resnet_conv_kind)]
+ if resnet_block_kind == "resnetblock5x5":
+ model += [ResnetBlock5x5(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
+ conv_kind=resnet_conv_kind)]
+ if resnet_block_kind == "resnetblockdwdil":
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer,
+ conv_kind=resnet_conv_kind, dilation=resnet_dilation, second_dilation=resnet_dilation)]
+ make_and_add_blocks(model, block_spec)
+
+ ### upsample
+ for i in range(n_downsampling):
+ mult = 2 ** (n_downsampling - i)
+ model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features)
+ model += [nn.ReflectionPad2d(3),
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ if add_out_act:
+ model.append(get_activation('tanh' if add_out_act is True else add_out_act))
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ return self.model(input)
+
+
+def make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs):
+ blocks = []
+ for i in range(dilated_blocks_n):
+ if dilation_block_kind == 'simple':
+ blocks.append(ResnetBlock(**dilated_block_kwargs, dilation=2 ** (i + 1)))
+ elif dilation_block_kind == 'multi':
+ blocks.append(MultidilatedResnetBlock(**dilated_block_kwargs))
+ else:
+ raise ValueError(f'dilation_block_kind could not be "{dilation_block_kind}"')
+ return blocks
+
+
+class GlobalGenerator(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
+ padding_type='reflect', conv_kind='default', activation=nn.ReLU(True),
+ up_norm_layer=nn.BatchNorm2d, affine=None,
+ up_activation=nn.ReLU(True), dilated_blocks_n=0, dilated_blocks_n_start=0,
+ dilated_blocks_n_middle=0,
+ add_out_act=True,
+ max_features=1024, is_resblock_depthwise=False,
+ ffc_positions=None, ffc_kwargs={}, dilation=1, second_dilation=None,
+ dilation_block_kind='simple', multidilation_kwargs={}):
+ assert (n_blocks >= 0)
+ super().__init__()
+
+ conv_layer = get_conv_block_ctor(conv_kind)
+ norm_layer = get_norm_layer(norm_layer)
+ if affine is not None:
+ norm_layer = partial(norm_layer, affine=affine)
+ up_norm_layer = get_norm_layer(up_norm_layer)
+ if affine is not None:
+ up_norm_layer = partial(up_norm_layer, affine=affine)
+
+ if ffc_positions is not None:
+ ffc_positions = collections.Counter(ffc_positions)
+
+ model = [nn.ReflectionPad2d(3),
+ conv_layer(input_nc, ngf, kernel_size=7, padding=0),
+ norm_layer(ngf),
+ activation]
+
+ identity = Identity()
+ ### downsample
+ for i in range(n_downsampling):
+ mult = 2 ** i
+
+ model += [conv_layer(min(max_features, ngf * mult),
+ min(max_features, ngf * mult * 2),
+ kernel_size=3, stride=2, padding=1),
+ norm_layer(min(max_features, ngf * mult * 2)),
+ activation]
+
+ mult = 2 ** n_downsampling
+ feats_num_bottleneck = min(max_features, ngf * mult)
+
+ dilated_block_kwargs = dict(dim=feats_num_bottleneck, padding_type=padding_type,
+ activation=activation, norm_layer=norm_layer)
+ if dilation_block_kind == 'simple':
+ dilated_block_kwargs['conv_kind'] = conv_kind
+ elif dilation_block_kind == 'multi':
+ dilated_block_kwargs['conv_layer'] = functools.partial(
+ get_conv_block_ctor('multidilated'), **multidilation_kwargs)
+
+ # dilated blocks at the start of the bottleneck sausage
+ if dilated_blocks_n_start is not None and dilated_blocks_n_start > 0:
+ model += make_dil_blocks(dilated_blocks_n_start, dilation_block_kind, dilated_block_kwargs)
+
+ # resnet blocks
+ for i in range(n_blocks):
+ # dilated blocks at the middle of the bottleneck sausage
+ if i == n_blocks // 2 and dilated_blocks_n_middle is not None and dilated_blocks_n_middle > 0:
+ model += make_dil_blocks(dilated_blocks_n_middle, dilation_block_kind, dilated_block_kwargs)
+
+ if ffc_positions is not None and i in ffc_positions:
+ for _ in range(ffc_positions[i]): # same position can occur more than once
+ model += [FFCResnetBlock(feats_num_bottleneck, padding_type, norm_layer, activation_layer=nn.ReLU,
+ inline=True, **ffc_kwargs)]
+
+ if is_resblock_depthwise:
+ resblock_groups = feats_num_bottleneck
+ else:
+ resblock_groups = 1
+
+ model += [ResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation=activation,
+ norm_layer=norm_layer, conv_kind=conv_kind, groups=resblock_groups,
+ dilation=dilation, second_dilation=second_dilation)]
+
+
+ # dilated blocks at the end of the bottleneck sausage
+ if dilated_blocks_n is not None and dilated_blocks_n > 0:
+ model += make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs)
+
+ # upsample
+ for i in range(n_downsampling):
+ mult = 2 ** (n_downsampling - i)
+ model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
+ min(max_features, int(ngf * mult / 2)),
+ kernel_size=3, stride=2, padding=1, output_padding=1),
+ up_norm_layer(min(max_features, int(ngf * mult / 2))),
+ up_activation]
+ model += [nn.ReflectionPad2d(3),
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ if add_out_act:
+ model.append(get_activation('tanh' if add_out_act is True else add_out_act))
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ return self.model(input)
+
+
+class GlobalGeneratorGated(GlobalGenerator):
+ def __init__(self, *args, **kwargs):
+ real_kwargs=dict(
+ conv_kind='gated_bn_relu',
+ activation=nn.Identity(),
+ norm_layer=nn.Identity
+ )
+ real_kwargs.update(kwargs)
+ super().__init__(*args, **real_kwargs)
+
+
+class GlobalGeneratorFromSuperChannels(nn.Module):
+ def __init__(self, input_nc, output_nc, n_downsampling, n_blocks, super_channels, norm_layer="bn", padding_type='reflect', add_out_act=True):
+ super().__init__()
+ self.n_downsampling = n_downsampling
+ norm_layer = get_norm_layer(norm_layer)
+ if type(norm_layer) == functools.partial:
+ use_bias = (norm_layer.func == nn.InstanceNorm2d)
+ else:
+ use_bias = (norm_layer == nn.InstanceNorm2d)
+
+ channels = self.convert_super_channels(super_channels)
+ self.channels = channels
+
+ model = [nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, channels[0], kernel_size=7, padding=0, bias=use_bias),
+ norm_layer(channels[0]),
+ nn.ReLU(True)]
+
+ for i in range(n_downsampling): # add downsampling layers
+ mult = 2 ** i
+ model += [nn.Conv2d(channels[0+i], channels[1+i], kernel_size=3, stride=2, padding=1, bias=use_bias),
+ norm_layer(channels[1+i]),
+ nn.ReLU(True)]
+
+ mult = 2 ** n_downsampling
+
+ n_blocks1 = n_blocks // 3
+ n_blocks2 = n_blocks1
+ n_blocks3 = n_blocks - n_blocks1 - n_blocks2
+
+ for i in range(n_blocks1):
+ c = n_downsampling
+ dim = channels[c]
+ model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer)]
+
+ for i in range(n_blocks2):
+ c = n_downsampling+1
+ dim = channels[c]
+ kwargs = {}
+ if i == 0:
+ kwargs = {"in_dim": channels[c-1]}
+ model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
+
+ for i in range(n_blocks3):
+ c = n_downsampling+2
+ dim = channels[c]
+ kwargs = {}
+ if i == 0:
+ kwargs = {"in_dim": channels[c-1]}
+ model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)]
+
+ for i in range(n_downsampling): # add upsampling layers
+ mult = 2 ** (n_downsampling - i)
+ model += [nn.ConvTranspose2d(channels[n_downsampling+3+i],
+ channels[n_downsampling+3+i+1],
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1,
+ bias=use_bias),
+ norm_layer(channels[n_downsampling+3+i+1]),
+ nn.ReLU(True)]
+ model += [nn.ReflectionPad2d(3)]
+ model += [nn.Conv2d(channels[2*n_downsampling+3], output_nc, kernel_size=7, padding=0)]
+
+ if add_out_act:
+ model.append(get_activation('tanh' if add_out_act is True else add_out_act))
+ self.model = nn.Sequential(*model)
+
+ def convert_super_channels(self, super_channels):
+ n_downsampling = self.n_downsampling
+ result = []
+ cnt = 0
+
+ if n_downsampling == 2:
+ N1 = 10
+ elif n_downsampling == 3:
+ N1 = 13
+ else:
+ raise NotImplementedError
+
+ for i in range(0, N1):
+ if i in [1,4,7,10]:
+ channel = super_channels[cnt] * (2 ** cnt)
+ config = {'channel': channel}
+ result.append(channel)
+ logging.info(f"Downsample channels {result[-1]}")
+ cnt += 1
+
+ for i in range(3):
+ for counter, j in enumerate(range(N1 + i * 3, N1 + 3 + i * 3)):
+ if len(super_channels) == 6:
+ channel = super_channels[3] * 4
+ else:
+ channel = super_channels[i + 3] * 4
+ config = {'channel': channel}
+ if counter == 0:
+ result.append(channel)
+ logging.info(f"Bottleneck channels {result[-1]}")
+ cnt = 2
+
+ for i in range(N1+9, N1+21):
+ if i in [22, 25,28]:
+ cnt -= 1
+ if len(super_channels) == 6:
+ channel = super_channels[5 - cnt] * (2 ** cnt)
+ else:
+ channel = super_channels[7 - cnt] * (2 ** cnt)
+ result.append(int(channel))
+ logging.info(f"Upsample channels {result[-1]}")
+ return result
+
+ def forward(self, input):
+ return self.model(input)
+
+
+# Defines the PatchGAN discriminator with the specified arguments.
+class NLayerDiscriminator(BaseDiscriminator):
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,):
+ super().__init__()
+ self.n_layers = n_layers
+
+ kw = 4
+ padw = int(np.ceil((kw-1.0)/2))
+ sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
+ nn.LeakyReLU(0.2, True)]]
+
+ nf = ndf
+ for n in range(1, n_layers):
+ nf_prev = nf
+ nf = min(nf * 2, 512)
+
+ cur_model = []
+ cur_model += [
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
+ norm_layer(nf),
+ nn.LeakyReLU(0.2, True)
+ ]
+ sequence.append(cur_model)
+
+ nf_prev = nf
+ nf = min(nf * 2, 512)
+
+ cur_model = []
+ cur_model += [
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
+ norm_layer(nf),
+ nn.LeakyReLU(0.2, True)
+ ]
+ sequence.append(cur_model)
+
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
+
+ for n in range(len(sequence)):
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
+
+ def get_all_activations(self, x):
+ res = [x]
+ for n in range(self.n_layers + 2):
+ model = getattr(self, 'model' + str(n))
+ res.append(model(res[-1]))
+ return res[1:]
+
+ def forward(self, x):
+ act = self.get_all_activations(x)
+ return act[-1], act[:-1]
+
+
+class MultidilatedNLayerDiscriminator(BaseDiscriminator):
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, multidilation_kwargs={}):
+ super().__init__()
+ self.n_layers = n_layers
+
+ kw = 4
+ padw = int(np.ceil((kw-1.0)/2))
+ sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
+ nn.LeakyReLU(0.2, True)]]
+
+ nf = ndf
+ for n in range(1, n_layers):
+ nf_prev = nf
+ nf = min(nf * 2, 512)
+
+ cur_model = []
+ cur_model += [
+ MultidilatedConv(nf_prev, nf, kernel_size=kw, stride=2, padding=[2, 3], **multidilation_kwargs),
+ norm_layer(nf),
+ nn.LeakyReLU(0.2, True)
+ ]
+ sequence.append(cur_model)
+
+ nf_prev = nf
+ nf = min(nf * 2, 512)
+
+ cur_model = []
+ cur_model += [
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
+ norm_layer(nf),
+ nn.LeakyReLU(0.2, True)
+ ]
+ sequence.append(cur_model)
+
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
+
+ for n in range(len(sequence)):
+ setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
+
+ def get_all_activations(self, x):
+ res = [x]
+ for n in range(self.n_layers + 2):
+ model = getattr(self, 'model' + str(n))
+ res.append(model(res[-1]))
+ return res[1:]
+
+ def forward(self, x):
+ act = self.get_all_activations(x)
+ return act[-1], act[:-1]
+
+
+class NLayerDiscriminatorAsGen(NLayerDiscriminator):
+ def forward(self, x):
+ return super().forward(x)[0]
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/spatial_transform.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/spatial_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..2de024ba08c549605a08b64d096f1f0db7b7722a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/spatial_transform.py
@@ -0,0 +1,49 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from kornia.geometry.transform import rotate
+
+
+class LearnableSpatialTransformWrapper(nn.Module):
+ def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
+ super().__init__()
+ self.impl = impl
+ self.angle = torch.rand(1) * angle_init_range
+ if train_angle:
+ self.angle = nn.Parameter(self.angle, requires_grad=True)
+ self.pad_coef = pad_coef
+
+ def forward(self, x):
+ if torch.is_tensor(x):
+ return self.inverse_transform(self.impl(self.transform(x)), x)
+ elif isinstance(x, tuple):
+ x_trans = tuple(self.transform(elem) for elem in x)
+ y_trans = self.impl(x_trans)
+ return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x))
+ else:
+ raise ValueError(f'Unexpected input type {type(x)}')
+
+ def transform(self, x):
+ height, width = x.shape[2:]
+ pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
+ x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect')
+ x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded))
+ return x_padded_rotated
+
+ def inverse_transform(self, y_padded_rotated, orig_x):
+ height, width = orig_x.shape[2:]
+ pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
+
+ y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated))
+ y_height, y_width = y_padded.shape[2:]
+ y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
+ return y
+
+
+if __name__ == '__main__':
+ layer = LearnableSpatialTransformWrapper(nn.Identity())
+ x = torch.arange(2* 3 * 15 * 15).view(2, 3, 15, 15).float()
+ y = layer(x)
+ assert x.shape == y.shape
+ assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1])
+ print('all ok')
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/squeeze_excitation.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/squeeze_excitation.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1d902bb30c071acbc0fa919a134c80fed86bd6c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/modules/squeeze_excitation.py
@@ -0,0 +1,20 @@
+import torch.nn as nn
+
+
+class SELayer(nn.Module):
+ def __init__(self, channel, reduction=16):
+ super(SELayer, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction, bias=False),
+ nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel, bias=False),
+ nn.Sigmoid()
+ )
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ res = x * y.expand_as(x)
+ return res
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/trainers/__init__.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/trainers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8307cd31c2139db0ce581637403b3a95dc8cae59
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/trainers/__init__.py
@@ -0,0 +1,29 @@
+import logging
+import torch
+from annotator.lama.saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule
+
+
+def get_training_model_class(kind):
+ if kind == 'default':
+ return DefaultInpaintingTrainingModule
+
+ raise ValueError(f'Unknown trainer module {kind}')
+
+
+def make_training_model(config):
+ kind = config.training_model.kind
+ kwargs = dict(config.training_model)
+ kwargs.pop('kind')
+ kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp'
+
+ logging.info(f'Make training model {kind}')
+
+ cls = get_training_model_class(kind)
+ return cls(config, **kwargs)
+
+
+def load_checkpoint(train_config, path, map_location='cuda', strict=True):
+ model = make_training_model(train_config).generator
+ state = torch.load(path, map_location=map_location)
+ model.load_state_dict(state, strict=strict)
+ return model
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/trainers/base.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/trainers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..372dd879a22ff6c3929abf23bb59d6b8b66256b7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/trainers/base.py
@@ -0,0 +1,293 @@
+import copy
+import logging
+from typing import Dict, Tuple
+
+import pandas as pd
+import pytorch_lightning as ptl
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+# from torch.utils.data import DistributedSampler
+
+# from annotator.lama.saicinpainting.evaluation import make_evaluator
+# from annotator.lama.saicinpainting.training.data.datasets import make_default_train_dataloader, make_default_val_dataloader
+# from annotator.lama.saicinpainting.training.losses.adversarial import make_discrim_loss
+# from annotator.lama.saicinpainting.training.losses.perceptual import PerceptualLoss, ResNetPL
+from annotator.lama.saicinpainting.training.modules import make_generator #, make_discriminator
+# from annotator.lama.saicinpainting.training.visualizers import make_visualizer
+from annotator.lama.saicinpainting.utils import add_prefix_to_keys, average_dicts, set_requires_grad, flatten_dict, \
+ get_has_ddp_rank
+
+LOGGER = logging.getLogger(__name__)
+
+
+def make_optimizer(parameters, kind='adamw', **kwargs):
+ if kind == 'adam':
+ optimizer_class = torch.optim.Adam
+ elif kind == 'adamw':
+ optimizer_class = torch.optim.AdamW
+ else:
+ raise ValueError(f'Unknown optimizer kind {kind}')
+ return optimizer_class(parameters, **kwargs)
+
+
+def update_running_average(result: nn.Module, new_iterate_model: nn.Module, decay=0.999):
+ with torch.no_grad():
+ res_params = dict(result.named_parameters())
+ new_params = dict(new_iterate_model.named_parameters())
+
+ for k in res_params.keys():
+ res_params[k].data.mul_(decay).add_(new_params[k].data, alpha=1 - decay)
+
+
+def make_multiscale_noise(base_tensor, scales=6, scale_mode='bilinear'):
+ batch_size, _, height, width = base_tensor.shape
+ cur_height, cur_width = height, width
+ result = []
+ align_corners = False if scale_mode in ('bilinear', 'bicubic') else None
+ for _ in range(scales):
+ cur_sample = torch.randn(batch_size, 1, cur_height, cur_width, device=base_tensor.device)
+ cur_sample_scaled = F.interpolate(cur_sample, size=(height, width), mode=scale_mode, align_corners=align_corners)
+ result.append(cur_sample_scaled)
+ cur_height //= 2
+ cur_width //= 2
+ return torch.cat(result, dim=1)
+
+
+class BaseInpaintingTrainingModule(ptl.LightningModule):
+ def __init__(self, config, use_ddp, *args, predict_only=False, visualize_each_iters=100,
+ average_generator=False, generator_avg_beta=0.999, average_generator_start_step=30000,
+ average_generator_period=10, store_discr_outputs_for_vis=False,
+ **kwargs):
+ super().__init__(*args, **kwargs)
+ LOGGER.info('BaseInpaintingTrainingModule init called')
+
+ self.config = config
+
+ self.generator = make_generator(config, **self.config.generator)
+ self.use_ddp = use_ddp
+
+ if not get_has_ddp_rank():
+ LOGGER.info(f'Generator\n{self.generator}')
+
+ # if not predict_only:
+ # self.save_hyperparameters(self.config)
+ # self.discriminator = make_discriminator(**self.config.discriminator)
+ # self.adversarial_loss = make_discrim_loss(**self.config.losses.adversarial)
+ # self.visualizer = make_visualizer(**self.config.visualizer)
+ # self.val_evaluator = make_evaluator(**self.config.evaluator)
+ # self.test_evaluator = make_evaluator(**self.config.evaluator)
+ #
+ # if not get_has_ddp_rank():
+ # LOGGER.info(f'Discriminator\n{self.discriminator}')
+ #
+ # extra_val = self.config.data.get('extra_val', ())
+ # if extra_val:
+ # self.extra_val_titles = list(extra_val)
+ # self.extra_evaluators = nn.ModuleDict({k: make_evaluator(**self.config.evaluator)
+ # for k in extra_val})
+ # else:
+ # self.extra_evaluators = {}
+ #
+ # self.average_generator = average_generator
+ # self.generator_avg_beta = generator_avg_beta
+ # self.average_generator_start_step = average_generator_start_step
+ # self.average_generator_period = average_generator_period
+ # self.generator_average = None
+ # self.last_generator_averaging_step = -1
+ # self.store_discr_outputs_for_vis = store_discr_outputs_for_vis
+ #
+ # if self.config.losses.get("l1", {"weight_known": 0})['weight_known'] > 0:
+ # self.loss_l1 = nn.L1Loss(reduction='none')
+ #
+ # if self.config.losses.get("mse", {"weight": 0})['weight'] > 0:
+ # self.loss_mse = nn.MSELoss(reduction='none')
+ #
+ # if self.config.losses.perceptual.weight > 0:
+ # self.loss_pl = PerceptualLoss()
+ #
+ # # if self.config.losses.get("resnet_pl", {"weight": 0})['weight'] > 0:
+ # # self.loss_resnet_pl = ResNetPL(**self.config.losses.resnet_pl)
+ # # else:
+ # # self.loss_resnet_pl = None
+ #
+ # self.loss_resnet_pl = None
+
+ self.visualize_each_iters = visualize_each_iters
+ LOGGER.info('BaseInpaintingTrainingModule init done')
+
+ def configure_optimizers(self):
+ discriminator_params = list(self.discriminator.parameters())
+ return [
+ dict(optimizer=make_optimizer(self.generator.parameters(), **self.config.optimizers.generator)),
+ dict(optimizer=make_optimizer(discriminator_params, **self.config.optimizers.discriminator)),
+ ]
+
+ def train_dataloader(self):
+ kwargs = dict(self.config.data.train)
+ if self.use_ddp:
+ kwargs['ddp_kwargs'] = dict(num_replicas=self.trainer.num_nodes * self.trainer.num_processes,
+ rank=self.trainer.global_rank,
+ shuffle=True)
+ dataloader = make_default_train_dataloader(**self.config.data.train)
+ return dataloader
+
+ def val_dataloader(self):
+ res = [make_default_val_dataloader(**self.config.data.val)]
+
+ if self.config.data.visual_test is not None:
+ res = res + [make_default_val_dataloader(**self.config.data.visual_test)]
+ else:
+ res = res + res
+
+ extra_val = self.config.data.get('extra_val', ())
+ if extra_val:
+ res += [make_default_val_dataloader(**extra_val[k]) for k in self.extra_val_titles]
+
+ return res
+
+ def training_step(self, batch, batch_idx, optimizer_idx=None):
+ self._is_training_step = True
+ return self._do_step(batch, batch_idx, mode='train', optimizer_idx=optimizer_idx)
+
+ def validation_step(self, batch, batch_idx, dataloader_idx):
+ extra_val_key = None
+ if dataloader_idx == 0:
+ mode = 'val'
+ elif dataloader_idx == 1:
+ mode = 'test'
+ else:
+ mode = 'extra_val'
+ extra_val_key = self.extra_val_titles[dataloader_idx - 2]
+ self._is_training_step = False
+ return self._do_step(batch, batch_idx, mode=mode, extra_val_key=extra_val_key)
+
+ def training_step_end(self, batch_parts_outputs):
+ if self.training and self.average_generator \
+ and self.global_step >= self.average_generator_start_step \
+ and self.global_step >= self.last_generator_averaging_step + self.average_generator_period:
+ if self.generator_average is None:
+ self.generator_average = copy.deepcopy(self.generator)
+ else:
+ update_running_average(self.generator_average, self.generator, decay=self.generator_avg_beta)
+ self.last_generator_averaging_step = self.global_step
+
+ full_loss = (batch_parts_outputs['loss'].mean()
+ if torch.is_tensor(batch_parts_outputs['loss']) # loss is not tensor when no discriminator used
+ else torch.tensor(batch_parts_outputs['loss']).float().requires_grad_(True))
+ log_info = {k: v.mean() for k, v in batch_parts_outputs['log_info'].items()}
+ self.log_dict(log_info, on_step=True, on_epoch=False)
+ return full_loss
+
+ def validation_epoch_end(self, outputs):
+ outputs = [step_out for out_group in outputs for step_out in out_group]
+ averaged_logs = average_dicts(step_out['log_info'] for step_out in outputs)
+ self.log_dict({k: v.mean() for k, v in averaged_logs.items()})
+
+ pd.set_option('display.max_columns', 500)
+ pd.set_option('display.width', 1000)
+
+ # standard validation
+ val_evaluator_states = [s['val_evaluator_state'] for s in outputs if 'val_evaluator_state' in s]
+ val_evaluator_res = self.val_evaluator.evaluation_end(states=val_evaluator_states)
+ val_evaluator_res_df = pd.DataFrame(val_evaluator_res).stack(1).unstack(0)
+ val_evaluator_res_df.dropna(axis=1, how='all', inplace=True)
+ LOGGER.info(f'Validation metrics after epoch #{self.current_epoch}, '
+ f'total {self.global_step} iterations:\n{val_evaluator_res_df}')
+
+ for k, v in flatten_dict(val_evaluator_res).items():
+ self.log(f'val_{k}', v)
+
+ # standard visual test
+ test_evaluator_states = [s['test_evaluator_state'] for s in outputs
+ if 'test_evaluator_state' in s]
+ test_evaluator_res = self.test_evaluator.evaluation_end(states=test_evaluator_states)
+ test_evaluator_res_df = pd.DataFrame(test_evaluator_res).stack(1).unstack(0)
+ test_evaluator_res_df.dropna(axis=1, how='all', inplace=True)
+ LOGGER.info(f'Test metrics after epoch #{self.current_epoch}, '
+ f'total {self.global_step} iterations:\n{test_evaluator_res_df}')
+
+ for k, v in flatten_dict(test_evaluator_res).items():
+ self.log(f'test_{k}', v)
+
+ # extra validations
+ if self.extra_evaluators:
+ for cur_eval_title, cur_evaluator in self.extra_evaluators.items():
+ cur_state_key = f'extra_val_{cur_eval_title}_evaluator_state'
+ cur_states = [s[cur_state_key] for s in outputs if cur_state_key in s]
+ cur_evaluator_res = cur_evaluator.evaluation_end(states=cur_states)
+ cur_evaluator_res_df = pd.DataFrame(cur_evaluator_res).stack(1).unstack(0)
+ cur_evaluator_res_df.dropna(axis=1, how='all', inplace=True)
+ LOGGER.info(f'Extra val {cur_eval_title} metrics after epoch #{self.current_epoch}, '
+ f'total {self.global_step} iterations:\n{cur_evaluator_res_df}')
+ for k, v in flatten_dict(cur_evaluator_res).items():
+ self.log(f'extra_val_{cur_eval_title}_{k}', v)
+
+ def _do_step(self, batch, batch_idx, mode='train', optimizer_idx=None, extra_val_key=None):
+ if optimizer_idx == 0: # step for generator
+ set_requires_grad(self.generator, True)
+ set_requires_grad(self.discriminator, False)
+ elif optimizer_idx == 1: # step for discriminator
+ set_requires_grad(self.generator, False)
+ set_requires_grad(self.discriminator, True)
+
+ batch = self(batch)
+
+ total_loss = 0
+ metrics = {}
+
+ if optimizer_idx is None or optimizer_idx == 0: # step for generator
+ total_loss, metrics = self.generator_loss(batch)
+
+ elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator
+ if self.config.losses.adversarial.weight > 0:
+ total_loss, metrics = self.discriminator_loss(batch)
+
+ if self.get_ddp_rank() in (None, 0) and (batch_idx % self.visualize_each_iters == 0 or mode == 'test'):
+ if self.config.losses.adversarial.weight > 0:
+ if self.store_discr_outputs_for_vis:
+ with torch.no_grad():
+ self.store_discr_outputs(batch)
+ vis_suffix = f'_{mode}'
+ if mode == 'extra_val':
+ vis_suffix += f'_{extra_val_key}'
+ self.visualizer(self.current_epoch, batch_idx, batch, suffix=vis_suffix)
+
+ metrics_prefix = f'{mode}_'
+ if mode == 'extra_val':
+ metrics_prefix += f'{extra_val_key}_'
+ result = dict(loss=total_loss, log_info=add_prefix_to_keys(metrics, metrics_prefix))
+ if mode == 'val':
+ result['val_evaluator_state'] = self.val_evaluator.process_batch(batch)
+ elif mode == 'test':
+ result['test_evaluator_state'] = self.test_evaluator.process_batch(batch)
+ elif mode == 'extra_val':
+ result[f'extra_val_{extra_val_key}_evaluator_state'] = self.extra_evaluators[extra_val_key].process_batch(batch)
+
+ return result
+
+ def get_current_generator(self, no_average=False):
+ if not no_average and not self.training and self.average_generator and self.generator_average is not None:
+ return self.generator_average
+ return self.generator
+
+ def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """Pass data through generator and obtain at leas 'predicted_image' and 'inpainted' keys"""
+ raise NotImplementedError()
+
+ def generator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+ raise NotImplementedError()
+
+ def discriminator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+ raise NotImplementedError()
+
+ def store_discr_outputs(self, batch):
+ out_size = batch['image'].shape[2:]
+ discr_real_out, _ = self.discriminator(batch['image'])
+ discr_fake_out, _ = self.discriminator(batch['predicted_image'])
+ batch['discr_output_real'] = F.interpolate(discr_real_out, size=out_size, mode='nearest')
+ batch['discr_output_fake'] = F.interpolate(discr_fake_out, size=out_size, mode='nearest')
+ batch['discr_output_diff'] = batch['discr_output_real'] - batch['discr_output_fake']
+
+ def get_ddp_rank(self):
+ return self.trainer.global_rank if (self.trainer.num_nodes * self.trainer.num_processes) > 1 else None
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/trainers/default.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/trainers/default.py
new file mode 100644
index 0000000000000000000000000000000000000000..29cd10ec376d5fe3ebcd957d807d2d3f83b6ec59
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/trainers/default.py
@@ -0,0 +1,175 @@
+import logging
+
+import torch
+import torch.nn.functional as F
+from omegaconf import OmegaConf
+
+# from annotator.lama.saicinpainting.training.data.datasets import make_constant_area_crop_params
+from annotator.lama.saicinpainting.training.losses.distance_weighting import make_mask_distance_weighter
+from annotator.lama.saicinpainting.training.losses.feature_matching import feature_matching_loss, masked_l1_loss
+# from annotator.lama.saicinpainting.training.modules.fake_fakes import FakeFakesGenerator
+from annotator.lama.saicinpainting.training.trainers.base import BaseInpaintingTrainingModule, make_multiscale_noise
+from annotator.lama.saicinpainting.utils import add_prefix_to_keys, get_ramp
+
+LOGGER = logging.getLogger(__name__)
+
+
+def make_constant_area_crop_batch(batch, **kwargs):
+ crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2],
+ img_width=batch['image'].shape[3],
+ **kwargs)
+ batch['image'] = batch['image'][:, :, crop_y : crop_y + crop_height, crop_x : crop_x + crop_width]
+ batch['mask'] = batch['mask'][:, :, crop_y: crop_y + crop_height, crop_x: crop_x + crop_width]
+ return batch
+
+
+class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
+ def __init__(self, *args, concat_mask=True, rescale_scheduler_kwargs=None, image_to_discriminator='predicted_image',
+ add_noise_kwargs=None, noise_fill_hole=False, const_area_crop_kwargs=None,
+ distance_weighter_kwargs=None, distance_weighted_mask_for_discr=False,
+ fake_fakes_proba=0, fake_fakes_generator_kwargs=None,
+ **kwargs):
+ super().__init__(*args, **kwargs)
+ self.concat_mask = concat_mask
+ self.rescale_size_getter = get_ramp(**rescale_scheduler_kwargs) if rescale_scheduler_kwargs is not None else None
+ self.image_to_discriminator = image_to_discriminator
+ self.add_noise_kwargs = add_noise_kwargs
+ self.noise_fill_hole = noise_fill_hole
+ self.const_area_crop_kwargs = const_area_crop_kwargs
+ self.refine_mask_for_losses = make_mask_distance_weighter(**distance_weighter_kwargs) \
+ if distance_weighter_kwargs is not None else None
+ self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr
+
+ self.fake_fakes_proba = fake_fakes_proba
+ if self.fake_fakes_proba > 1e-3:
+ self.fake_fakes_gen = FakeFakesGenerator(**(fake_fakes_generator_kwargs or {}))
+
+ def forward(self, batch):
+ if self.training and self.rescale_size_getter is not None:
+ cur_size = self.rescale_size_getter(self.global_step)
+ batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False)
+ batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest')
+
+ if self.training and self.const_area_crop_kwargs is not None:
+ batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs)
+
+ img = batch['image']
+ mask = batch['mask']
+
+ masked_img = img * (1 - mask)
+
+ if self.add_noise_kwargs is not None:
+ noise = make_multiscale_noise(masked_img, **self.add_noise_kwargs)
+ if self.noise_fill_hole:
+ masked_img = masked_img + mask * noise[:, :masked_img.shape[1]]
+ masked_img = torch.cat([masked_img, noise], dim=1)
+
+ if self.concat_mask:
+ masked_img = torch.cat([masked_img, mask], dim=1)
+
+ batch['predicted_image'] = self.generator(masked_img)
+ batch['inpainted'] = mask * batch['predicted_image'] + (1 - mask) * batch['image']
+
+ if self.fake_fakes_proba > 1e-3:
+ if self.training and torch.rand(1).item() < self.fake_fakes_proba:
+ batch['fake_fakes'], batch['fake_fakes_masks'] = self.fake_fakes_gen(img, mask)
+ batch['use_fake_fakes'] = True
+ else:
+ batch['fake_fakes'] = torch.zeros_like(img)
+ batch['fake_fakes_masks'] = torch.zeros_like(mask)
+ batch['use_fake_fakes'] = False
+
+ batch['mask_for_losses'] = self.refine_mask_for_losses(img, batch['predicted_image'], mask) \
+ if self.refine_mask_for_losses is not None and self.training \
+ else mask
+
+ return batch
+
+ def generator_loss(self, batch):
+ img = batch['image']
+ predicted_img = batch[self.image_to_discriminator]
+ original_mask = batch['mask']
+ supervised_mask = batch['mask_for_losses']
+
+ # L1
+ l1_value = masked_l1_loss(predicted_img, img, supervised_mask,
+ self.config.losses.l1.weight_known,
+ self.config.losses.l1.weight_missing)
+
+ total_loss = l1_value
+ metrics = dict(gen_l1=l1_value)
+
+ # vgg-based perceptual loss
+ if self.config.losses.perceptual.weight > 0:
+ pl_value = self.loss_pl(predicted_img, img, mask=supervised_mask).sum() * self.config.losses.perceptual.weight
+ total_loss = total_loss + pl_value
+ metrics['gen_pl'] = pl_value
+
+ # discriminator
+ # adversarial_loss calls backward by itself
+ mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask
+ self.adversarial_loss.pre_generator_step(real_batch=img, fake_batch=predicted_img,
+ generator=self.generator, discriminator=self.discriminator)
+ discr_real_pred, discr_real_features = self.discriminator(img)
+ discr_fake_pred, discr_fake_features = self.discriminator(predicted_img)
+ adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss(real_batch=img,
+ fake_batch=predicted_img,
+ discr_real_pred=discr_real_pred,
+ discr_fake_pred=discr_fake_pred,
+ mask=mask_for_discr)
+ total_loss = total_loss + adv_gen_loss
+ metrics['gen_adv'] = adv_gen_loss
+ metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
+
+ # feature matching
+ if self.config.losses.feature_matching.weight > 0:
+ need_mask_in_fm = OmegaConf.to_container(self.config.losses.feature_matching).get('pass_mask', False)
+ mask_for_fm = supervised_mask if need_mask_in_fm else None
+ fm_value = feature_matching_loss(discr_fake_features, discr_real_features,
+ mask=mask_for_fm) * self.config.losses.feature_matching.weight
+ total_loss = total_loss + fm_value
+ metrics['gen_fm'] = fm_value
+
+ if self.loss_resnet_pl is not None:
+ resnet_pl_value = self.loss_resnet_pl(predicted_img, img)
+ total_loss = total_loss + resnet_pl_value
+ metrics['gen_resnet_pl'] = resnet_pl_value
+
+ return total_loss, metrics
+
+ def discriminator_loss(self, batch):
+ total_loss = 0
+ metrics = {}
+
+ predicted_img = batch[self.image_to_discriminator].detach()
+ self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=predicted_img,
+ generator=self.generator, discriminator=self.discriminator)
+ discr_real_pred, discr_real_features = self.discriminator(batch['image'])
+ discr_fake_pred, discr_fake_features = self.discriminator(predicted_img)
+ adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss(real_batch=batch['image'],
+ fake_batch=predicted_img,
+ discr_real_pred=discr_real_pred,
+ discr_fake_pred=discr_fake_pred,
+ mask=batch['mask'])
+ total_loss = total_loss + adv_discr_loss
+ metrics['discr_adv'] = adv_discr_loss
+ metrics.update(add_prefix_to_keys(adv_metrics, 'adv_'))
+
+
+ if batch.get('use_fake_fakes', False):
+ fake_fakes = batch['fake_fakes']
+ self.adversarial_loss.pre_discriminator_step(real_batch=batch['image'], fake_batch=fake_fakes,
+ generator=self.generator, discriminator=self.discriminator)
+ discr_fake_fakes_pred, _ = self.discriminator(fake_fakes)
+ fake_fakes_adv_discr_loss, fake_fakes_adv_metrics = self.adversarial_loss.discriminator_loss(
+ real_batch=batch['image'],
+ fake_batch=fake_fakes,
+ discr_real_pred=discr_real_pred,
+ discr_fake_pred=discr_fake_fakes_pred,
+ mask=batch['mask']
+ )
+ total_loss = total_loss + fake_fakes_adv_discr_loss
+ metrics['discr_adv_fake_fakes'] = fake_fakes_adv_discr_loss
+ metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, 'adv_'))
+
+ return total_loss, metrics
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/visualizers/__init__.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/visualizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d280fd8d48428c249c40c341ecc3c36f34524c99
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/visualizers/__init__.py
@@ -0,0 +1,15 @@
+import logging
+
+from annotator.lama.saicinpainting.training.visualizers.directory import DirectoryVisualizer
+from annotator.lama.saicinpainting.training.visualizers.noop import NoopVisualizer
+
+
+def make_visualizer(kind, **kwargs):
+ logging.info(f'Make visualizer {kind}')
+
+ if kind == 'directory':
+ return DirectoryVisualizer(**kwargs)
+ if kind == 'noop':
+ return NoopVisualizer()
+
+ raise ValueError(f'Unknown visualizer kind {kind}')
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/visualizers/base.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/visualizers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..675f01682ddf5e31b6cc341735378c6f3b242e49
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/visualizers/base.py
@@ -0,0 +1,73 @@
+import abc
+from typing import Dict, List
+
+import numpy as np
+import torch
+from skimage import color
+from skimage.segmentation import mark_boundaries
+
+from . import colors
+
+COLORS, _ = colors.generate_colors(151) # 151 - max classes for semantic segmentation
+
+
+class BaseVisualizer:
+ @abc.abstractmethod
+ def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
+ """
+ Take a batch, make an image from it and visualize
+ """
+ raise NotImplementedError()
+
+
+def visualize_mask_and_images(images_dict: Dict[str, np.ndarray], keys: List[str],
+ last_without_mask=True, rescale_keys=None, mask_only_first=None,
+ black_mask=False) -> np.ndarray:
+ mask = images_dict['mask'] > 0.5
+ result = []
+ for i, k in enumerate(keys):
+ img = images_dict[k]
+ img = np.transpose(img, (1, 2, 0))
+
+ if rescale_keys is not None and k in rescale_keys:
+ img = img - img.min()
+ img /= img.max() + 1e-5
+ if len(img.shape) == 2:
+ img = np.expand_dims(img, 2)
+
+ if img.shape[2] == 1:
+ img = np.repeat(img, 3, axis=2)
+ elif (img.shape[2] > 3):
+ img_classes = img.argmax(2)
+ img = color.label2rgb(img_classes, colors=COLORS)
+
+ if mask_only_first:
+ need_mark_boundaries = i == 0
+ else:
+ need_mark_boundaries = i < len(keys) - 1 or not last_without_mask
+
+ if need_mark_boundaries:
+ if black_mask:
+ img = img * (1 - mask[0][..., None])
+ img = mark_boundaries(img,
+ mask[0],
+ color=(1., 0., 0.),
+ outline_color=(1., 1., 1.),
+ mode='thick')
+ result.append(img)
+ return np.concatenate(result, axis=1)
+
+
+def visualize_mask_and_images_batch(batch: Dict[str, torch.Tensor], keys: List[str], max_items=10,
+ last_without_mask=True, rescale_keys=None) -> np.ndarray:
+ batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items()
+ if k in keys or k == 'mask'}
+
+ batch_size = next(iter(batch.values())).shape[0]
+ items_to_vis = min(batch_size, max_items)
+ result = []
+ for i in range(items_to_vis):
+ cur_dct = {k: tens[i] for k, tens in batch.items()}
+ result.append(visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask,
+ rescale_keys=rescale_keys))
+ return np.concatenate(result, axis=0)
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/visualizers/colors.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/visualizers/colors.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e9e39182c58cb06a1c5e97a7e6c497cc3388ebe
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/visualizers/colors.py
@@ -0,0 +1,76 @@
+import random
+import colorsys
+
+import numpy as np
+import matplotlib
+matplotlib.use('agg')
+import matplotlib.pyplot as plt
+from matplotlib.colors import LinearSegmentedColormap
+
+
+def generate_colors(nlabels, type='bright', first_color_black=False, last_color_black=True, verbose=False):
+ # https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib
+ """
+ Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks
+ :param nlabels: Number of labels (size of colormap)
+ :param type: 'bright' for strong colors, 'soft' for pastel colors
+ :param first_color_black: Option to use first color as black, True or False
+ :param last_color_black: Option to use last color as black, True or False
+ :param verbose: Prints the number of labels and shows the colormap. True or False
+ :return: colormap for matplotlib
+ """
+ if type not in ('bright', 'soft'):
+ print ('Please choose "bright" or "soft" for type')
+ return
+
+ if verbose:
+ print('Number of labels: ' + str(nlabels))
+
+ # Generate color map for bright colors, based on hsv
+ if type == 'bright':
+ randHSVcolors = [(np.random.uniform(low=0.0, high=1),
+ np.random.uniform(low=0.2, high=1),
+ np.random.uniform(low=0.9, high=1)) for i in range(nlabels)]
+
+ # Convert HSV list to RGB
+ randRGBcolors = []
+ for HSVcolor in randHSVcolors:
+ randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2]))
+
+ if first_color_black:
+ randRGBcolors[0] = [0, 0, 0]
+
+ if last_color_black:
+ randRGBcolors[-1] = [0, 0, 0]
+
+ random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)
+
+ # Generate soft pastel colors, by limiting the RGB spectrum
+ if type == 'soft':
+ low = 0.6
+ high = 0.95
+ randRGBcolors = [(np.random.uniform(low=low, high=high),
+ np.random.uniform(low=low, high=high),
+ np.random.uniform(low=low, high=high)) for i in range(nlabels)]
+
+ if first_color_black:
+ randRGBcolors[0] = [0, 0, 0]
+
+ if last_color_black:
+ randRGBcolors[-1] = [0, 0, 0]
+ random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)
+
+ # Display colorbar
+ if verbose:
+ from matplotlib import colors, colorbar
+ from matplotlib import pyplot as plt
+ fig, ax = plt.subplots(1, 1, figsize=(15, 0.5))
+
+ bounds = np.linspace(0, nlabels, nlabels + 1)
+ norm = colors.BoundaryNorm(bounds, nlabels)
+
+ cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None,
+ boundaries=bounds, format='%1i', orientation=u'horizontal')
+
+ return randRGBcolors, random_colormap
+
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/visualizers/directory.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/visualizers/directory.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0a3b5eb93c0738784bf24083bdd54d50e4782f6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/visualizers/directory.py
@@ -0,0 +1,36 @@
+import os
+
+import cv2
+import numpy as np
+
+from annotator.lama.saicinpainting.training.visualizers.base import BaseVisualizer, visualize_mask_and_images_batch
+from annotator.lama.saicinpainting.utils import check_and_warn_input_range
+
+
+class DirectoryVisualizer(BaseVisualizer):
+ DEFAULT_KEY_ORDER = 'image predicted_image inpainted'.split(' ')
+
+ def __init__(self, outdir, key_order=DEFAULT_KEY_ORDER, max_items_in_batch=10,
+ last_without_mask=True, rescale_keys=None):
+ self.outdir = outdir
+ os.makedirs(self.outdir, exist_ok=True)
+ self.key_order = key_order
+ self.max_items_in_batch = max_items_in_batch
+ self.last_without_mask = last_without_mask
+ self.rescale_keys = rescale_keys
+
+ def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
+ check_and_warn_input_range(batch['image'], 0, 1, 'DirectoryVisualizer target image')
+ vis_img = visualize_mask_and_images_batch(batch, self.key_order, max_items=self.max_items_in_batch,
+ last_without_mask=self.last_without_mask,
+ rescale_keys=self.rescale_keys)
+
+ vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8')
+
+ curoutdir = os.path.join(self.outdir, f'epoch{epoch_i:04d}{suffix}')
+ os.makedirs(curoutdir, exist_ok=True)
+ rank_suffix = f'_r{rank}' if rank is not None else ''
+ out_fname = os.path.join(curoutdir, f'batch{batch_i:07d}{rank_suffix}.jpg')
+
+ vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(out_fname, vis_img)
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/training/visualizers/noop.py b/sd-webui-controlnet/annotator/lama/saicinpainting/training/visualizers/noop.py
new file mode 100644
index 0000000000000000000000000000000000000000..4479597baf33a817686a4f679b4576f83b6e5c31
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/training/visualizers/noop.py
@@ -0,0 +1,9 @@
+from annotator.lama.saicinpainting.training.visualizers.base import BaseVisualizer
+
+
+class NoopVisualizer(BaseVisualizer):
+ def __init__(self, *args, **kwargs):
+ pass
+
+ def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None):
+ pass
diff --git a/sd-webui-controlnet/annotator/lama/saicinpainting/utils.py b/sd-webui-controlnet/annotator/lama/saicinpainting/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f36f5130d4c105b63689642da5321ce2e1863a9f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lama/saicinpainting/utils.py
@@ -0,0 +1,174 @@
+import bisect
+import functools
+import logging
+import numbers
+import os
+import signal
+import sys
+import traceback
+import warnings
+
+import torch
+from pytorch_lightning import seed_everything
+
+LOGGER = logging.getLogger(__name__)
+
+
+def check_and_warn_input_range(tensor, min_value, max_value, name):
+ actual_min = tensor.min()
+ actual_max = tensor.max()
+ if actual_min < min_value or actual_max > max_value:
+ warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}")
+
+
+def sum_dict_with_prefix(target, cur_dict, prefix, default=0):
+ for k, v in cur_dict.items():
+ target_key = prefix + k
+ target[target_key] = target.get(target_key, default) + v
+
+
+def average_dicts(dict_list):
+ result = {}
+ norm = 1e-3
+ for dct in dict_list:
+ sum_dict_with_prefix(result, dct, '')
+ norm += 1
+ for k in list(result):
+ result[k] /= norm
+ return result
+
+
+def add_prefix_to_keys(dct, prefix):
+ return {prefix + k: v for k, v in dct.items()}
+
+
+def set_requires_grad(module, value):
+ for param in module.parameters():
+ param.requires_grad = value
+
+
+def flatten_dict(dct):
+ result = {}
+ for k, v in dct.items():
+ if isinstance(k, tuple):
+ k = '_'.join(k)
+ if isinstance(v, dict):
+ for sub_k, sub_v in flatten_dict(v).items():
+ result[f'{k}_{sub_k}'] = sub_v
+ else:
+ result[k] = v
+ return result
+
+
+class LinearRamp:
+ def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
+ self.start_value = start_value
+ self.end_value = end_value
+ self.start_iter = start_iter
+ self.end_iter = end_iter
+
+ def __call__(self, i):
+ if i < self.start_iter:
+ return self.start_value
+ if i >= self.end_iter:
+ return self.end_value
+ part = (i - self.start_iter) / (self.end_iter - self.start_iter)
+ return self.start_value * (1 - part) + self.end_value * part
+
+
+class LadderRamp:
+ def __init__(self, start_iters, values):
+ self.start_iters = start_iters
+ self.values = values
+ assert len(values) == len(start_iters) + 1, (len(values), len(start_iters))
+
+ def __call__(self, i):
+ segment_i = bisect.bisect_right(self.start_iters, i)
+ return self.values[segment_i]
+
+
+def get_ramp(kind='ladder', **kwargs):
+ if kind == 'linear':
+ return LinearRamp(**kwargs)
+ if kind == 'ladder':
+ return LadderRamp(**kwargs)
+ raise ValueError(f'Unexpected ramp kind: {kind}')
+
+
+def print_traceback_handler(sig, frame):
+ LOGGER.warning(f'Received signal {sig}')
+ bt = ''.join(traceback.format_stack())
+ LOGGER.warning(f'Requested stack trace:\n{bt}')
+
+
+def register_debug_signal_handlers(sig=None, handler=print_traceback_handler):
+ LOGGER.warning(f'Setting signal {sig} handler {handler}')
+ signal.signal(sig, handler)
+
+
+def handle_deterministic_config(config):
+ seed = dict(config).get('seed', None)
+ if seed is None:
+ return False
+
+ seed_everything(seed)
+ return True
+
+
+def get_shape(t):
+ if torch.is_tensor(t):
+ return tuple(t.shape)
+ elif isinstance(t, dict):
+ return {n: get_shape(q) for n, q in t.items()}
+ elif isinstance(t, (list, tuple)):
+ return [get_shape(q) for q in t]
+ elif isinstance(t, numbers.Number):
+ return type(t)
+ else:
+ raise ValueError('unexpected type {}'.format(type(t)))
+
+
+def get_has_ddp_rank():
+ master_port = os.environ.get('MASTER_PORT', None)
+ node_rank = os.environ.get('NODE_RANK', None)
+ local_rank = os.environ.get('LOCAL_RANK', None)
+ world_size = os.environ.get('WORLD_SIZE', None)
+ has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None
+ return has_rank
+
+
+def handle_ddp_subprocess():
+ def main_decorator(main_func):
+ @functools.wraps(main_func)
+ def new_main(*args, **kwargs):
+ # Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE
+ parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
+ has_parent = parent_cwd is not None
+ has_rank = get_has_ddp_rank()
+ assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
+
+ if has_parent:
+ # we are in the worker
+ sys.argv.extend([
+ f'hydra.run.dir={parent_cwd}',
+ # 'hydra/hydra_logging=disabled',
+ # 'hydra/job_logging=disabled'
+ ])
+ # do nothing if this is a top-level process
+ # TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization
+
+ main_func(*args, **kwargs)
+ return new_main
+ return main_decorator
+
+
+def handle_ddp_parent_process():
+ parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
+ has_parent = parent_cwd is not None
+ has_rank = get_has_ddp_rank()
+ assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
+
+ if parent_cwd is None:
+ os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd()
+
+ return has_parent
diff --git a/sd-webui-controlnet/annotator/leres/__init__.py b/sd-webui-controlnet/annotator/leres/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b11e44a954b68a634326d097bcb54b8876524b4
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/__init__.py
@@ -0,0 +1,113 @@
+import cv2
+import numpy as np
+import torch
+import os
+from modules import devices, shared
+from annotator.annotator_path import models_path
+from torchvision.transforms import transforms
+
+# AdelaiDepth/LeReS imports
+from .leres.depthmap import estimateleres, estimateboost
+from .leres.multi_depth_model_woauxi import RelDepthModel
+from .leres.net_tools import strip_prefix_if_present
+
+# pix2pix/merge net imports
+from .pix2pix.options.test_options import TestOptions
+from .pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel
+
+base_model_path = os.path.join(models_path, "leres")
+old_modeldir = os.path.dirname(os.path.realpath(__file__))
+
+remote_model_path_leres = "https://huggingface.co/lllyasviel/Annotators/resolve/main/res101.pth"
+remote_model_path_pix2pix = "https://huggingface.co/lllyasviel/Annotators/resolve/main/latest_net_G.pth"
+
+model = None
+pix2pixmodel = None
+
+def unload_leres_model():
+ global model, pix2pixmodel
+ if model is not None:
+ model = model.cpu()
+ if pix2pixmodel is not None:
+ pix2pixmodel = pix2pixmodel.unload_network('G')
+
+
+def apply_leres(input_image, thr_a, thr_b, boost=False):
+ global model, pix2pixmodel
+ if model is None:
+ model_path = os.path.join(base_model_path, "res101.pth")
+ old_model_path = os.path.join(old_modeldir, "res101.pth")
+
+ if os.path.exists(old_model_path):
+ model_path = old_model_path
+ elif not os.path.exists(model_path):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path_leres, model_dir=base_model_path)
+
+ if torch.cuda.is_available():
+ checkpoint = torch.load(model_path)
+ else:
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
+
+ model = RelDepthModel(backbone='resnext101')
+ model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True)
+ del checkpoint
+
+ if boost and pix2pixmodel is None:
+ pix2pixmodel_path = os.path.join(base_model_path, "latest_net_G.pth")
+ if not os.path.exists(pix2pixmodel_path):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path_pix2pix, model_dir=base_model_path)
+
+ opt = TestOptions().parse()
+ if not torch.cuda.is_available():
+ opt.gpu_ids = [] # cpu mode
+ pix2pixmodel = Pix2Pix4DepthModel(opt)
+ pix2pixmodel.save_dir = base_model_path
+ pix2pixmodel.load_networks('latest')
+ pix2pixmodel.eval()
+
+ if devices.get_device_for("controlnet").type != 'mps':
+ model = model.to(devices.get_device_for("controlnet"))
+
+ assert input_image.ndim == 3
+ height, width, dim = input_image.shape
+
+ with torch.no_grad():
+
+ if boost:
+ depth = estimateboost(input_image, model, 0, pix2pixmodel, max(width, height))
+ else:
+ depth = estimateleres(input_image, model, width, height)
+
+ numbytes=2
+ depth_min = depth.min()
+ depth_max = depth.max()
+ max_val = (2**(8*numbytes))-1
+
+ # check output before normalizing and mapping to 16 bit
+ if depth_max - depth_min > np.finfo("float").eps:
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
+ else:
+ out = np.zeros(depth.shape)
+
+ # single channel, 16 bit image
+ depth_image = out.astype("uint16")
+
+ # convert to uint8
+ depth_image = cv2.convertScaleAbs(depth_image, alpha=(255.0/65535.0))
+
+ # remove near
+ if thr_a != 0:
+ thr_a = ((thr_a/100)*255)
+ depth_image = cv2.threshold(depth_image, thr_a, 255, cv2.THRESH_TOZERO)[1]
+
+ # invert image
+ depth_image = cv2.bitwise_not(depth_image)
+
+ # remove bg
+ if thr_b != 0:
+ thr_b = ((thr_b/100)*255)
+ depth_image = cv2.threshold(depth_image, thr_b, 255, cv2.THRESH_TOZERO)[1]
+
+ return depth_image
diff --git a/sd-webui-controlnet/annotator/leres/leres/LICENSE b/sd-webui-controlnet/annotator/leres/leres/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..e0f1d07d98d4e85e684734d058dfe2515d215405
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/leres/LICENSE
@@ -0,0 +1,23 @@
+https://github.com/thygate/stable-diffusion-webui-depthmap-script
+
+MIT License
+
+Copyright (c) 2023 Bob Thiry
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/leres/leres/Resnet.py b/sd-webui-controlnet/annotator/leres/leres/Resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f12c9975c1aa05401269be3ca3dbaa56bde55581
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/leres/Resnet.py
@@ -0,0 +1,199 @@
+import torch.nn as nn
+import torch.nn as NN
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+ 'resnet152']
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, num_classes=1000):
+ self.inplanes = 64
+ super(ResNet, self).__init__()
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ #self.avgpool = nn.AvgPool2d(7, stride=1)
+ #self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ features = []
+
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ features.append(x)
+ x = self.layer2(x)
+ features.append(x)
+ x = self.layer3(x)
+ features.append(x)
+ x = self.layer4(x)
+ features.append(x)
+
+ return features
+
+
+def resnet18(pretrained=True, **kwargs):
+ """Constructs a ResNet-18 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+ return model
+
+
+def resnet34(pretrained=True, **kwargs):
+ """Constructs a ResNet-34 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
+ return model
+
+
+def resnet50(pretrained=True, **kwargs):
+ """Constructs a ResNet-50 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+
+ return model
+
+
+def resnet101(pretrained=True, **kwargs):
+ """Constructs a ResNet-101 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
+
+ return model
+
+
+def resnet152(pretrained=True, **kwargs):
+ """Constructs a ResNet-152 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
+ return model
diff --git a/sd-webui-controlnet/annotator/leres/leres/Resnext_torch.py b/sd-webui-controlnet/annotator/leres/leres/Resnext_torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..9af54fcc3e5b363935ef60c8aaf269110c0d6611
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/leres/Resnext_torch.py
@@ -0,0 +1,237 @@
+#!/usr/bin/env python
+# coding: utf-8
+import torch.nn as nn
+
+try:
+ from urllib import urlretrieve
+except ImportError:
+ from urllib.request import urlretrieve
+
+__all__ = ['resnext101_32x8d']
+
+
+model_urls = {
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+ base_width=64, dilation=1, norm_layer=None):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
+ norm_layer=None):
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ #self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+ #self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0)
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def _forward_impl(self, x):
+ # See note [TorchScript super()]
+ features = []
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ features.append(x)
+
+ x = self.layer2(x)
+ features.append(x)
+
+ x = self.layer3(x)
+ features.append(x)
+
+ x = self.layer4(x)
+ features.append(x)
+
+ #x = self.avgpool(x)
+ #x = torch.flatten(x, 1)
+ #x = self.fc(x)
+
+ return features
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+
+
+def resnext101_32x8d(pretrained=True, **kwargs):
+ """Constructs a ResNet-152 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 8
+
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
+ return model
+
diff --git a/sd-webui-controlnet/annotator/leres/leres/depthmap.py b/sd-webui-controlnet/annotator/leres/leres/depthmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebceecbe28ec248f6f96bb65b1c53bdbaf393ecc
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/leres/depthmap.py
@@ -0,0 +1,546 @@
+# Author: thygate
+# https://github.com/thygate/stable-diffusion-webui-depthmap-script
+
+from modules import devices
+from modules.shared import opts
+from torchvision.transforms import transforms
+from operator import getitem
+
+import torch, gc
+import cv2
+import numpy as np
+import skimage.measure
+
+whole_size_threshold = 1600 # R_max from the paper
+pix2pixsize = 1024
+
+def scale_torch(img):
+ """
+ Scale the image and output it in torch.tensor.
+ :param img: input rgb is in shape [H, W, C], input depth/disp is in shape [H, W]
+ :param scale: the scale factor. float
+ :return: img. [C, H, W]
+ """
+ if len(img.shape) == 2:
+ img = img[np.newaxis, :, :]
+ if img.shape[2] == 3:
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406) , (0.229, 0.224, 0.225) )])
+ img = transform(img.astype(np.float32))
+ else:
+ img = img.astype(np.float32)
+ img = torch.from_numpy(img)
+ return img
+
+def estimateleres(img, model, w, h):
+ # leres transform input
+ rgb_c = img[:, :, ::-1].copy()
+ A_resize = cv2.resize(rgb_c, (w, h))
+ img_torch = scale_torch(A_resize)[None, :, :, :]
+
+ # compute
+ with torch.no_grad():
+ img_torch = img_torch.to(devices.get_device_for("controlnet"))
+ prediction = model.depth_model(img_torch)
+
+ prediction = prediction.squeeze().cpu().numpy()
+ prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+ return prediction
+
+def generatemask(size):
+ # Generates a Guassian mask
+ mask = np.zeros(size, dtype=np.float32)
+ sigma = int(size[0]/16)
+ k_size = int(2 * np.ceil(2 * int(size[0]/16)) + 1)
+ mask[int(0.15*size[0]):size[0] - int(0.15*size[0]), int(0.15*size[1]): size[1] - int(0.15*size[1])] = 1
+ mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma)
+ mask = (mask - mask.min()) / (mask.max() - mask.min())
+ mask = mask.astype(np.float32)
+ return mask
+
+def resizewithpool(img, size):
+ i_size = img.shape[0]
+ n = int(np.floor(i_size/size))
+
+ out = skimage.measure.block_reduce(img, (n, n), np.max)
+ return out
+
+def rgb2gray(rgb):
+ # Converts rgb to gray
+ return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140])
+
+def calculateprocessingres(img, basesize, confidence=0.1, scale_threshold=3, whole_size_threshold=3000):
+ # Returns the R_x resolution described in section 5 of the main paper.
+
+ # Parameters:
+ # img :input rgb image
+ # basesize : size the dilation kernel which is equal to receptive field of the network.
+ # confidence: value of x in R_x; allowed percentage of pixels that are not getting any contextual cue.
+ # scale_threshold: maximum allowed upscaling on the input image ; it has been set to 3.
+ # whole_size_threshold: maximum allowed resolution. (R_max from section 6 of the main paper)
+
+ # Returns:
+ # outputsize_scale*speed_scale :The computed R_x resolution
+ # patch_scale: K parameter from section 6 of the paper
+
+ # speed scale parameter is to process every image in a smaller size to accelerate the R_x resolution search
+ speed_scale = 32
+ image_dim = int(min(img.shape[0:2]))
+
+ gray = rgb2gray(img)
+ grad = np.abs(cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)) + np.abs(cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3))
+ grad = cv2.resize(grad, (image_dim, image_dim), cv2.INTER_AREA)
+
+ # thresholding the gradient map to generate the edge-map as a proxy of the contextual cues
+ m = grad.min()
+ M = grad.max()
+ middle = m + (0.4 * (M - m))
+ grad[grad < middle] = 0
+ grad[grad >= middle] = 1
+
+ # dilation kernel with size of the receptive field
+ kernel = np.ones((int(basesize/speed_scale), int(basesize/speed_scale)), float)
+ # dilation kernel with size of the a quarter of receptive field used to compute k
+ # as described in section 6 of main paper
+ kernel2 = np.ones((int(basesize / (4*speed_scale)), int(basesize / (4*speed_scale))), float)
+
+ # Output resolution limit set by the whole_size_threshold and scale_threshold.
+ threshold = min(whole_size_threshold, scale_threshold * max(img.shape[:2]))
+
+ outputsize_scale = basesize / speed_scale
+ for p_size in range(int(basesize/speed_scale), int(threshold/speed_scale), int(basesize / (2*speed_scale))):
+ grad_resized = resizewithpool(grad, p_size)
+ grad_resized = cv2.resize(grad_resized, (p_size, p_size), cv2.INTER_NEAREST)
+ grad_resized[grad_resized >= 0.5] = 1
+ grad_resized[grad_resized < 0.5] = 0
+
+ dilated = cv2.dilate(grad_resized, kernel, iterations=1)
+ meanvalue = (1-dilated).mean()
+ if meanvalue > confidence:
+ break
+ else:
+ outputsize_scale = p_size
+
+ grad_region = cv2.dilate(grad_resized, kernel2, iterations=1)
+ patch_scale = grad_region.mean()
+
+ return int(outputsize_scale*speed_scale), patch_scale
+
+# Generate a double-input depth estimation
+def doubleestimate(img, size1, size2, pix2pixsize, model, net_type, pix2pixmodel):
+ # Generate the low resolution estimation
+ estimate1 = singleestimate(img, size1, model, net_type)
+ # Resize to the inference size of merge network.
+ estimate1 = cv2.resize(estimate1, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
+
+ # Generate the high resolution estimation
+ estimate2 = singleestimate(img, size2, model, net_type)
+ # Resize to the inference size of merge network.
+ estimate2 = cv2.resize(estimate2, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
+
+ # Inference on the merge model
+ pix2pixmodel.set_input(estimate1, estimate2)
+ pix2pixmodel.test()
+ visuals = pix2pixmodel.get_current_visuals()
+ prediction_mapped = visuals['fake_B']
+ prediction_mapped = (prediction_mapped+1)/2
+ prediction_mapped = (prediction_mapped - torch.min(prediction_mapped)) / (
+ torch.max(prediction_mapped) - torch.min(prediction_mapped))
+ prediction_mapped = prediction_mapped.squeeze().cpu().numpy()
+
+ return prediction_mapped
+
+# Generate a single-input depth estimation
+def singleestimate(img, msize, model, net_type):
+ # if net_type == 0:
+ return estimateleres(img, model, msize, msize)
+ # else:
+ # return estimatemidasBoost(img, model, msize, msize)
+
+def applyGridpatch(blsize, stride, img, box):
+ # Extract a simple grid patch.
+ counter1 = 0
+ patch_bound_list = {}
+ for k in range(blsize, img.shape[1] - blsize, stride):
+ for j in range(blsize, img.shape[0] - blsize, stride):
+ patch_bound_list[str(counter1)] = {}
+ patchbounds = [j - blsize, k - blsize, j - blsize + 2 * blsize, k - blsize + 2 * blsize]
+ patch_bound = [box[0] + patchbounds[1], box[1] + patchbounds[0], patchbounds[3] - patchbounds[1],
+ patchbounds[2] - patchbounds[0]]
+ patch_bound_list[str(counter1)]['rect'] = patch_bound
+ patch_bound_list[str(counter1)]['size'] = patch_bound[2]
+ counter1 = counter1 + 1
+ return patch_bound_list
+
+# Generating local patches to perform the local refinement described in section 6 of the main paper.
+def generatepatchs(img, base_size):
+
+ # Compute the gradients as a proxy of the contextual cues.
+ img_gray = rgb2gray(img)
+ whole_grad = np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 0, 1, ksize=3)) +\
+ np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 1, 0, ksize=3))
+
+ threshold = whole_grad[whole_grad > 0].mean()
+ whole_grad[whole_grad < threshold] = 0
+
+ # We use the integral image to speed-up the evaluation of the amount of gradients for each patch.
+ gf = whole_grad.sum()/len(whole_grad.reshape(-1))
+ grad_integral_image = cv2.integral(whole_grad)
+
+ # Variables are selected such that the initial patch size would be the receptive field size
+ # and the stride is set to 1/3 of the receptive field size.
+ blsize = int(round(base_size/2))
+ stride = int(round(blsize*0.75))
+
+ # Get initial Grid
+ patch_bound_list = applyGridpatch(blsize, stride, img, [0, 0, 0, 0])
+
+ # Refine initial Grid of patches by discarding the flat (in terms of gradients of the rgb image) ones. Refine
+ # each patch size to ensure that there will be enough depth cues for the network to generate a consistent depth map.
+ print("Selecting patches ...")
+ patch_bound_list = adaptiveselection(grad_integral_image, patch_bound_list, gf)
+
+ # Sort the patch list to make sure the merging operation will be done with the correct order: starting from biggest
+ # patch
+ patchset = sorted(patch_bound_list.items(), key=lambda x: getitem(x[1], 'size'), reverse=True)
+ return patchset
+
+def getGF_fromintegral(integralimage, rect):
+ # Computes the gradient density of a given patch from the gradient integral image.
+ x1 = rect[1]
+ x2 = rect[1]+rect[3]
+ y1 = rect[0]
+ y2 = rect[0]+rect[2]
+ value = integralimage[x2, y2]-integralimage[x1, y2]-integralimage[x2, y1]+integralimage[x1, y1]
+ return value
+
+# Adaptively select patches
+def adaptiveselection(integral_grad, patch_bound_list, gf):
+ patchlist = {}
+ count = 0
+ height, width = integral_grad.shape
+
+ search_step = int(32/factor)
+
+ # Go through all patches
+ for c in range(len(patch_bound_list)):
+ # Get patch
+ bbox = patch_bound_list[str(c)]['rect']
+
+ # Compute the amount of gradients present in the patch from the integral image.
+ cgf = getGF_fromintegral(integral_grad, bbox)/(bbox[2]*bbox[3])
+
+ # Check if patching is beneficial by comparing the gradient density of the patch to
+ # the gradient density of the whole image
+ if cgf >= gf:
+ bbox_test = bbox.copy()
+ patchlist[str(count)] = {}
+
+ # Enlarge each patch until the gradient density of the patch is equal
+ # to the whole image gradient density
+ while True:
+
+ bbox_test[0] = bbox_test[0] - int(search_step/2)
+ bbox_test[1] = bbox_test[1] - int(search_step/2)
+
+ bbox_test[2] = bbox_test[2] + search_step
+ bbox_test[3] = bbox_test[3] + search_step
+
+ # Check if we are still within the image
+ if bbox_test[0] < 0 or bbox_test[1] < 0 or bbox_test[1] + bbox_test[3] >= height \
+ or bbox_test[0] + bbox_test[2] >= width:
+ break
+
+ # Compare gradient density
+ cgf = getGF_fromintegral(integral_grad, bbox_test)/(bbox_test[2]*bbox_test[3])
+ if cgf < gf:
+ break
+ bbox = bbox_test.copy()
+
+ # Add patch to selected patches
+ patchlist[str(count)]['rect'] = bbox
+ patchlist[str(count)]['size'] = bbox[2]
+ count = count + 1
+
+ # Return selected patches
+ return patchlist
+
+def impatch(image, rect):
+ # Extract the given patch pixels from a given image.
+ w1 = rect[0]
+ h1 = rect[1]
+ w2 = w1 + rect[2]
+ h2 = h1 + rect[3]
+ image_patch = image[h1:h2, w1:w2]
+ return image_patch
+
+class ImageandPatchs:
+ def __init__(self, root_dir, name, patchsinfo, rgb_image, scale=1):
+ self.root_dir = root_dir
+ self.patchsinfo = patchsinfo
+ self.name = name
+ self.patchs = patchsinfo
+ self.scale = scale
+
+ self.rgb_image = cv2.resize(rgb_image, (round(rgb_image.shape[1]*scale), round(rgb_image.shape[0]*scale)),
+ interpolation=cv2.INTER_CUBIC)
+
+ self.do_have_estimate = False
+ self.estimation_updated_image = None
+ self.estimation_base_image = None
+
+ def __len__(self):
+ return len(self.patchs)
+
+ def set_base_estimate(self, est):
+ self.estimation_base_image = est
+ if self.estimation_updated_image is not None:
+ self.do_have_estimate = True
+
+ def set_updated_estimate(self, est):
+ self.estimation_updated_image = est
+ if self.estimation_base_image is not None:
+ self.do_have_estimate = True
+
+ def __getitem__(self, index):
+ patch_id = int(self.patchs[index][0])
+ rect = np.array(self.patchs[index][1]['rect'])
+ msize = self.patchs[index][1]['size']
+
+ ## applying scale to rect:
+ rect = np.round(rect * self.scale)
+ rect = rect.astype('int')
+ msize = round(msize * self.scale)
+
+ patch_rgb = impatch(self.rgb_image, rect)
+ if self.do_have_estimate:
+ patch_whole_estimate_base = impatch(self.estimation_base_image, rect)
+ patch_whole_estimate_updated = impatch(self.estimation_updated_image, rect)
+ return {'patch_rgb': patch_rgb, 'patch_whole_estimate_base': patch_whole_estimate_base,
+ 'patch_whole_estimate_updated': patch_whole_estimate_updated, 'rect': rect,
+ 'size': msize, 'id': patch_id}
+ else:
+ return {'patch_rgb': patch_rgb, 'rect': rect, 'size': msize, 'id': patch_id}
+
+ def print_options(self, opt):
+ """Print and save options
+
+ It will print both current options and default values(if different).
+ It will save options into a text file / [checkpoints_dir] / opt.txt
+ """
+ message = ''
+ message += '----------------- Options ---------------\n'
+ for k, v in sorted(vars(opt).items()):
+ comment = ''
+ default = self.parser.get_default(k)
+ if v != default:
+ comment = '\t[default: %s]' % str(default)
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
+ message += '----------------- End -------------------'
+ print(message)
+
+ # save to the disk
+ """
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ util.mkdirs(expr_dir)
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
+ with open(file_name, 'wt') as opt_file:
+ opt_file.write(message)
+ opt_file.write('\n')
+ """
+
+ def parse(self):
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
+ opt = self.gather_options()
+ opt.isTrain = self.isTrain # train or test
+
+ # process opt.suffix
+ if opt.suffix:
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
+ opt.name = opt.name + suffix
+
+ #self.print_options(opt)
+
+ # set gpu ids
+ str_ids = opt.gpu_ids.split(',')
+ opt.gpu_ids = []
+ for str_id in str_ids:
+ id = int(str_id)
+ if id >= 0:
+ opt.gpu_ids.append(id)
+ #if len(opt.gpu_ids) > 0:
+ # torch.cuda.set_device(opt.gpu_ids[0])
+
+ self.opt = opt
+ return self.opt
+
+
+def estimateboost(img, model, model_type, pix2pixmodel, max_res=512):
+ global whole_size_threshold
+
+ # get settings
+ if hasattr(opts, 'depthmap_script_boost_rmax'):
+ whole_size_threshold = opts.depthmap_script_boost_rmax
+
+ if model_type == 0: #leres
+ net_receptive_field_size = 448
+ patch_netsize = 2 * net_receptive_field_size
+ elif model_type == 1: #dpt_beit_large_512
+ net_receptive_field_size = 512
+ patch_netsize = 2 * net_receptive_field_size
+ else: #other midas
+ net_receptive_field_size = 384
+ patch_netsize = 2 * net_receptive_field_size
+
+ gc.collect()
+ devices.torch_gc()
+
+ # Generate mask used to smoothly blend the local pathc estimations to the base estimate.
+ # It is arbitrarily large to avoid artifacts during rescaling for each crop.
+ mask_org = generatemask((3000, 3000))
+ mask = mask_org.copy()
+
+ # Value x of R_x defined in the section 5 of the main paper.
+ r_threshold_value = 0.2
+ #if R0:
+ # r_threshold_value = 0
+
+ input_resolution = img.shape
+ scale_threshold = 3 # Allows up-scaling with a scale up to 3
+
+ # Find the best input resolution R-x. The resolution search described in section 5-double estimation of the main paper and section B of the
+ # supplementary material.
+ whole_image_optimal_size, patch_scale = calculateprocessingres(img, net_receptive_field_size, r_threshold_value, scale_threshold, whole_size_threshold)
+
+ # print('wholeImage being processed in :', whole_image_optimal_size)
+
+ # Generate the base estimate using the double estimation.
+ whole_estimate = doubleestimate(img, net_receptive_field_size, whole_image_optimal_size, pix2pixsize, model, model_type, pix2pixmodel)
+
+ # Compute the multiplier described in section 6 of the main paper to make sure our initial patch can select
+ # small high-density regions of the image.
+ global factor
+ factor = max(min(1, 4 * patch_scale * whole_image_optimal_size / whole_size_threshold), 0.2)
+ # print('Adjust factor is:', 1/factor)
+
+ # Check if Local boosting is beneficial.
+ if max_res < whole_image_optimal_size:
+ # print("No Local boosting. Specified Max Res is smaller than R20, Returning doubleestimate result")
+ return cv2.resize(whole_estimate, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC)
+
+ # Compute the default target resolution.
+ if img.shape[0] > img.shape[1]:
+ a = 2 * whole_image_optimal_size
+ b = round(2 * whole_image_optimal_size * img.shape[1] / img.shape[0])
+ else:
+ a = round(2 * whole_image_optimal_size * img.shape[0] / img.shape[1])
+ b = 2 * whole_image_optimal_size
+ b = int(round(b / factor))
+ a = int(round(a / factor))
+
+ """
+ # recompute a, b and saturate to max res.
+ if max(a,b) > max_res:
+ print('Default Res is higher than max-res: Reducing final resolution')
+ if img.shape[0] > img.shape[1]:
+ a = max_res
+ b = round(max_res * img.shape[1] / img.shape[0])
+ else:
+ a = round(max_res * img.shape[0] / img.shape[1])
+ b = max_res
+ b = int(b)
+ a = int(a)
+ """
+
+ img = cv2.resize(img, (b, a), interpolation=cv2.INTER_CUBIC)
+
+ # Extract selected patches for local refinement
+ base_size = net_receptive_field_size * 2
+ patchset = generatepatchs(img, base_size)
+
+ # print('Target resolution: ', img.shape)
+
+ # Computing a scale in case user prompted to generate the results as the same resolution of the input.
+ # Notice that our method output resolution is independent of the input resolution and this parameter will only
+ # enable a scaling operation during the local patch merge implementation to generate results with the same resolution
+ # as the input.
+ """
+ if output_resolution == 1:
+ mergein_scale = input_resolution[0] / img.shape[0]
+ print('Dynamicly change merged-in resolution; scale:', mergein_scale)
+ else:
+ mergein_scale = 1
+ """
+ # always rescale to input res for now
+ mergein_scale = input_resolution[0] / img.shape[0]
+
+ imageandpatchs = ImageandPatchs('', '', patchset, img, mergein_scale)
+ whole_estimate_resized = cv2.resize(whole_estimate, (round(img.shape[1]*mergein_scale),
+ round(img.shape[0]*mergein_scale)), interpolation=cv2.INTER_CUBIC)
+ imageandpatchs.set_base_estimate(whole_estimate_resized.copy())
+ imageandpatchs.set_updated_estimate(whole_estimate_resized.copy())
+
+ print('Resulting depthmap resolution will be :', whole_estimate_resized.shape[:2])
+ print('Patches to process: '+str(len(imageandpatchs)))
+
+ # Enumerate through all patches, generate their estimations and refining the base estimate.
+ for patch_ind in range(len(imageandpatchs)):
+
+ # Get patch information
+ patch = imageandpatchs[patch_ind] # patch object
+ patch_rgb = patch['patch_rgb'] # rgb patch
+ patch_whole_estimate_base = patch['patch_whole_estimate_base'] # corresponding patch from base
+ rect = patch['rect'] # patch size and location
+ patch_id = patch['id'] # patch ID
+ org_size = patch_whole_estimate_base.shape # the original size from the unscaled input
+ print('\t Processing patch', patch_ind, '/', len(imageandpatchs)-1, '|', rect)
+
+ # We apply double estimation for patches. The high resolution value is fixed to twice the receptive
+ # field size of the network for patches to accelerate the process.
+ patch_estimation = doubleestimate(patch_rgb, net_receptive_field_size, patch_netsize, pix2pixsize, model, model_type, pix2pixmodel)
+ patch_estimation = cv2.resize(patch_estimation, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
+ patch_whole_estimate_base = cv2.resize(patch_whole_estimate_base, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC)
+
+ # Merging the patch estimation into the base estimate using our merge network:
+ # We feed the patch estimation and the same region from the updated base estimate to the merge network
+ # to generate the target estimate for the corresponding region.
+ pix2pixmodel.set_input(patch_whole_estimate_base, patch_estimation)
+
+ # Run merging network
+ pix2pixmodel.test()
+ visuals = pix2pixmodel.get_current_visuals()
+
+ prediction_mapped = visuals['fake_B']
+ prediction_mapped = (prediction_mapped+1)/2
+ prediction_mapped = prediction_mapped.squeeze().cpu().numpy()
+
+ mapped = prediction_mapped
+
+ # We use a simple linear polynomial to make sure the result of the merge network would match the values of
+ # base estimate
+ p_coef = np.polyfit(mapped.reshape(-1), patch_whole_estimate_base.reshape(-1), deg=1)
+ merged = np.polyval(p_coef, mapped.reshape(-1)).reshape(mapped.shape)
+
+ merged = cv2.resize(merged, (org_size[1],org_size[0]), interpolation=cv2.INTER_CUBIC)
+
+ # Get patch size and location
+ w1 = rect[0]
+ h1 = rect[1]
+ w2 = w1 + rect[2]
+ h2 = h1 + rect[3]
+
+ # To speed up the implementation, we only generate the Gaussian mask once with a sufficiently large size
+ # and resize it to our needed size while merging the patches.
+ if mask.shape != org_size:
+ mask = cv2.resize(mask_org, (org_size[1],org_size[0]), interpolation=cv2.INTER_LINEAR)
+
+ tobemergedto = imageandpatchs.estimation_updated_image
+
+ # Update the whole estimation:
+ # We use a simple Gaussian mask to blend the merged patch region with the base estimate to ensure seamless
+ # blending at the boundaries of the patch region.
+ tobemergedto[h1:h2, w1:w2] = np.multiply(tobemergedto[h1:h2, w1:w2], 1 - mask) + np.multiply(merged, mask)
+ imageandpatchs.set_updated_estimate(tobemergedto)
+
+ # output
+ return cv2.resize(imageandpatchs.estimation_updated_image, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC)
diff --git a/sd-webui-controlnet/annotator/leres/leres/multi_depth_model_woauxi.py b/sd-webui-controlnet/annotator/leres/leres/multi_depth_model_woauxi.py
new file mode 100644
index 0000000000000000000000000000000000000000..822ab0893267042446c2a24ed35b4ea053c9914a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/leres/multi_depth_model_woauxi.py
@@ -0,0 +1,34 @@
+from . import network_auxi as network
+from .net_tools import get_func
+import torch
+import torch.nn as nn
+from modules import devices
+
+class RelDepthModel(nn.Module):
+ def __init__(self, backbone='resnet50'):
+ super(RelDepthModel, self).__init__()
+ if backbone == 'resnet50':
+ encoder = 'resnet50_stride32'
+ elif backbone == 'resnext101':
+ encoder = 'resnext101_stride32x8d'
+ self.depth_model = DepthModel(encoder)
+
+ def inference(self, rgb):
+ with torch.no_grad():
+ input = rgb.to(self.depth_model.device)
+ depth = self.depth_model(input)
+ #pred_depth_out = depth - depth.min() + 0.01
+ return depth #pred_depth_out
+
+
+class DepthModel(nn.Module):
+ def __init__(self, encoder):
+ super(DepthModel, self).__init__()
+ backbone = network.__name__.split('.')[-1] + '.' + encoder
+ self.encoder_modules = get_func(backbone)()
+ self.decoder_modules = network.Decoder()
+
+ def forward(self, x):
+ lateral_out = self.encoder_modules(x)
+ out_logit = self.decoder_modules(lateral_out)
+ return out_logit
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/leres/leres/net_tools.py b/sd-webui-controlnet/annotator/leres/leres/net_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..745ba5a0ef19adb869525e6b252db86780b8126e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/leres/net_tools.py
@@ -0,0 +1,54 @@
+import importlib
+import torch
+import os
+from collections import OrderedDict
+
+
+def get_func(func_name):
+ """Helper to return a function object by name. func_name must identify a
+ function in this module or the path to a function relative to the base
+ 'modeling' module.
+ """
+ if func_name == '':
+ return None
+ try:
+ parts = func_name.split('.')
+ # Refers to a function in this module
+ if len(parts) == 1:
+ return globals()[parts[0]]
+ # Otherwise, assume we're referencing a module under modeling
+ module_name = 'annotator.leres.leres.' + '.'.join(parts[:-1])
+ module = importlib.import_module(module_name)
+ return getattr(module, parts[-1])
+ except Exception:
+ print('Failed to f1ind function: %s', func_name)
+ raise
+
+def load_ckpt(args, depth_model, shift_model, focal_model):
+ """
+ Load checkpoint.
+ """
+ if os.path.isfile(args.load_ckpt):
+ print("loading checkpoint %s" % args.load_ckpt)
+ checkpoint = torch.load(args.load_ckpt)
+ if shift_model is not None:
+ shift_model.load_state_dict(strip_prefix_if_present(checkpoint['shift_model'], 'module.'),
+ strict=True)
+ if focal_model is not None:
+ focal_model.load_state_dict(strip_prefix_if_present(checkpoint['focal_model'], 'module.'),
+ strict=True)
+ depth_model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."),
+ strict=True)
+ del checkpoint
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+
+def strip_prefix_if_present(state_dict, prefix):
+ keys = sorted(state_dict.keys())
+ if not all(key.startswith(prefix) for key in keys):
+ return state_dict
+ stripped_state_dict = OrderedDict()
+ for key, value in state_dict.items():
+ stripped_state_dict[key.replace(prefix, "")] = value
+ return stripped_state_dict
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/leres/leres/network_auxi.py b/sd-webui-controlnet/annotator/leres/leres/network_auxi.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bd87011a5339aca632d1a10b217c8737bdc794f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/leres/network_auxi.py
@@ -0,0 +1,417 @@
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+
+from . import Resnet, Resnext_torch
+
+
+def resnet50_stride32():
+ return DepthNet(backbone='resnet', depth=50, upfactors=[2, 2, 2, 2])
+
+def resnext101_stride32x8d():
+ return DepthNet(backbone='resnext101_32x8d', depth=101, upfactors=[2, 2, 2, 2])
+
+
+class Decoder(nn.Module):
+ def __init__(self):
+ super(Decoder, self).__init__()
+ self.inchannels = [256, 512, 1024, 2048]
+ self.midchannels = [256, 256, 256, 512]
+ self.upfactors = [2,2,2,2]
+ self.outchannels = 1
+
+ self.conv = FTB(inchannels=self.inchannels[3], midchannels=self.midchannels[3])
+ self.conv1 = nn.Conv2d(in_channels=self.midchannels[3], out_channels=self.midchannels[2], kernel_size=3, padding=1, stride=1, bias=True)
+ self.upsample = nn.Upsample(scale_factor=self.upfactors[3], mode='bilinear', align_corners=True)
+
+ self.ffm2 = FFM(inchannels=self.inchannels[2], midchannels=self.midchannels[2], outchannels = self.midchannels[2], upfactor=self.upfactors[2])
+ self.ffm1 = FFM(inchannels=self.inchannels[1], midchannels=self.midchannels[1], outchannels = self.midchannels[1], upfactor=self.upfactors[1])
+ self.ffm0 = FFM(inchannels=self.inchannels[0], midchannels=self.midchannels[0], outchannels = self.midchannels[0], upfactor=self.upfactors[0])
+
+ self.outconv = AO(inchannels=self.midchannels[0], outchannels=self.outchannels, upfactor=2)
+ self._init_params()
+
+ def _init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): #NN.BatchNorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+ def forward(self, features):
+ x_32x = self.conv(features[3]) # 1/32
+ x_32 = self.conv1(x_32x)
+ x_16 = self.upsample(x_32) # 1/16
+
+ x_8 = self.ffm2(features[2], x_16) # 1/8
+ x_4 = self.ffm1(features[1], x_8) # 1/4
+ x_2 = self.ffm0(features[0], x_4) # 1/2
+ #-----------------------------------------
+ x = self.outconv(x_2) # original size
+ return x
+
+class DepthNet(nn.Module):
+ __factory = {
+ 18: Resnet.resnet18,
+ 34: Resnet.resnet34,
+ 50: Resnet.resnet50,
+ 101: Resnet.resnet101,
+ 152: Resnet.resnet152
+ }
+ def __init__(self,
+ backbone='resnet',
+ depth=50,
+ upfactors=[2, 2, 2, 2]):
+ super(DepthNet, self).__init__()
+ self.backbone = backbone
+ self.depth = depth
+ self.pretrained = False
+ self.inchannels = [256, 512, 1024, 2048]
+ self.midchannels = [256, 256, 256, 512]
+ self.upfactors = upfactors
+ self.outchannels = 1
+
+ # Build model
+ if self.backbone == 'resnet':
+ if self.depth not in DepthNet.__factory:
+ raise KeyError("Unsupported depth:", self.depth)
+ self.encoder = DepthNet.__factory[depth](pretrained=self.pretrained)
+ elif self.backbone == 'resnext101_32x8d':
+ self.encoder = Resnext_torch.resnext101_32x8d(pretrained=self.pretrained)
+ else:
+ self.encoder = Resnext_torch.resnext101(pretrained=self.pretrained)
+
+ def forward(self, x):
+ x = self.encoder(x) # 1/32, 1/16, 1/8, 1/4
+ return x
+
+
+class FTB(nn.Module):
+ def __init__(self, inchannels, midchannels=512):
+ super(FTB, self).__init__()
+ self.in1 = inchannels
+ self.mid = midchannels
+ self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1,
+ bias=True)
+ # NN.BatchNorm2d
+ self.conv_branch = nn.Sequential(nn.ReLU(inplace=True), \
+ nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3,
+ padding=1, stride=1, bias=True), \
+ nn.BatchNorm2d(num_features=self.mid), \
+ nn.ReLU(inplace=True), \
+ nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3,
+ padding=1, stride=1, bias=True))
+ self.relu = nn.ReLU(inplace=True)
+
+ self.init_params()
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = x + self.conv_branch(x)
+ x = self.relu(x)
+
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+class ATA(nn.Module):
+ def __init__(self, inchannels, reduction=8):
+ super(ATA, self).__init__()
+ self.inchannels = inchannels
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(nn.Linear(self.inchannels * 2, self.inchannels // reduction),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.inchannels // reduction, self.inchannels),
+ nn.Sigmoid())
+ self.init_params()
+
+ def forward(self, low_x, high_x):
+ n, c, _, _ = low_x.size()
+ x = torch.cat([low_x, high_x], 1)
+ x = self.avg_pool(x)
+ x = x.view(n, -1)
+ x = self.fc(x).view(n, c, 1, 1)
+ x = low_x * x + high_x
+
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ # init.normal(m.weight, std=0.01)
+ init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ # init.normal_(m.weight, std=0.01)
+ init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+class FFM(nn.Module):
+ def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
+ super(FFM, self).__init__()
+ self.inchannels = inchannels
+ self.midchannels = midchannels
+ self.outchannels = outchannels
+ self.upfactor = upfactor
+
+ self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels)
+ # self.ata = ATA(inchannels = self.midchannels)
+ self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
+
+ self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
+
+ self.init_params()
+
+ def forward(self, low_x, high_x):
+ x = self.ftb1(low_x)
+ x = x + high_x
+ x = self.ftb2(x)
+ x = self.upsample(x)
+
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+class AO(nn.Module):
+ # Adaptive output module
+ def __init__(self, inchannels, outchannels, upfactor=2):
+ super(AO, self).__init__()
+ self.inchannels = inchannels
+ self.outchannels = outchannels
+ self.upfactor = upfactor
+
+ self.adapt_conv = nn.Sequential(
+ nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels // 2, kernel_size=3, padding=1,
+ stride=1, bias=True), \
+ nn.BatchNorm2d(num_features=self.inchannels // 2), \
+ nn.ReLU(inplace=True), \
+ nn.Conv2d(in_channels=self.inchannels // 2, out_channels=self.outchannels, kernel_size=3, padding=1,
+ stride=1, bias=True), \
+ nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True))
+
+ self.init_params()
+
+ def forward(self, x):
+ x = self.adapt_conv(x)
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+
+# ==============================================================================================================
+
+
+class ResidualConv(nn.Module):
+ def __init__(self, inchannels):
+ super(ResidualConv, self).__init__()
+ # NN.BatchNorm2d
+ self.conv = nn.Sequential(
+ # nn.BatchNorm2d(num_features=inchannels),
+ nn.ReLU(inplace=False),
+ # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True),
+ # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True)
+ nn.Conv2d(in_channels=inchannels, out_channels=inchannels / 2, kernel_size=3, padding=1, stride=1,
+ bias=False),
+ nn.BatchNorm2d(num_features=inchannels / 2),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=inchannels / 2, out_channels=inchannels, kernel_size=3, padding=1, stride=1,
+ bias=False)
+ )
+ self.init_params()
+
+ def forward(self, x):
+ x = self.conv(x) + x
+ return x
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+class FeatureFusion(nn.Module):
+ def __init__(self, inchannels, outchannels):
+ super(FeatureFusion, self).__init__()
+ self.conv = ResidualConv(inchannels=inchannels)
+ # NN.BatchNorm2d
+ self.up = nn.Sequential(ResidualConv(inchannels=inchannels),
+ nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3,
+ stride=2, padding=1, output_padding=1),
+ nn.BatchNorm2d(num_features=outchannels),
+ nn.ReLU(inplace=True))
+
+ def forward(self, lowfeat, highfeat):
+ return self.up(highfeat + self.conv(lowfeat))
+
+ def init_params(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ # init.kaiming_normal_(m.weight, mode='fan_out')
+ init.normal_(m.weight, std=0.01)
+ # init.xavier_normal_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ init.normal_(m.weight, std=0.01)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+
+
+class SenceUnderstand(nn.Module):
+ def __init__(self, channels):
+ super(SenceUnderstand, self).__init__()
+ self.channels = channels
+ self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True))
+ self.pool = nn.AdaptiveAvgPool2d(8)
+ self.fc = nn.Sequential(nn.Linear(512 * 8 * 8, self.channels),
+ nn.ReLU(inplace=True))
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0),
+ nn.ReLU(inplace=True))
+ self.initial_params()
+
+ def forward(self, x):
+ n, c, h, w = x.size()
+ x = self.conv1(x)
+ x = self.pool(x)
+ x = x.view(n, -1)
+ x = self.fc(x)
+ x = x.view(n, self.channels, 1, 1)
+ x = self.conv2(x)
+ x = x.repeat(1, 1, h, w)
+ return x
+
+ def initial_params(self, dev=0.01):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # print torch.sum(m.weight)
+ m.weight.data.normal_(0, dev)
+ if m.bias is not None:
+ m.bias.data.fill_(0)
+ elif isinstance(m, nn.ConvTranspose2d):
+ # print torch.sum(m.weight)
+ m.weight.data.normal_(0, dev)
+ if m.bias is not None:
+ m.bias.data.fill_(0)
+ elif isinstance(m, nn.Linear):
+ m.weight.data.normal_(0, dev)
+
+
+if __name__ == '__main__':
+ net = DepthNet(depth=50, pretrained=True)
+ print(net)
+ inputs = torch.ones(4,3,128,128)
+ out = net(inputs)
+ print(out.size())
+
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/LICENSE b/sd-webui-controlnet/annotator/leres/pix2pix/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..38b1a24fd389a138b930dcf1ee606ef97a0186c8
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/LICENSE
@@ -0,0 +1,19 @@
+https://github.com/compphoto/BoostingMonocularDepth
+
+Copyright 2021, Seyed Mahdi Hosseini Miangoleh, Sebastian Dille, Computational Photography Laboratory. All rights reserved.
+
+This software is for academic use only. A redistribution of this
+software, with or without modifications, has to be for academic
+use only, while giving the appropriate credit to the original
+authors of the software. The methods implemented as a part of
+this software may be covered under patents or patent applications.
+
+THIS SOFTWARE IS PROVIDED BY THE AUTHOR ''AS IS'' AND ANY EXPRESS OR IMPLIED
+WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR
+CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
+ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/models/__init__.py b/sd-webui-controlnet/annotator/leres/pix2pix/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f96e5c7f032f2154c6bb433b68fc968d0a19b5a8
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/models/__init__.py
@@ -0,0 +1,67 @@
+"""This package contains modules related to objective functions, optimizations, and network architectures.
+
+To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
+You need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate loss, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+
+In the function <__init__>, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): define networks used in our training.
+ -- self.visual_names (str list): specify the images that you want to display and save.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
+
+Now you can use the model class by specifying flag '--model dummy'.
+See our template model class 'template_model.py' for more details.
+"""
+
+import importlib
+from .base_model import BaseModel
+
+
+def find_model_using_name(model_name):
+ """Import the module "models/[model_name]_model.py".
+
+ In the file, the class called DatasetNameModel() will
+ be instantiated. It has to be a subclass of BaseModel,
+ and it is case-insensitive.
+ """
+ model_filename = "annotator.leres.pix2pix.models." + model_name + "_model"
+ modellib = importlib.import_module(model_filename)
+ model = None
+ target_model_name = model_name.replace('_', '') + 'model'
+ for name, cls in modellib.__dict__.items():
+ if name.lower() == target_model_name.lower() \
+ and issubclass(cls, BaseModel):
+ model = cls
+
+ if model is None:
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
+ exit(0)
+
+ return model
+
+
+def get_option_setter(model_name):
+ """Return the static method of the model class."""
+ model_class = find_model_using_name(model_name)
+ return model_class.modify_commandline_options
+
+
+def create_model(opt):
+ """Create a model given the option.
+
+ This function warps the class CustomDatasetDataLoader.
+ This is the main interface between this package and 'train.py'/'test.py'
+
+ Example:
+ >>> from models import create_model
+ >>> model = create_model(opt)
+ """
+ model = find_model_using_name(opt.model)
+ instance = model(opt)
+ print("model [%s] was created" % type(instance).__name__)
+ return instance
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/models/base_model.py b/sd-webui-controlnet/annotator/leres/pix2pix/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a90c5f832404bc44ef247b42a72988a37fc834cb
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/models/base_model.py
@@ -0,0 +1,241 @@
+import os
+import torch, gc
+from modules import devices
+from collections import OrderedDict
+from abc import ABC, abstractmethod
+from . import networks
+
+
+class BaseModel(ABC):
+ """This class is an abstract base class (ABC) for models.
+ To create a subclass, you need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate losses, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+ """
+
+ def __init__(self, opt):
+ """Initialize the BaseModel class.
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+ When creating your custom class, you need to implement your own initialization.
+ In this function, you should first call
+ Then, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): define networks used in our training.
+ -- self.visual_names (str list): specify the images that you want to display and save.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+ """
+ self.opt = opt
+ self.gpu_ids = opt.gpu_ids
+ self.isTrain = opt.isTrain
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
+ if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
+ torch.backends.cudnn.benchmark = True
+ self.loss_names = []
+ self.model_names = []
+ self.visual_names = []
+ self.optimizers = []
+ self.image_paths = []
+ self.metric = 0 # used for learning rate policy 'plateau'
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new model-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ return parser
+
+ @abstractmethod
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input (dict): includes the data itself and its metadata information.
+ """
+ pass
+
+ @abstractmethod
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+ pass
+
+ @abstractmethod
+ def optimize_parameters(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ pass
+
+ def setup(self, opt):
+ """Load and print networks; create schedulers
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ if self.isTrain:
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
+ if not self.isTrain or opt.continue_train:
+ load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
+ self.load_networks(load_suffix)
+ self.print_networks(opt.verbose)
+
+ def eval(self):
+ """Make models eval mode during test time"""
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, 'net' + name)
+ net.eval()
+
+ def test(self):
+ """Forward function used in test time.
+
+ This function wraps function in no_grad() so we don't save intermediate steps for backprop
+ It also calls to produce additional visualization results
+ """
+ with torch.no_grad():
+ self.forward()
+ self.compute_visuals()
+
+ def compute_visuals(self):
+ """Calculate additional output images for visdom and HTML visualization"""
+ pass
+
+ def get_image_paths(self):
+ """ Return image paths that are used to load current data"""
+ return self.image_paths
+
+ def update_learning_rate(self):
+ """Update learning rates for all the networks; called at the end of every epoch"""
+ old_lr = self.optimizers[0].param_groups[0]['lr']
+ for scheduler in self.schedulers:
+ if self.opt.lr_policy == 'plateau':
+ scheduler.step(self.metric)
+ else:
+ scheduler.step()
+
+ lr = self.optimizers[0].param_groups[0]['lr']
+ print('learning rate %.7f -> %.7f' % (old_lr, lr))
+
+ def get_current_visuals(self):
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
+ visual_ret = OrderedDict()
+ for name in self.visual_names:
+ if isinstance(name, str):
+ visual_ret[name] = getattr(self, name)
+ return visual_ret
+
+ def get_current_losses(self):
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
+ errors_ret = OrderedDict()
+ for name in self.loss_names:
+ if isinstance(name, str):
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
+ return errors_ret
+
+ def save_networks(self, epoch):
+ """Save all the networks to the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ for name in self.model_names:
+ if isinstance(name, str):
+ save_filename = '%s_net_%s.pth' % (epoch, name)
+ save_path = os.path.join(self.save_dir, save_filename)
+ net = getattr(self, 'net' + name)
+
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
+ torch.save(net.module.cpu().state_dict(), save_path)
+ net.cuda(self.gpu_ids[0])
+ else:
+ torch.save(net.cpu().state_dict(), save_path)
+
+ def unload_network(self, name):
+ """Unload network and gc.
+ """
+ if isinstance(name, str):
+ net = getattr(self, 'net' + name)
+ del net
+ gc.collect()
+ devices.torch_gc()
+ return None
+
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
+ key = keys[i]
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'running_mean' or key == 'running_var'):
+ if getattr(module, key) is None:
+ state_dict.pop('.'.join(keys))
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'num_batches_tracked'):
+ state_dict.pop('.'.join(keys))
+ else:
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
+
+ def load_networks(self, epoch):
+ """Load all the networks from the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ for name in self.model_names:
+ if isinstance(name, str):
+ load_filename = '%s_net_%s.pth' % (epoch, name)
+ load_path = os.path.join(self.save_dir, load_filename)
+ net = getattr(self, 'net' + name)
+ if isinstance(net, torch.nn.DataParallel):
+ net = net.module
+ # print('Loading depth boost model from %s' % load_path)
+ # if you are using PyTorch newer than 0.4 (e.g., built from
+ # GitHub source), you can remove str() on self.device
+ state_dict = torch.load(load_path, map_location=str(self.device))
+ if hasattr(state_dict, '_metadata'):
+ del state_dict._metadata
+
+ # patch InstanceNorm checkpoints prior to 0.4
+ for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
+ self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
+ net.load_state_dict(state_dict)
+
+ def print_networks(self, verbose):
+ """Print the total number of parameters in the network and (if verbose) network architecture
+
+ Parameters:
+ verbose (bool) -- if verbose: print the network architecture
+ """
+ print('---------- Networks initialized -------------')
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, 'net' + name)
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ if verbose:
+ print(net)
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
+ print('-----------------------------------------------')
+
+ def set_requires_grad(self, nets, requires_grad=False):
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
+ Parameters:
+ nets (network list) -- a list of networks
+ requires_grad (bool) -- whether the networks require gradients or not
+ """
+ if not isinstance(nets, list):
+ nets = [nets]
+ for net in nets:
+ if net is not None:
+ for param in net.parameters():
+ param.requires_grad = requires_grad
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/models/base_model_hg.py b/sd-webui-controlnet/annotator/leres/pix2pix/models/base_model_hg.py
new file mode 100644
index 0000000000000000000000000000000000000000..1709accdf0b048b3793dfd1f58d1b06c35f7b907
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/models/base_model_hg.py
@@ -0,0 +1,58 @@
+import os
+import torch
+
+class BaseModelHG():
+ def name(self):
+ return 'BaseModel'
+
+ def initialize(self, opt):
+ self.opt = opt
+ self.gpu_ids = opt.gpu_ids
+ self.isTrain = opt.isTrain
+ self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
+
+ def set_input(self, input):
+ self.input = input
+
+ def forward(self):
+ pass
+
+ # used in test time, no backprop
+ def test(self):
+ pass
+
+ def get_image_paths(self):
+ pass
+
+ def optimize_parameters(self):
+ pass
+
+ def get_current_visuals(self):
+ return self.input
+
+ def get_current_errors(self):
+ return {}
+
+ def save(self, label):
+ pass
+
+ # helper saving function that can be used by subclasses
+ def save_network(self, network, network_label, epoch_label, gpu_ids):
+ save_filename = '_%s_net_%s.pth' % (epoch_label, network_label)
+ save_path = os.path.join(self.save_dir, save_filename)
+ torch.save(network.cpu().state_dict(), save_path)
+ if len(gpu_ids) and torch.cuda.is_available():
+ network.cuda(device_id=gpu_ids[0])
+
+ # helper loading function that can be used by subclasses
+ def load_network(self, network, network_label, epoch_label):
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
+ save_path = os.path.join(self.save_dir, save_filename)
+ print(save_path)
+ model = torch.load(save_path)
+ return model
+ # network.load_state_dict(torch.load(save_path))
+
+ def update_learning_rate():
+ pass
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/models/networks.py b/sd-webui-controlnet/annotator/leres/pix2pix/models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cf912b2973721a02deefd042af621e732bad59f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/models/networks.py
@@ -0,0 +1,623 @@
+import torch
+import torch.nn as nn
+from torch.nn import init
+import functools
+from torch.optim import lr_scheduler
+
+
+###############################################################################
+# Helper Functions
+###############################################################################
+
+
+class Identity(nn.Module):
+ def forward(self, x):
+ return x
+
+
+def get_norm_layer(norm_type='instance'):
+ """Return a normalization layer
+
+ Parameters:
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
+
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
+ """
+ if norm_type == 'batch':
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
+ elif norm_type == 'instance':
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
+ elif norm_type == 'none':
+ def norm_layer(x): return Identity()
+ else:
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
+ return norm_layer
+
+
+def get_scheduler(optimizer, opt):
+ """Return a learning rate scheduler
+
+ Parameters:
+ optimizer -- the optimizer of the network
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
+
+ For 'linear', we keep the same learning rate for the first epochs
+ and linearly decay the rate to zero over the next epochs.
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
+ See https://pytorch.org/docs/stable/optim.html for more details.
+ """
+ if opt.lr_policy == 'linear':
+ def lambda_rule(epoch):
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
+ return lr_l
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
+ elif opt.lr_policy == 'step':
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
+ elif opt.lr_policy == 'plateau':
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
+ elif opt.lr_policy == 'cosine':
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
+ else:
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
+ return scheduler
+
+
+def init_weights(net, init_type='normal', init_gain=0.02):
+ """Initialize network weights.
+
+ Parameters:
+ net (network) -- network to be initialized
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
+ work better for some applications. Feel free to try yourself.
+ """
+ def init_func(m): # define the initialization function
+ classname = m.__class__.__name__
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ init.normal_(m.weight.data, 0.0, init_gain)
+ elif init_type == 'xavier':
+ init.xavier_normal_(m.weight.data, gain=init_gain)
+ elif init_type == 'kaiming':
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ init.orthogonal_(m.weight.data, gain=init_gain)
+ else:
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
+ init.normal_(m.weight.data, 1.0, init_gain)
+ init.constant_(m.bias.data, 0.0)
+
+ # print('initialize network with %s' % init_type)
+ net.apply(init_func) # apply the initialization function
+
+
+def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
+ Parameters:
+ net (network) -- the network to be initialized
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
+
+ Return an initialized network.
+ """
+ if len(gpu_ids) > 0:
+ assert(torch.cuda.is_available())
+ net.to(gpu_ids[0])
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
+ init_weights(net, init_type, init_gain=init_gain)
+ return net
+
+
+def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
+ """Create a generator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
+ norm (str) -- the name of normalization layers used in the network: batch | instance | none
+ use_dropout (bool) -- if use dropout layers.
+ init_type (str) -- the name of our initialization method.
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
+
+ Returns a generator
+
+ Our current implementation provides two types of generators:
+ U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
+ The original U-Net paper: https://arxiv.org/abs/1505.04597
+
+ Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
+ Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
+ We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
+
+
+ The generator has been initialized by . It uses RELU for non-linearity.
+ """
+ net = None
+ norm_layer = get_norm_layer(norm_type=norm)
+
+ if netG == 'resnet_9blocks':
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
+ elif netG == 'resnet_6blocks':
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
+ elif netG == 'resnet_12blocks':
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=12)
+ elif netG == 'unet_128':
+ net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
+ elif netG == 'unet_256':
+ net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
+ elif netG == 'unet_672':
+ net = UnetGenerator(input_nc, output_nc, 5, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
+ elif netG == 'unet_960':
+ net = UnetGenerator(input_nc, output_nc, 6, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
+ elif netG == 'unet_1024':
+ net = UnetGenerator(input_nc, output_nc, 10, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
+ else:
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
+ return init_net(net, init_type, init_gain, gpu_ids)
+
+
+def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
+ """Create a discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the first conv layer
+ netD (str) -- the architecture's name: basic | n_layers | pixel
+ n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
+ norm (str) -- the type of normalization layers used in the network.
+ init_type (str) -- the name of the initialization method.
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
+
+ Returns a discriminator
+
+ Our current implementation provides three types of discriminators:
+ [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
+ It can classify whether 70×70 overlapping patches are real or fake.
+ Such a patch-level discriminator architecture has fewer parameters
+ than a full-image discriminator and can work on arbitrarily-sized images
+ in a fully convolutional fashion.
+
+ [n_layers]: With this mode, you can specify the number of conv layers in the discriminator
+ with the parameter (default=3 as used in [basic] (PatchGAN).)
+
+ [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
+ It encourages greater color diversity but has no effect on spatial statistics.
+
+ The discriminator has been initialized by . It uses Leakly RELU for non-linearity.
+ """
+ net = None
+ norm_layer = get_norm_layer(norm_type=norm)
+
+ if netD == 'basic': # default PatchGAN classifier
+ net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
+ elif netD == 'n_layers': # more options
+ net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
+ elif netD == 'pixel': # classify if each pixel is real or fake
+ net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
+ else:
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
+ return init_net(net, init_type, init_gain, gpu_ids)
+
+
+##############################################################################
+# Classes
+##############################################################################
+class GANLoss(nn.Module):
+ """Define different GAN objectives.
+
+ The GANLoss class abstracts away the need to create the target label tensor
+ that has the same size as the input.
+ """
+
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
+ """ Initialize the GANLoss class.
+
+ Parameters:
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
+ target_real_label (bool) - - label for a real image
+ target_fake_label (bool) - - label of a fake image
+
+ Note: Do not use sigmoid as the last layer of Discriminator.
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
+ """
+ super(GANLoss, self).__init__()
+ self.register_buffer('real_label', torch.tensor(target_real_label))
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
+ self.gan_mode = gan_mode
+ if gan_mode == 'lsgan':
+ self.loss = nn.MSELoss()
+ elif gan_mode == 'vanilla':
+ self.loss = nn.BCEWithLogitsLoss()
+ elif gan_mode in ['wgangp']:
+ self.loss = None
+ else:
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
+
+ def get_target_tensor(self, prediction, target_is_real):
+ """Create label tensors with the same size as the input.
+
+ Parameters:
+ prediction (tensor) - - tpyically the prediction from a discriminator
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
+
+ Returns:
+ A label tensor filled with ground truth label, and with the size of the input
+ """
+
+ if target_is_real:
+ target_tensor = self.real_label
+ else:
+ target_tensor = self.fake_label
+ return target_tensor.expand_as(prediction)
+
+ def __call__(self, prediction, target_is_real):
+ """Calculate loss given Discriminator's output and grount truth labels.
+
+ Parameters:
+ prediction (tensor) - - tpyically the prediction output from a discriminator
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
+
+ Returns:
+ the calculated loss.
+ """
+ if self.gan_mode in ['lsgan', 'vanilla']:
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
+ loss = self.loss(prediction, target_tensor)
+ elif self.gan_mode == 'wgangp':
+ if target_is_real:
+ loss = -prediction.mean()
+ else:
+ loss = prediction.mean()
+ return loss
+
+
+def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
+
+ Arguments:
+ netD (network) -- discriminator network
+ real_data (tensor array) -- real images
+ fake_data (tensor array) -- generated images from the generator
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
+ constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2
+ lambda_gp (float) -- weight for this loss
+
+ Returns the gradient penalty loss
+ """
+ if lambda_gp > 0.0:
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
+ interpolatesv = real_data
+ elif type == 'fake':
+ interpolatesv = fake_data
+ elif type == 'mixed':
+ alpha = torch.rand(real_data.shape[0], 1, device=device)
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
+ else:
+ raise NotImplementedError('{} not implemented'.format(type))
+ interpolatesv.requires_grad_(True)
+ disc_interpolates = netD(interpolatesv)
+ gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
+ create_graph=True, retain_graph=True, only_inputs=True)
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
+ return gradient_penalty, gradients
+ else:
+ return 0.0, None
+
+
+class ResnetGenerator(nn.Module):
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
+
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
+ """
+
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
+ """Construct a Resnet-based generator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
+ """
+ assert(n_blocks >= 0)
+ super(ResnetGenerator, self).__init__()
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ model = [nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
+ norm_layer(ngf),
+ nn.ReLU(True)]
+
+ n_downsampling = 2
+ for i in range(n_downsampling): # add downsampling layers
+ mult = 2 ** i
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+
+ mult = 2 ** n_downsampling
+ for i in range(n_blocks): # add ResNet blocks
+
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+
+ for i in range(n_downsampling): # add upsampling layers
+ mult = 2 ** (n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ model += [nn.ReflectionPad2d(3)]
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ model += [nn.Tanh()]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ """Standard forward"""
+ return self.model(input)
+
+
+class ResnetBlock(nn.Module):
+ """Define a Resnet block"""
+
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
+ """Initialize the Resnet block
+
+ A resnet block is a conv block with skip connections
+ We construct a conv block with build_conv_block function,
+ and implement skip connections in function.
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
+ """
+ super(ResnetBlock, self).__init__()
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
+
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
+ """Construct a convolutional block.
+
+ Parameters:
+ dim (int) -- the number of channels in the conv layer.
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers.
+ use_bias (bool) -- if the conv layer uses bias or not
+
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
+ """
+ conv_block = []
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
+ if use_dropout:
+ conv_block += [nn.Dropout(0.5)]
+
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
+
+ return nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ """Forward function (with skip connections)"""
+ out = x + self.conv_block(x) # add skip connections
+ return out
+
+
+class UnetGenerator(nn.Module):
+ """Create a Unet-based generator"""
+
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
+ """Construct a Unet generator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
+ image of size 128x128 will become of size 1x1 # at the bottleneck
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+
+ We construct the U-Net from the innermost layer to the outermost layer.
+ It is a recursive process.
+ """
+ super(UnetGenerator, self).__init__()
+ # construct unet structure
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
+ for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
+ # gradually reduce the number of filters from ngf * 8 to ngf
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
+
+ def forward(self, input):
+ """Standard forward"""
+ return self.model(input)
+
+
+class UnetSkipConnectionBlock(nn.Module):
+ """Defines the Unet submodule with skip connection.
+ X -------------------identity----------------------
+ |-- downsampling -- |submodule| -- upsampling --|
+ """
+
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
+ """Construct a Unet submodule with skip connections.
+
+ Parameters:
+ outer_nc (int) -- the number of filters in the outer conv layer
+ inner_nc (int) -- the number of filters in the inner conv layer
+ input_nc (int) -- the number of channels in input images/features
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
+ outermost (bool) -- if this module is the outermost module
+ innermost (bool) -- if this module is the innermost module
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers.
+ """
+ super(UnetSkipConnectionBlock, self).__init__()
+ self.outermost = outermost
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+ if input_nc is None:
+ input_nc = outer_nc
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
+ stride=2, padding=1, bias=use_bias)
+ downrelu = nn.LeakyReLU(0.2, True)
+ downnorm = norm_layer(inner_nc)
+ uprelu = nn.ReLU(True)
+ upnorm = norm_layer(outer_nc)
+
+ if outermost:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1)
+ down = [downconv]
+ up = [uprelu, upconv, nn.Tanh()]
+ model = down + [submodule] + up
+ elif innermost:
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv]
+ up = [uprelu, upconv, upnorm]
+ model = down + up
+ else:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv, downnorm]
+ up = [uprelu, upconv, upnorm]
+
+ if use_dropout:
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
+ else:
+ model = down + [submodule] + up
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ if self.outermost:
+ return self.model(x)
+ else: # add skip connections
+ return torch.cat([x, self.model(x)], 1)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator"""
+
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
+ """Construct a PatchGAN discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n_layers, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
+ self.model = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.model(input)
+
+
+class PixelDiscriminator(nn.Module):
+ """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
+
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
+ """Construct a 1x1 PatchGAN discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ """
+ super(PixelDiscriminator, self).__init__()
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ self.net = [
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
+ norm_layer(ndf * 2),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
+
+ self.net = nn.Sequential(*self.net)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.net(input)
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/models/pix2pix4depth_model.py b/sd-webui-controlnet/annotator/leres/pix2pix/models/pix2pix4depth_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..89e89652feb96314973a050c5a2477b474630abb
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/models/pix2pix4depth_model.py
@@ -0,0 +1,155 @@
+import torch
+from .base_model import BaseModel
+from . import networks
+
+
+class Pix2Pix4DepthModel(BaseModel):
+ """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.
+
+ The model training requires '--dataset_mode aligned' dataset.
+ By default, it uses a '--netG unet256' U-Net generator,
+ a '--netD basic' discriminator (PatchGAN),
+ and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
+
+ pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
+ """
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """Add new dataset-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+
+ For pix2pix, we do not use image buffer
+ The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1
+ By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets.
+ """
+ # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/)
+ parser.set_defaults(input_nc=2,output_nc=1,norm='none', netG='unet_1024', dataset_mode='depthmerge')
+ if is_train:
+ parser.set_defaults(pool_size=0, gan_mode='vanilla',)
+ parser.add_argument('--lambda_L1', type=float, default=1000, help='weight for L1 loss')
+ return parser
+
+ def __init__(self, opt):
+ """Initialize the pix2pix class.
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ BaseModel.__init__(self, opt)
+ # specify the training losses you want to print out. The training/test scripts will call
+
+ self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
+ # self.loss_names = ['G_L1']
+
+ # specify the images you want to save/display. The training/test scripts will call
+ if self.isTrain:
+ self.visual_names = ['outer','inner', 'fake_B', 'real_B']
+ else:
+ self.visual_names = ['fake_B']
+
+ # specify the models you want to save to the disk. The training/test scripts will call and
+ if self.isTrain:
+ self.model_names = ['G','D']
+ else: # during test time, only load G
+ self.model_names = ['G']
+
+ # define networks (both generator and discriminator)
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, 64, 'unet_1024', 'none',
+ False, 'normal', 0.02, self.gpu_ids)
+
+ if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
+ self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
+
+ if self.isTrain:
+ # define loss functions
+ self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
+ self.criterionL1 = torch.nn.L1Loss()
+ # initialize optimizers; schedulers will be automatically created by function .
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=1e-4, betas=(opt.beta1, 0.999))
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=2e-06, betas=(opt.beta1, 0.999))
+ self.optimizers.append(self.optimizer_G)
+ self.optimizers.append(self.optimizer_D)
+
+ def set_input_train(self, input):
+ self.outer = input['data_outer'].to(self.device)
+ self.outer = torch.nn.functional.interpolate(self.outer,(1024,1024),mode='bilinear',align_corners=False)
+
+ self.inner = input['data_inner'].to(self.device)
+ self.inner = torch.nn.functional.interpolate(self.inner,(1024,1024),mode='bilinear',align_corners=False)
+
+ self.image_paths = input['image_path']
+
+ if self.isTrain:
+ self.gtfake = input['data_gtfake'].to(self.device)
+ self.gtfake = torch.nn.functional.interpolate(self.gtfake, (1024, 1024), mode='bilinear', align_corners=False)
+ self.real_B = self.gtfake
+
+ self.real_A = torch.cat((self.outer, self.inner), 1)
+
+ def set_input(self, outer, inner):
+ inner = torch.from_numpy(inner).unsqueeze(0).unsqueeze(0)
+ outer = torch.from_numpy(outer).unsqueeze(0).unsqueeze(0)
+
+ inner = (inner - torch.min(inner))/(torch.max(inner)-torch.min(inner))
+ outer = (outer - torch.min(outer))/(torch.max(outer)-torch.min(outer))
+
+ inner = self.normalize(inner)
+ outer = self.normalize(outer)
+
+ self.real_A = torch.cat((outer, inner), 1).to(self.device)
+
+
+ def normalize(self, input):
+ input = input * 2
+ input = input - 1
+ return input
+
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+ self.fake_B = self.netG(self.real_A) # G(A)
+
+ def backward_D(self):
+ """Calculate GAN loss for the discriminator"""
+ # Fake; stop backprop to the generator by detaching fake_B
+ fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator
+ pred_fake = self.netD(fake_AB.detach())
+ self.loss_D_fake = self.criterionGAN(pred_fake, False)
+ # Real
+ real_AB = torch.cat((self.real_A, self.real_B), 1)
+ pred_real = self.netD(real_AB)
+ self.loss_D_real = self.criterionGAN(pred_real, True)
+ # combine loss and calculate gradients
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
+ self.loss_D.backward()
+
+ def backward_G(self):
+ """Calculate GAN and L1 loss for the generator"""
+ # First, G(A) should fake the discriminator
+ fake_AB = torch.cat((self.real_A, self.fake_B), 1)
+ pred_fake = self.netD(fake_AB)
+ self.loss_G_GAN = self.criterionGAN(pred_fake, True)
+ # Second, G(A) = B
+ self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
+ # combine loss and calculate gradients
+ self.loss_G = self.loss_G_L1 + self.loss_G_GAN
+ self.loss_G.backward()
+
+ def optimize_parameters(self):
+ self.forward() # compute fake images: G(A)
+ # update D
+ self.set_requires_grad(self.netD, True) # enable backprop for D
+ self.optimizer_D.zero_grad() # set D's gradients to zero
+ self.backward_D() # calculate gradients for D
+ self.optimizer_D.step() # update D's weights
+ # update G
+ self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
+ self.optimizer_G.zero_grad() # set G's gradients to zero
+ self.backward_G() # calculate graidents for G
+ self.optimizer_G.step() # udpate G's weights
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/options/__init__.py b/sd-webui-controlnet/annotator/leres/pix2pix/options/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7eedebe54aa70169fd25951b3034d819e396c90
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/options/__init__.py
@@ -0,0 +1 @@
+"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/options/base_options.py b/sd-webui-controlnet/annotator/leres/pix2pix/options/base_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..533a1e88a7e8494223f6994e6861c93667754f83
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/options/base_options.py
@@ -0,0 +1,156 @@
+import argparse
+import os
+from ...pix2pix.util import util
+# import torch
+from ...pix2pix import models
+# import pix2pix.data
+import numpy as np
+
+class BaseOptions():
+ """This class defines options used during both training and test time.
+
+ It also implements several helper functions such as parsing, printing, and saving the options.
+ It also gathers additional options defined in functions in both dataset class and model class.
+ """
+
+ def __init__(self):
+ """Reset the class; indicates the class hasn't been initailized"""
+ self.initialized = False
+
+ def initialize(self, parser):
+ """Define the common options that are used in both training and test."""
+ # basic parameters
+ parser.add_argument('--dataroot', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
+ parser.add_argument('--name', type=str, default='void', help='mahdi_unet_new, scaled_unet')
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
+ parser.add_argument('--checkpoints_dir', type=str, default='./pix2pix/checkpoints', help='models are saved here')
+ # model parameters
+ parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
+ parser.add_argument('--input_nc', type=int, default=2, help='# of input image channels: 3 for RGB and 1 for grayscale')
+ parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale')
+ parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
+ parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
+ parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
+ parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
+ parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
+ parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
+ parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
+ parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
+ parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
+ # dataset parameters
+ parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
+ parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
+ parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
+ parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
+ parser.add_argument('--load_size', type=int, default=672, help='scale images to this size')
+ parser.add_argument('--crop_size', type=int, default=672, help='then crop to this size')
+ parser.add_argument('--max_dataset_size', type=int, default=10000, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
+ parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
+ parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
+ parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
+ # additional parameters
+ parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
+ parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
+ parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
+
+ parser.add_argument('--data_dir', type=str, required=False,
+ help='input files directory images can be .png .jpg .tiff')
+ parser.add_argument('--output_dir', type=str, required=False,
+ help='result dir. result depth will be png. vides are JMPG as avi')
+ parser.add_argument('--savecrops', type=int, required=False)
+ parser.add_argument('--savewholeest', type=int, required=False)
+ parser.add_argument('--output_resolution', type=int, required=False,
+ help='0 for no restriction 1 for resize to input size')
+ parser.add_argument('--net_receptive_field_size', type=int, required=False)
+ parser.add_argument('--pix2pixsize', type=int, required=False)
+ parser.add_argument('--generatevideo', type=int, required=False)
+ parser.add_argument('--depthNet', type=int, required=False, help='0: midas 1:strurturedRL')
+ parser.add_argument('--R0', action='store_true')
+ parser.add_argument('--R20', action='store_true')
+ parser.add_argument('--Final', action='store_true')
+ parser.add_argument('--colorize_results', action='store_true')
+ parser.add_argument('--max_res', type=float, default=np.inf)
+
+ self.initialized = True
+ return parser
+
+ def gather_options(self):
+ """Initialize our parser with basic options(only once).
+ Add additional model-specific and dataset-specific options.
+ These options are defined in the function
+ in model and dataset classes.
+ """
+ if not self.initialized: # check if it has been initialized
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser = self.initialize(parser)
+
+ # get the basic options
+ opt, _ = parser.parse_known_args()
+
+ # modify model-related parser options
+ model_name = opt.model
+ model_option_setter = models.get_option_setter(model_name)
+ parser = model_option_setter(parser, self.isTrain)
+ opt, _ = parser.parse_known_args() # parse again with new defaults
+
+ # modify dataset-related parser options
+ # dataset_name = opt.dataset_mode
+ # dataset_option_setter = pix2pix.data.get_option_setter(dataset_name)
+ # parser = dataset_option_setter(parser, self.isTrain)
+
+ # save and return the parser
+ self.parser = parser
+ #return parser.parse_args() #EVIL
+ return opt
+
+ def print_options(self, opt):
+ """Print and save options
+
+ It will print both current options and default values(if different).
+ It will save options into a text file / [checkpoints_dir] / opt.txt
+ """
+ message = ''
+ message += '----------------- Options ---------------\n'
+ for k, v in sorted(vars(opt).items()):
+ comment = ''
+ default = self.parser.get_default(k)
+ if v != default:
+ comment = '\t[default: %s]' % str(default)
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
+ message += '----------------- End -------------------'
+ print(message)
+
+ # save to the disk
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ util.mkdirs(expr_dir)
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
+ with open(file_name, 'wt') as opt_file:
+ opt_file.write(message)
+ opt_file.write('\n')
+
+ def parse(self):
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
+ opt = self.gather_options()
+ opt.isTrain = self.isTrain # train or test
+
+ # process opt.suffix
+ if opt.suffix:
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
+ opt.name = opt.name + suffix
+
+ #self.print_options(opt)
+
+ # set gpu ids
+ str_ids = opt.gpu_ids.split(',')
+ opt.gpu_ids = []
+ for str_id in str_ids:
+ id = int(str_id)
+ if id >= 0:
+ opt.gpu_ids.append(id)
+ #if len(opt.gpu_ids) > 0:
+ # torch.cuda.set_device(opt.gpu_ids[0])
+
+ self.opt = opt
+ return self.opt
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/options/test_options.py b/sd-webui-controlnet/annotator/leres/pix2pix/options/test_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3424b5e3b66d6813f74c8cecad691d7488d121c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/options/test_options.py
@@ -0,0 +1,22 @@
+from .base_options import BaseOptions
+
+
+class TestOptions(BaseOptions):
+ """This class includes test options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser) # define shared options
+ parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
+ # Dropout and Batchnorm has different behavioir during training and test.
+ parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
+ parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
+ # rewrite devalue values
+ parser.set_defaults(model='pix2pix4depth')
+ # To avoid cropping, the load_size should be the same as crop_size
+ parser.set_defaults(load_size=parser.get_default('crop_size'))
+ self.isTrain = False
+ return parser
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/util/__init__.py b/sd-webui-controlnet/annotator/leres/pix2pix/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae36f63d8859ec0c60dcbfe67c4ac324e751ddf7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/util/__init__.py
@@ -0,0 +1 @@
+"""This package includes a miscellaneous collection of useful helper functions."""
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/util/get_data.py b/sd-webui-controlnet/annotator/leres/pix2pix/util/get_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..97edc3ce3c3ab6d6080dca34e73a5fb77bb715fb
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/util/get_data.py
@@ -0,0 +1,110 @@
+from __future__ import print_function
+import os
+import tarfile
+import requests
+from warnings import warn
+from zipfile import ZipFile
+from bs4 import BeautifulSoup
+from os.path import abspath, isdir, join, basename
+
+
+class GetData(object):
+ """A Python script for downloading CycleGAN or pix2pix datasets.
+
+ Parameters:
+ technique (str) -- One of: 'cyclegan' or 'pix2pix'.
+ verbose (bool) -- If True, print additional information.
+
+ Examples:
+ >>> from util.get_data import GetData
+ >>> gd = GetData(technique='cyclegan')
+ >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
+
+ Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
+ and 'scripts/download_cyclegan_model.sh'.
+ """
+
+ def __init__(self, technique='cyclegan', verbose=True):
+ url_dict = {
+ 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
+ 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
+ }
+ self.url = url_dict.get(technique.lower())
+ self._verbose = verbose
+
+ def _print(self, text):
+ if self._verbose:
+ print(text)
+
+ @staticmethod
+ def _get_options(r):
+ soup = BeautifulSoup(r.text, 'lxml')
+ options = [h.text for h in soup.find_all('a', href=True)
+ if h.text.endswith(('.zip', 'tar.gz'))]
+ return options
+
+ def _present_options(self):
+ r = requests.get(self.url)
+ options = self._get_options(r)
+ print('Options:\n')
+ for i, o in enumerate(options):
+ print("{0}: {1}".format(i, o))
+ choice = input("\nPlease enter the number of the "
+ "dataset above you wish to download:")
+ return options[int(choice)]
+
+ def _download_data(self, dataset_url, save_path):
+ if not isdir(save_path):
+ os.makedirs(save_path)
+
+ base = basename(dataset_url)
+ temp_save_path = join(save_path, base)
+
+ with open(temp_save_path, "wb") as f:
+ r = requests.get(dataset_url)
+ f.write(r.content)
+
+ if base.endswith('.tar.gz'):
+ obj = tarfile.open(temp_save_path)
+ elif base.endswith('.zip'):
+ obj = ZipFile(temp_save_path, 'r')
+ else:
+ raise ValueError("Unknown File Type: {0}.".format(base))
+
+ self._print("Unpacking Data...")
+ obj.extractall(save_path)
+ obj.close()
+ os.remove(temp_save_path)
+
+ def get(self, save_path, dataset=None):
+ """
+
+ Download a dataset.
+
+ Parameters:
+ save_path (str) -- A directory to save the data to.
+ dataset (str) -- (optional). A specific dataset to download.
+ Note: this must include the file extension.
+ If None, options will be presented for you
+ to choose from.
+
+ Returns:
+ save_path_full (str) -- the absolute path to the downloaded data.
+
+ """
+ if dataset is None:
+ selected_dataset = self._present_options()
+ else:
+ selected_dataset = dataset
+
+ save_path_full = join(save_path, selected_dataset.split('.')[0])
+
+ if isdir(save_path_full):
+ warn("\n'{0}' already exists. Voiding Download.".format(
+ save_path_full))
+ else:
+ self._print('Downloading Data...')
+ url = "{0}/{1}".format(self.url, selected_dataset)
+ self._download_data(url, save_path=save_path)
+
+ return abspath(save_path_full)
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/util/guidedfilter.py b/sd-webui-controlnet/annotator/leres/pix2pix/util/guidedfilter.py
new file mode 100644
index 0000000000000000000000000000000000000000..d377ff12e078a5f156e9246b63573dae71825fad
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/util/guidedfilter.py
@@ -0,0 +1,47 @@
+import numpy as np
+
+class GuidedFilter():
+ def __init__(self, source, reference, r=64, eps= 0.05**2):
+ self.source = source;
+ self.reference = reference;
+ self.r = r
+ self.eps = eps
+
+ self.smooth = self.guidedfilter(self.source,self.reference,self.r,self.eps)
+
+ def boxfilter(self,img, r):
+ (rows, cols) = img.shape
+ imDst = np.zeros_like(img)
+
+ imCum = np.cumsum(img, 0)
+ imDst[0 : r+1, :] = imCum[r : 2*r+1, :]
+ imDst[r+1 : rows-r, :] = imCum[2*r+1 : rows, :] - imCum[0 : rows-2*r-1, :]
+ imDst[rows-r: rows, :] = np.tile(imCum[rows-1, :], [r, 1]) - imCum[rows-2*r-1 : rows-r-1, :]
+
+ imCum = np.cumsum(imDst, 1)
+ imDst[:, 0 : r+1] = imCum[:, r : 2*r+1]
+ imDst[:, r+1 : cols-r] = imCum[:, 2*r+1 : cols] - imCum[:, 0 : cols-2*r-1]
+ imDst[:, cols-r: cols] = np.tile(imCum[:, cols-1], [r, 1]).T - imCum[:, cols-2*r-1 : cols-r-1]
+
+ return imDst
+
+ def guidedfilter(self,I, p, r, eps):
+ (rows, cols) = I.shape
+ N = self.boxfilter(np.ones([rows, cols]), r)
+
+ meanI = self.boxfilter(I, r) / N
+ meanP = self.boxfilter(p, r) / N
+ meanIp = self.boxfilter(I * p, r) / N
+ covIp = meanIp - meanI * meanP
+
+ meanII = self.boxfilter(I * I, r) / N
+ varI = meanII - meanI * meanI
+
+ a = covIp / (varI + eps)
+ b = meanP - a * meanI
+
+ meanA = self.boxfilter(a, r) / N
+ meanB = self.boxfilter(b, r) / N
+
+ q = meanA * I + meanB
+ return q
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/util/html.py b/sd-webui-controlnet/annotator/leres/pix2pix/util/html.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc3262a1eafda34842e4dbad47bb6ba72f0c5a68
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/util/html.py
@@ -0,0 +1,86 @@
+import dominate
+from dominate.tags import meta, h3, table, tr, td, p, a, img, br
+import os
+
+
+class HTML:
+ """This HTML class allows us to save images and write texts into a single HTML file.
+
+ It consists of functions such as (add a text header to the HTML file),
+ (add a row of images to the HTML file), and (save the HTML to the disk).
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
+ """
+
+ def __init__(self, web_dir, title, refresh=0):
+ """Initialize the HTML classes
+
+ Parameters:
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
+ with self.doc.head:
+ meta(http_equiv="refresh", content=str(refresh))
+
+ def get_image_dir(self):
+ """Return the directory that stores images"""
+ return self.img_dir
+
+ def add_header(self, text):
+ """Insert a header to the HTML file
+
+ Parameters:
+ text (str) -- the header text
+ """
+ with self.doc:
+ h3(text)
+
+ def add_images(self, ims, txts, links, width=400):
+ """add images to the HTML file
+
+ Parameters:
+ ims (str list) -- a list of image paths
+ txts (str list) -- a list of image names shown on the website
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
+ """
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
+ self.doc.add(self.t)
+ with self.t:
+ with tr():
+ for im, txt, link in zip(ims, txts, links):
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
+ with p():
+ with a(href=os.path.join('images', link)):
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
+ br()
+ p(txt)
+
+ def save(self):
+ """save the current content to the HMTL file"""
+ html_file = '%s/index.html' % self.web_dir
+ f = open(html_file, 'wt')
+ f.write(self.doc.render())
+ f.close()
+
+
+if __name__ == '__main__': # we show an example usage here.
+ html = HTML('web/', 'test_html')
+ html.add_header('hello world')
+
+ ims, txts, links = [], [], []
+ for n in range(4):
+ ims.append('image_%d.png' % n)
+ txts.append('text_%d' % n)
+ links.append('image_%d.png' % n)
+ html.add_images(ims, txts, links)
+ html.save()
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/util/image_pool.py b/sd-webui-controlnet/annotator/leres/pix2pix/util/image_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d086f882bc3d1b90c529fce6cddaaa75f2005d7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/util/image_pool.py
@@ -0,0 +1,54 @@
+import random
+import torch
+
+
+class ImagePool():
+ """This class implements an image buffer that stores previously generated images.
+
+ This buffer enables us to update discriminators using a history of generated images
+ rather than the ones produced by the latest generators.
+ """
+
+ def __init__(self, pool_size):
+ """Initialize the ImagePool class
+
+ Parameters:
+ pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
+ """
+ self.pool_size = pool_size
+ if self.pool_size > 0: # create an empty pool
+ self.num_imgs = 0
+ self.images = []
+
+ def query(self, images):
+ """Return an image from the pool.
+
+ Parameters:
+ images: the latest generated images from the generator
+
+ Returns images from the buffer.
+
+ By 50/100, the buffer will return input images.
+ By 50/100, the buffer will return images previously stored in the buffer,
+ and insert the current images to the buffer.
+ """
+ if self.pool_size == 0: # if the buffer size is 0, do nothing
+ return images
+ return_images = []
+ for image in images:
+ image = torch.unsqueeze(image.data, 0)
+ if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
+ self.num_imgs = self.num_imgs + 1
+ self.images.append(image)
+ return_images.append(image)
+ else:
+ p = random.uniform(0, 1)
+ if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
+ tmp = self.images[random_id].clone()
+ self.images[random_id] = image
+ return_images.append(tmp)
+ else: # by another 50% chance, the buffer will return the current image
+ return_images.append(image)
+ return_images = torch.cat(return_images, 0) # collect all the images and return
+ return return_images
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/util/util.py b/sd-webui-controlnet/annotator/leres/pix2pix/util/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a7aceaa00681cb76675df7866bf8db58c8d2caf
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/util/util.py
@@ -0,0 +1,105 @@
+"""This module contains simple helper functions """
+from __future__ import print_function
+import torch
+import numpy as np
+from PIL import Image
+import os
+
+
+def tensor2im(input_image, imtype=np.uint16):
+ """"Converts a Tensor array into a numpy image array.
+
+ Parameters:
+ input_image (tensor) -- the input image tensor array
+ imtype (type) -- the desired type of the converted numpy array
+ """
+ if not isinstance(input_image, np.ndarray):
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
+ image_tensor = input_image.data
+ else:
+ return input_image
+ image_numpy = torch.squeeze(image_tensor).cpu().numpy() # convert it into a numpy array
+ image_numpy = (image_numpy + 1) / 2.0 * (2**16-1) #
+ else: # if it is a numpy array, do nothing
+ image_numpy = input_image
+ return image_numpy.astype(imtype)
+
+
+def diagnose_network(net, name='network'):
+ """Calculate and print the mean of average absolute(gradients)
+
+ Parameters:
+ net (torch network) -- Torch network
+ name (str) -- the name of the network
+ """
+ mean = 0.0
+ count = 0
+ for param in net.parameters():
+ if param.grad is not None:
+ mean += torch.mean(torch.abs(param.grad.data))
+ count += 1
+ if count > 0:
+ mean = mean / count
+ print(name)
+ print(mean)
+
+
+def save_image(image_numpy, image_path, aspect_ratio=1.0):
+ """Save a numpy image to the disk
+
+ Parameters:
+ image_numpy (numpy array) -- input numpy array
+ image_path (str) -- the path of the image
+ """
+ image_pil = Image.fromarray(image_numpy)
+
+ image_pil = image_pil.convert('I;16')
+
+ # image_pil = Image.fromarray(image_numpy)
+ # h, w, _ = image_numpy.shape
+ #
+ # if aspect_ratio > 1.0:
+ # image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
+ # if aspect_ratio < 1.0:
+ # image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
+
+ image_pil.save(image_path)
+
+
+def print_numpy(x, val=True, shp=False):
+ """Print the mean, min, max, median, std, and size of a numpy array
+
+ Parameters:
+ val (bool) -- if print the values of the numpy array
+ shp (bool) -- if print the shape of the numpy array
+ """
+ x = x.astype(np.float64)
+ if shp:
+ print('shape,', x.shape)
+ if val:
+ x = x.flatten()
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
+
+
+def mkdirs(paths):
+ """create empty directories if they don't exist
+
+ Parameters:
+ paths (str list) -- a list of directory paths
+ """
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+
+def mkdir(path):
+ """create a single empty directory if it didn't exist
+
+ Parameters:
+ path (str) -- a single directory path
+ """
+ if not os.path.exists(path):
+ os.makedirs(path)
diff --git a/sd-webui-controlnet/annotator/leres/pix2pix/util/visualizer.py b/sd-webui-controlnet/annotator/leres/pix2pix/util/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..810a0513ab997103ace77b665c9a17f223b173c9
--- /dev/null
+++ b/sd-webui-controlnet/annotator/leres/pix2pix/util/visualizer.py
@@ -0,0 +1,166 @@
+import numpy as np
+import os
+import sys
+import ntpath
+import time
+from . import util, html
+from subprocess import Popen, PIPE
+import torch
+
+
+if sys.version_info[0] == 2:
+ VisdomExceptionBase = Exception
+else:
+ VisdomExceptionBase = ConnectionError
+
+
+def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
+ """Save images to the disk.
+
+ Parameters:
+ webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
+ visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
+ image_path (str) -- the string is used to create image paths
+ aspect_ratio (float) -- the aspect ratio of saved images
+ width (int) -- the images will be resized to width x width
+
+ This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
+ """
+ image_dir = webpage.get_image_dir()
+ short_path = ntpath.basename(image_path[0])
+ name = os.path.splitext(short_path)[0]
+
+ webpage.add_header(name)
+ ims, txts, links = [], [], []
+
+ for label, im_data in visuals.items():
+ im = util.tensor2im(im_data)
+ image_name = '%s_%s.png' % (name, label)
+ save_path = os.path.join(image_dir, image_name)
+ util.save_image(im, save_path, aspect_ratio=aspect_ratio)
+ ims.append(image_name)
+ txts.append(label)
+ links.append(image_name)
+ webpage.add_images(ims, txts, links, width=width)
+
+
+class Visualizer():
+ """This class includes several functions that can display/save images and print/save logging information.
+
+ It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
+ """
+
+ def __init__(self, opt):
+ """Initialize the Visualizer class
+
+ Parameters:
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ Step 1: Cache the training/test options
+ Step 2: connect to a visdom server
+ Step 3: create an HTML object for saveing HTML filters
+ Step 4: create a logging file to store training losses
+ """
+ self.opt = opt # cache the option
+ self.display_id = opt.display_id
+ self.use_html = opt.isTrain and not opt.no_html
+ self.win_size = opt.display_winsize
+ self.name = opt.name
+ self.port = opt.display_port
+ self.saved = False
+
+ if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ print('create web directory %s...' % self.web_dir)
+ util.mkdirs([self.web_dir, self.img_dir])
+ # create a logging file to store training losses
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write('================ Training Loss (%s) ================\n' % now)
+
+ def reset(self):
+ """Reset the self.saved status"""
+ self.saved = False
+
+ def create_visdom_connections(self):
+ """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
+ cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
+ print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
+ print('Command: %s' % cmd)
+ Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
+
+ def display_current_results(self, visuals, epoch, save_result):
+ """Display current results on visdom; save current results to an HTML file.
+
+ Parameters:
+ visuals (OrderedDict) - - dictionary of images to display or save
+ epoch (int) - - the current epoch
+ save_result (bool) - - if save the current results to an HTML file
+ """
+ if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
+ self.saved = True
+ # save images to the disk
+ for label, image in visuals.items():
+ image_numpy = util.tensor2im(image)
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
+ util.save_image(image_numpy, img_path)
+
+ # update website
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
+ for n in range(epoch, 0, -1):
+ webpage.add_header('epoch [%d]' % n)
+ ims, txts, links = [], [], []
+
+ for label, image_numpy in visuals.items():
+ # image_numpy = util.tensor2im(image)
+ img_path = 'epoch%.3d_%s.png' % (n, label)
+ ims.append(img_path)
+ txts.append(label)
+ links.append(img_path)
+ webpage.add_images(ims, txts, links, width=self.win_size)
+ webpage.save()
+
+ # def plot_current_losses(self, epoch, counter_ratio, losses):
+ # """display the current losses on visdom display: dictionary of error labels and values
+ #
+ # Parameters:
+ # epoch (int) -- current epoch
+ # counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
+ # losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+ # """
+ # if not hasattr(self, 'plot_data'):
+ # self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
+ # self.plot_data['X'].append(epoch + counter_ratio)
+ # self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
+ # try:
+ # self.vis.line(
+ # X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
+ # Y=np.array(self.plot_data['Y']),
+ # opts={
+ # 'title': self.name + ' loss over time',
+ # 'legend': self.plot_data['legend'],
+ # 'xlabel': 'epoch',
+ # 'ylabel': 'loss'},
+ # win=self.display_id)
+ # except VisdomExceptionBase:
+ # self.create_visdom_connections()
+
+ # losses: same format as |losses| of plot_current_losses
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
+ """print current losses on console; also save the losses to the disk
+
+ Parameters:
+ epoch (int) -- current epoch
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+ t_comp (float) -- computational time per data point (normalized by batch_size)
+ t_data (float) -- data loading time per data point (normalized by batch_size)
+ """
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
+ for k, v in losses.items():
+ message += '%s: %.3f ' % (k, v)
+
+ print(message) # print the message
+ with open(self.log_name, "a") as log_file:
+ log_file.write('%s\n' % message) # save the message
diff --git a/sd-webui-controlnet/annotator/lineart/LICENSE b/sd-webui-controlnet/annotator/lineart/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..16a9d56a3d4c15e4f34ac5426459c58487b01520
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lineart/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Caroline Chan
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/lineart/__init__.py b/sd-webui-controlnet/annotator/lineart/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0e43c501f64aeb170cf933d06a63bd9dfd4f4e7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lineart/__init__.py
@@ -0,0 +1,133 @@
+import os
+import cv2
+import torch
+import numpy as np
+
+import torch.nn as nn
+from einops import rearrange
+from modules import devices
+from annotator.annotator_path import models_path
+
+
+norm_layer = nn.InstanceNorm2d
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_features):
+ super(ResidualBlock, self).__init__()
+
+ conv_block = [ nn.ReflectionPad2d(1),
+ nn.Conv2d(in_features, in_features, 3),
+ norm_layer(in_features),
+ nn.ReLU(inplace=True),
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(in_features, in_features, 3),
+ norm_layer(in_features)
+ ]
+
+ self.conv_block = nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ return x + self.conv_block(x)
+
+
+class Generator(nn.Module):
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
+ super(Generator, self).__init__()
+
+ # Initial convolution block
+ model0 = [ nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, 64, 7),
+ norm_layer(64),
+ nn.ReLU(inplace=True) ]
+ self.model0 = nn.Sequential(*model0)
+
+ # Downsampling
+ model1 = []
+ in_features = 64
+ out_features = in_features*2
+ for _ in range(2):
+ model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
+ norm_layer(out_features),
+ nn.ReLU(inplace=True) ]
+ in_features = out_features
+ out_features = in_features*2
+ self.model1 = nn.Sequential(*model1)
+
+ model2 = []
+ # Residual blocks
+ for _ in range(n_residual_blocks):
+ model2 += [ResidualBlock(in_features)]
+ self.model2 = nn.Sequential(*model2)
+
+ # Upsampling
+ model3 = []
+ out_features = in_features//2
+ for _ in range(2):
+ model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
+ norm_layer(out_features),
+ nn.ReLU(inplace=True) ]
+ in_features = out_features
+ out_features = in_features//2
+ self.model3 = nn.Sequential(*model3)
+
+ # Output layer
+ model4 = [ nn.ReflectionPad2d(3),
+ nn.Conv2d(64, output_nc, 7)]
+ if sigmoid:
+ model4 += [nn.Sigmoid()]
+
+ self.model4 = nn.Sequential(*model4)
+
+ def forward(self, x, cond=None):
+ out = self.model0(x)
+ out = self.model1(out)
+ out = self.model2(out)
+ out = self.model3(out)
+ out = self.model4(out)
+
+ return out
+
+
+class LineartDetector:
+ model_dir = os.path.join(models_path, "lineart")
+ model_default = 'sk_model.pth'
+ model_coarse = 'sk_model2.pth'
+
+ def __init__(self, model_name):
+ self.model = None
+ self.model_name = model_name
+ self.device = devices.get_device_for("controlnet")
+
+ def load_model(self, name):
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/" + name
+ model_path = os.path.join(self.model_dir, name)
+ if not os.path.exists(model_path):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=self.model_dir)
+ model = Generator(3, 1, 3)
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
+ model.eval()
+ self.model = model.to(self.device)
+
+ def unload_model(self):
+ if self.model is not None:
+ self.model.cpu()
+
+ def __call__(self, input_image):
+ if self.model is None:
+ self.load_model(self.model_name)
+ self.model.to(self.device)
+
+ assert input_image.ndim == 3
+ image = input_image
+ with torch.no_grad():
+ image = torch.from_numpy(image).float().to(self.device)
+ image = image / 255.0
+ image = rearrange(image, 'h w c -> 1 c h w')
+ line = self.model(image)[0][0]
+
+ line = line.cpu().numpy()
+ line = (line * 255.0).clip(0, 255).astype(np.uint8)
+
+ return line
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/lineart_anime/LICENSE b/sd-webui-controlnet/annotator/lineart_anime/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..16a9d56a3d4c15e4f34ac5426459c58487b01520
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lineart_anime/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Caroline Chan
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/lineart_anime/__init__.py b/sd-webui-controlnet/annotator/lineart_anime/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dddfa97930f27e7ffc7604a6da1f1a08d117ea3b
--- /dev/null
+++ b/sd-webui-controlnet/annotator/lineart_anime/__init__.py
@@ -0,0 +1,161 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import functools
+
+import os
+import cv2
+from einops import rearrange
+from modules import devices
+from annotator.annotator_path import models_path
+
+
+class UnetGenerator(nn.Module):
+ """Create a Unet-based generator"""
+
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
+ """Construct a Unet generator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
+ image of size 128x128 will become of size 1x1 # at the bottleneck
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ We construct the U-Net from the innermost layer to the outermost layer.
+ It is a recursive process.
+ """
+ super(UnetGenerator, self).__init__()
+ # construct unet structure
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
+ for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
+ # gradually reduce the number of filters from ngf * 8 to ngf
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
+
+ def forward(self, input):
+ """Standard forward"""
+ return self.model(input)
+
+
+class UnetSkipConnectionBlock(nn.Module):
+ """Defines the Unet submodule with skip connection.
+ X -------------------identity----------------------
+ |-- downsampling -- |submodule| -- upsampling --|
+ """
+
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
+ """Construct a Unet submodule with skip connections.
+ Parameters:
+ outer_nc (int) -- the number of filters in the outer conv layer
+ inner_nc (int) -- the number of filters in the inner conv layer
+ input_nc (int) -- the number of channels in input images/features
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
+ outermost (bool) -- if this module is the outermost module
+ innermost (bool) -- if this module is the innermost module
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers.
+ """
+ super(UnetSkipConnectionBlock, self).__init__()
+ self.outermost = outermost
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+ if input_nc is None:
+ input_nc = outer_nc
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
+ stride=2, padding=1, bias=use_bias)
+ downrelu = nn.LeakyReLU(0.2, True)
+ downnorm = norm_layer(inner_nc)
+ uprelu = nn.ReLU(True)
+ upnorm = norm_layer(outer_nc)
+
+ if outermost:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1)
+ down = [downconv]
+ up = [uprelu, upconv, nn.Tanh()]
+ model = down + [submodule] + up
+ elif innermost:
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv]
+ up = [uprelu, upconv, upnorm]
+ model = down + up
+ else:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv, downnorm]
+ up = [uprelu, upconv, upnorm]
+
+ if use_dropout:
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
+ else:
+ model = down + [submodule] + up
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ if self.outermost:
+ return self.model(x)
+ else: # add skip connections
+ return torch.cat([x, self.model(x)], 1)
+
+
+class LineartAnimeDetector:
+ model_dir = os.path.join(models_path, "lineart_anime")
+
+ def __init__(self):
+ self.model = None
+ self.device = devices.get_device_for("controlnet")
+
+ def load_model(self):
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/netG.pth"
+ modelpath = os.path.join(self.model_dir, "netG.pth")
+ if not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=self.model_dir)
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
+ net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
+ ckpt = torch.load(modelpath)
+ for key in list(ckpt.keys()):
+ if 'module.' in key:
+ ckpt[key.replace('module.', '')] = ckpt[key]
+ del ckpt[key]
+ net.load_state_dict(ckpt)
+ net.eval()
+ self.model = net.to(self.device)
+
+ def unload_model(self):
+ if self.model is not None:
+ self.model.cpu()
+
+ def __call__(self, input_image):
+ if self.model is None:
+ self.load_model()
+ self.model.to(self.device)
+
+ H, W, C = input_image.shape
+ Hn = 256 * int(np.ceil(float(H) / 256.0))
+ Wn = 256 * int(np.ceil(float(W) / 256.0))
+ img = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_CUBIC)
+ with torch.no_grad():
+ image_feed = torch.from_numpy(img).float().to(self.device)
+ image_feed = image_feed / 127.5 - 1.0
+ image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
+
+ line = self.model(image_feed)[0, 0] * 127.5 + 127.5
+ line = line.cpu().numpy()
+
+ line = cv2.resize(line, (W, H), interpolation=cv2.INTER_CUBIC)
+ line = line.clip(0, 255).astype(np.uint8)
+ return line
+
diff --git a/sd-webui-controlnet/annotator/manga_line/LICENSE b/sd-webui-controlnet/annotator/manga_line/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..9bad05450ca061904f97acebe04ff7183cfbdc1a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/manga_line/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2021 Miaomiao Li
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/sd-webui-controlnet/annotator/manga_line/__init__.py b/sd-webui-controlnet/annotator/manga_line/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c797fd22c563a775162c8fd646dfa50077e4ac16
--- /dev/null
+++ b/sd-webui-controlnet/annotator/manga_line/__init__.py
@@ -0,0 +1,248 @@
+import os
+import torch
+import torch.nn as nn
+from PIL import Image
+import fnmatch
+import cv2
+
+import sys
+
+import numpy as np
+from einops import rearrange
+from modules import devices
+from annotator.annotator_path import models_path
+
+
+class _bn_relu_conv(nn.Module):
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
+ super(_bn_relu_conv, self).__init__()
+ self.model = nn.Sequential(
+ nn.BatchNorm2d(in_filters, eps=1e-3),
+ nn.LeakyReLU(0.2),
+ nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros')
+ )
+
+ def forward(self, x):
+ return self.model(x)
+
+ # the following are for debugs
+ print("****", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape)
+ for i,layer in enumerate(self.model):
+ if i != 2:
+ x = layer(x)
+ else:
+ x = layer(x)
+ #x = nn.functional.pad(x, (1, 1, 1, 1), mode='constant', value=0)
+ print("____", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape)
+ print(x[0])
+ return x
+
+class _u_bn_relu_conv(nn.Module):
+ def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
+ super(_u_bn_relu_conv, self).__init__()
+ self.model = nn.Sequential(
+ nn.BatchNorm2d(in_filters, eps=1e-3),
+ nn.LeakyReLU(0.2),
+ nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)),
+ nn.Upsample(scale_factor=2, mode='nearest')
+ )
+
+ def forward(self, x):
+ return self.model(x)
+
+
+
+class _shortcut(nn.Module):
+ def __init__(self, in_filters, nb_filters, subsample=1):
+ super(_shortcut, self).__init__()
+ self.process = False
+ self.model = None
+ if in_filters != nb_filters or subsample != 1:
+ self.process = True
+ self.model = nn.Sequential(
+ nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample)
+ )
+
+ def forward(self, x, y):
+ #print(x.size(), y.size(), self.process)
+ if self.process:
+ y0 = self.model(x)
+ #print("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape)
+ return y0 + y
+ else:
+ #print("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape)
+ return x + y
+
+class _u_shortcut(nn.Module):
+ def __init__(self, in_filters, nb_filters, subsample):
+ super(_u_shortcut, self).__init__()
+ self.process = False
+ self.model = None
+ if in_filters != nb_filters:
+ self.process = True
+ self.model = nn.Sequential(
+ nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'),
+ nn.Upsample(scale_factor=2, mode='nearest')
+ )
+
+ def forward(self, x, y):
+ if self.process:
+ return self.model(x) + y
+ else:
+ return x + y
+
+
+class basic_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
+ super(basic_block, self).__init__()
+ self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
+ self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample)
+
+ def forward(self, x):
+ x1 = self.conv1(x)
+ x2 = self.residual(x1)
+ return self.shortcut(x, x2)
+
+class _u_basic_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, init_subsample=1):
+ super(_u_basic_block, self).__init__()
+ self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
+ self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
+ self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample)
+
+ def forward(self, x):
+ y = self.residual(self.conv1(x))
+ return self.shortcut(x, y)
+
+
+class _residual_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False):
+ super(_residual_block, self).__init__()
+ layers = []
+ for i in range(repetitions):
+ init_subsample = 1
+ if i == repetitions - 1 and not is_first_layer:
+ init_subsample = 2
+ if i == 0:
+ l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample)
+ else:
+ l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample)
+ layers.append(l)
+
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class _upsampling_residual_block(nn.Module):
+ def __init__(self, in_filters, nb_filters, repetitions):
+ super(_upsampling_residual_block, self).__init__()
+ layers = []
+ for i in range(repetitions):
+ l = None
+ if i == 0:
+ l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters)#(input)
+ else:
+ l = basic_block(in_filters=nb_filters, nb_filters=nb_filters)#(input)
+ layers.append(l)
+
+ self.model = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class res_skip(nn.Module):
+
+ def __init__(self):
+ super(res_skip, self).__init__()
+ self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True)#(input)
+ self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3)#(block0)
+ self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5)#(block1)
+ self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7)#(block2)
+ self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12)#(block3)
+
+ self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7)#(block4)
+ self.res1 = _shortcut(in_filters=192, nb_filters=192)#(block3, block5, subsample=(1,1))
+
+ self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5)#(res1)
+ self.res2 = _shortcut(in_filters=96, nb_filters=96)#(block2, block6, subsample=(1,1))
+
+ self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3)#(res2)
+ self.res3 = _shortcut(in_filters=48, nb_filters=48)#(block1, block7, subsample=(1,1))
+
+ self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2)#(res3)
+ self.res4 = _shortcut(in_filters=24, nb_filters=24)#(block0,block8, subsample=(1,1))
+
+ self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True)#(res4)
+ self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1)#(block7)
+
+ def forward(self, x):
+ x0 = self.block0(x)
+ x1 = self.block1(x0)
+ x2 = self.block2(x1)
+ x3 = self.block3(x2)
+ x4 = self.block4(x3)
+
+ x5 = self.block5(x4)
+ res1 = self.res1(x3, x5)
+
+ x6 = self.block6(res1)
+ res2 = self.res2(x2, x6)
+
+ x7 = self.block7(res2)
+ res3 = self.res3(x1, x7)
+
+ x8 = self.block8(res3)
+ res4 = self.res4(x0, x8)
+
+ x9 = self.block9(res4)
+ y = self.conv15(x9)
+
+ return y
+
+
+class MangaLineExtration:
+ model_dir = os.path.join(models_path, "manga_line")
+
+ def __init__(self):
+ self.model = None
+ self.device = devices.get_device_for("controlnet")
+
+ def load_model(self):
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/erika.pth"
+ modelpath = os.path.join(self.model_dir, "erika.pth")
+ if not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=self.model_dir)
+ #norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
+ net = res_skip()
+ ckpt = torch.load(modelpath)
+ for key in list(ckpt.keys()):
+ if 'module.' in key:
+ ckpt[key.replace('module.', '')] = ckpt[key]
+ del ckpt[key]
+ net.load_state_dict(ckpt)
+ net.eval()
+ self.model = net.to(self.device)
+
+ def unload_model(self):
+ if self.model is not None:
+ self.model.cpu()
+
+ def __call__(self, input_image):
+ if self.model is None:
+ self.load_model()
+ self.model.to(self.device)
+ img = cv2.cvtColor(input_image, cv2.COLOR_RGB2GRAY)
+ img = np.ascontiguousarray(img.copy()).copy()
+ with torch.no_grad():
+ image_feed = torch.from_numpy(img).float().to(self.device)
+ image_feed = rearrange(image_feed, 'h w -> 1 1 h w')
+ line = self.model(image_feed)
+ line = 255 - line.cpu().numpy()[0, 0]
+ return line.clip(0, 255).astype(np.uint8)
+
+
diff --git a/sd-webui-controlnet/annotator/mediapipe_face/__init__.py b/sd-webui-controlnet/annotator/mediapipe_face/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f74edfb187e4e39583ed92bfe69ea29c42a34ddc
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mediapipe_face/__init__.py
@@ -0,0 +1,5 @@
+from .mediapipe_face_common import generate_annotation
+
+
+def apply_mediapipe_face(image, max_faces: int = 1, min_confidence: float = 0.5):
+ return generate_annotation(image, max_faces, min_confidence)
diff --git a/sd-webui-controlnet/annotator/mediapipe_face/mediapipe_face_common.py b/sd-webui-controlnet/annotator/mediapipe_face/mediapipe_face_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f7d3701dc40eee88977f17a877fa800d0ae328d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mediapipe_face/mediapipe_face_common.py
@@ -0,0 +1,155 @@
+from typing import Mapping
+
+import mediapipe as mp
+import numpy
+
+
+mp_drawing = mp.solutions.drawing_utils
+mp_drawing_styles = mp.solutions.drawing_styles
+mp_face_detection = mp.solutions.face_detection # Only for counting faces.
+mp_face_mesh = mp.solutions.face_mesh
+mp_face_connections = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION
+mp_hand_connections = mp.solutions.hands_connections.HAND_CONNECTIONS
+mp_body_connections = mp.solutions.pose_connections.POSE_CONNECTIONS
+
+DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
+PoseLandmark = mp.solutions.drawing_styles.PoseLandmark
+
+min_face_size_pixels: int = 64
+f_thick = 2
+f_rad = 1
+right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
+right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
+right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
+left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
+left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
+left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
+mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad)
+head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
+
+# mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
+face_connection_spec = {}
+for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
+ face_connection_spec[edge] = head_draw
+for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
+ face_connection_spec[edge] = left_eye_draw
+for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
+ face_connection_spec[edge] = left_eyebrow_draw
+# for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
+# face_connection_spec[edge] = left_iris_draw
+for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
+ face_connection_spec[edge] = right_eye_draw
+for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
+ face_connection_spec[edge] = right_eyebrow_draw
+# for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
+# face_connection_spec[edge] = right_iris_draw
+for edge in mp_face_mesh.FACEMESH_LIPS:
+ face_connection_spec[edge] = mouth_draw
+iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
+
+
+def draw_pupils(image, landmark_list, drawing_spec, halfwidth: int = 2):
+ """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
+ landmarks. Until our PR is merged into mediapipe, we need this separate method."""
+ if len(image.shape) != 3:
+ raise ValueError("Input image must be H,W,C.")
+ image_rows, image_cols, image_channels = image.shape
+ if image_channels != 3: # BGR channels
+ raise ValueError('Input image must contain three channel bgr data.')
+ for idx, landmark in enumerate(landmark_list.landmark):
+ if (
+ (landmark.HasField('visibility') and landmark.visibility < 0.9) or
+ (landmark.HasField('presence') and landmark.presence < 0.5)
+ ):
+ continue
+ if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
+ continue
+ image_x = int(image_cols*landmark.x)
+ image_y = int(image_rows*landmark.y)
+ draw_color = None
+ if isinstance(drawing_spec, Mapping):
+ if drawing_spec.get(idx) is None:
+ continue
+ else:
+ draw_color = drawing_spec[idx].color
+ elif isinstance(drawing_spec, DrawingSpec):
+ draw_color = drawing_spec.color
+ image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color
+
+
+def reverse_channels(image):
+ """Given a numpy array in RGB form, convert to BGR. Will also convert from BGR to RGB."""
+ # im[:,:,::-1] is a neat hack to convert BGR to RGB by reversing the indexing order.
+ # im[:,:,::[2,1,0]] would also work but makes a copy of the data.
+ return image[:, :, ::-1]
+
+
+def generate_annotation(
+ img_rgb,
+ max_faces: int,
+ min_confidence: float
+):
+ """
+ Find up to 'max_faces' inside the provided input image.
+ If min_face_size_pixels is provided and nonzero it will be used to filter faces that occupy less than this many
+ pixels in the image.
+ """
+ with mp_face_mesh.FaceMesh(
+ static_image_mode=True,
+ max_num_faces=max_faces,
+ refine_landmarks=True,
+ min_detection_confidence=min_confidence,
+ ) as facemesh:
+ img_height, img_width, img_channels = img_rgb.shape
+ assert(img_channels == 3)
+
+ results = facemesh.process(img_rgb).multi_face_landmarks
+
+ if results is None:
+ print("No faces detected in controlnet image for Mediapipe face annotator.")
+ return numpy.zeros_like(img_rgb)
+
+ # Filter faces that are too small
+ filtered_landmarks = []
+ for lm in results:
+ landmarks = lm.landmark
+ face_rect = [
+ landmarks[0].x,
+ landmarks[0].y,
+ landmarks[0].x,
+ landmarks[0].y,
+ ] # Left, up, right, down.
+ for i in range(len(landmarks)):
+ face_rect[0] = min(face_rect[0], landmarks[i].x)
+ face_rect[1] = min(face_rect[1], landmarks[i].y)
+ face_rect[2] = max(face_rect[2], landmarks[i].x)
+ face_rect[3] = max(face_rect[3], landmarks[i].y)
+ if min_face_size_pixels > 0:
+ face_width = abs(face_rect[2] - face_rect[0])
+ face_height = abs(face_rect[3] - face_rect[1])
+ face_width_pixels = face_width * img_width
+ face_height_pixels = face_height * img_height
+ face_size = min(face_width_pixels, face_height_pixels)
+ if face_size >= min_face_size_pixels:
+ filtered_landmarks.append(lm)
+ else:
+ filtered_landmarks.append(lm)
+
+ # Annotations are drawn in BGR for some reason, but we don't need to flip a zero-filled image at the start.
+ empty = numpy.zeros_like(img_rgb)
+
+ # Draw detected faces:
+ for face_landmarks in filtered_landmarks:
+ mp_drawing.draw_landmarks(
+ empty,
+ face_landmarks,
+ connections=face_connection_spec.keys(),
+ landmark_drawing_spec=None,
+ connection_drawing_spec=face_connection_spec
+ )
+ draw_pupils(empty, face_landmarks, iris_landmark_spec, 2)
+
+ # Flip BGR back to RGB.
+ empty = reverse_channels(empty).copy()
+
+ return empty
diff --git a/sd-webui-controlnet/annotator/midas/LICENSE b/sd-webui-controlnet/annotator/midas/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/midas/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/sd-webui-controlnet/annotator/midas/__init__.py b/sd-webui-controlnet/annotator/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc247615fbdaeba9105512184ce39a5baab57b2b
--- /dev/null
+++ b/sd-webui-controlnet/annotator/midas/__init__.py
@@ -0,0 +1,49 @@
+import cv2
+import numpy as np
+import torch
+
+from einops import rearrange
+from .api import MiDaSInference
+from modules import devices
+
+model = None
+
+def unload_midas_model():
+ global model
+ if model is not None:
+ model = model.cpu()
+
+def apply_midas(input_image, a=np.pi * 2.0, bg_th=0.1):
+ global model
+ if model is None:
+ model = MiDaSInference(model_type="dpt_hybrid")
+ if devices.get_device_for("controlnet").type != 'mps':
+ model = model.to(devices.get_device_for("controlnet"))
+
+ assert input_image.ndim == 3
+ image_depth = input_image
+ with torch.no_grad():
+ image_depth = torch.from_numpy(image_depth).float()
+ if devices.get_device_for("controlnet").type != 'mps':
+ image_depth = image_depth.to(devices.get_device_for("controlnet"))
+ image_depth = image_depth / 127.5 - 1.0
+ image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
+ depth = model(image_depth)[0]
+
+ depth_pt = depth.clone()
+ depth_pt -= torch.min(depth_pt)
+ depth_pt /= torch.max(depth_pt)
+ depth_pt = depth_pt.cpu().numpy()
+ depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
+
+ depth_np = depth.cpu().numpy()
+ x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
+ y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
+ z = np.ones_like(x) * a
+ x[depth_pt < bg_th] = 0
+ y[depth_pt < bg_th] = 0
+ normal = np.stack([x, y, z], axis=2)
+ normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
+ normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)[:, :, ::-1]
+
+ return depth_image, normal_image
diff --git a/sd-webui-controlnet/annotator/midas/api.py b/sd-webui-controlnet/annotator/midas/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..72870381f336427b886dedef1c208c5f66c6f4cc
--- /dev/null
+++ b/sd-webui-controlnet/annotator/midas/api.py
@@ -0,0 +1,181 @@
+# based on https://github.com/isl-org/MiDaS
+
+import cv2
+import torch
+import torch.nn as nn
+import os
+from annotator.annotator_path import models_path
+
+from torchvision.transforms import Compose
+
+from .midas.dpt_depth import DPTDepthModel
+from .midas.midas_net import MidasNet
+from .midas.midas_net_custom import MidasNet_small
+from .midas.transforms import Resize, NormalizeImage, PrepareForNet
+
+base_model_path = os.path.join(models_path, "midas")
+old_modeldir = os.path.dirname(os.path.realpath(__file__))
+remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
+
+ISL_PATHS = {
+ "dpt_large": os.path.join(base_model_path, "dpt_large-midas-2f21e586.pt"),
+ "dpt_hybrid": os.path.join(base_model_path, "dpt_hybrid-midas-501f0c75.pt"),
+ "midas_v21": "",
+ "midas_v21_small": "",
+}
+
+OLD_ISL_PATHS = {
+ "dpt_large": os.path.join(old_modeldir, "dpt_large-midas-2f21e586.pt"),
+ "dpt_hybrid": os.path.join(old_modeldir, "dpt_hybrid-midas-501f0c75.pt"),
+ "midas_v21": "",
+ "midas_v21_small": "",
+}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def load_midas_transform(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load transform only
+ if model_type == "dpt_large": # DPT-Large
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ elif model_type == "midas_v21_small":
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ else:
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return transform
+
+
+def load_model(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load network
+ model_path = ISL_PATHS[model_type]
+ old_model_path = OLD_ISL_PATHS[model_type]
+ if model_type == "dpt_large": # DPT-Large
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitl16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ if os.path.exists(old_model_path):
+ model_path = old_model_path
+ elif not os.path.exists(model_path):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=base_model_path)
+
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitb_rn50_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ model = MidasNet(model_path, non_negative=True)
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ elif model_type == "midas_v21_small":
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+ non_negative=True, blocks={'expand': True})
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ else:
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
+ assert False
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return model.eval(), transform
+
+
+class MiDaSInference(nn.Module):
+ MODEL_TYPES_TORCH_HUB = [
+ "DPT_Large",
+ "DPT_Hybrid",
+ "MiDaS_small"
+ ]
+ MODEL_TYPES_ISL = [
+ "dpt_large",
+ "dpt_hybrid",
+ "midas_v21",
+ "midas_v21_small",
+ ]
+
+ def __init__(self, model_type):
+ super().__init__()
+ assert (model_type in self.MODEL_TYPES_ISL)
+ model, _ = load_model(model_type)
+ self.model = model
+ self.model.train = disabled_train
+
+ def forward(self, x):
+ with torch.no_grad():
+ prediction = self.model(x)
+ return prediction
+
diff --git a/sd-webui-controlnet/annotator/midas/midas/__init__.py b/sd-webui-controlnet/annotator/midas/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sd-webui-controlnet/annotator/midas/midas/base_model.py b/sd-webui-controlnet/annotator/midas/midas/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd
--- /dev/null
+++ b/sd-webui-controlnet/annotator/midas/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class BaseModel(torch.nn.Module):
+ def load(self, path):
+ """Load model from file.
+
+ Args:
+ path (str): file path
+ """
+ parameters = torch.load(path, map_location=torch.device('cpu'))
+
+ if "optimizer" in parameters:
+ parameters = parameters["model"]
+
+ self.load_state_dict(parameters)
diff --git a/sd-webui-controlnet/annotator/midas/midas/blocks.py b/sd-webui-controlnet/annotator/midas/midas/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78
--- /dev/null
+++ b/sd-webui-controlnet/annotator/midas/midas/blocks.py
@@ -0,0 +1,342 @@
+import torch
+import torch.nn as nn
+
+from .vit import (
+ _make_pretrained_vitb_rn50_384,
+ _make_pretrained_vitl16_384,
+ _make_pretrained_vitb16_384,
+ forward_vit,
+)
+
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
+ if backbone == "vitl16_384":
+ pretrained = _make_pretrained_vitl16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb_rn50_384":
+ pretrained = _make_pretrained_vitb_rn50_384(
+ use_pretrained,
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
+ scratch = _make_scratch(
+ [256, 512, 768, 768], features, groups=groups, expand=expand
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb16_384":
+ pretrained = _make_pretrained_vitb16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
+ elif backbone == "resnext101_wsl":
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
+ elif backbone == "efficientnet_lite3":
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
+ else:
+ print(f"Backbone '{backbone}' not implemented")
+ assert False
+
+ return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ out_shape4 = out_shape
+ if expand==True:
+ out_shape1 = out_shape
+ out_shape2 = out_shape*2
+ out_shape3 = out_shape*4
+ out_shape4 = out_shape*8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+
+ return scratch
+
+
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+ efficientnet = torch.hub.load(
+ "rwightman/gen-efficientnet-pytorch",
+ "tf_efficientnet_lite3",
+ pretrained=use_pretrained,
+ exportable=exportable
+ )
+ return _make_efficientnet_backbone(efficientnet)
+
+
+def _make_efficientnet_backbone(effnet):
+ pretrained = nn.Module()
+
+ pretrained.layer1 = nn.Sequential(
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+ )
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+
+ return pretrained
+
+
+def _make_resnet_backbone(resnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+ )
+
+ pretrained.layer2 = resnet.layer2
+ pretrained.layer3 = resnet.layer3
+ pretrained.layer4 = resnet.layer4
+
+ return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+ return _make_resnet_backbone(resnet)
+
+
+
+class Interpolate(nn.Module):
+ """Interpolation module.
+ """
+
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+ )
+
+ return x
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+
+ return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ output += self.resConfUnit1(xs[1])
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=True
+ )
+
+ return output
+
+
+
+
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ if self.bn==True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn==True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn==True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+ # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+ # output += res
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+ )
+
+ output = self.out_conv(output)
+
+ return output
+
diff --git a/sd-webui-controlnet/annotator/midas/midas/dpt_depth.py b/sd-webui-controlnet/annotator/midas/midas/dpt_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1
--- /dev/null
+++ b/sd-webui-controlnet/annotator/midas/midas/dpt_depth.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .base_model import BaseModel
+from .blocks import (
+ FeatureFusionBlock,
+ FeatureFusionBlock_custom,
+ Interpolate,
+ _make_encoder,
+ forward_vit,
+)
+
+
+def _make_fusion_block(features, use_bn):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ )
+
+
+class DPT(BaseModel):
+ def __init__(
+ self,
+ head,
+ features=256,
+ backbone="vitb_rn50_384",
+ readout="project",
+ channels_last=False,
+ use_bn=False,
+ ):
+
+ super(DPT, self).__init__()
+
+ self.channels_last = channels_last
+
+ hooks = {
+ "vitb_rn50_384": [0, 1, 8, 11],
+ "vitb16_384": [2, 5, 8, 11],
+ "vitl16_384": [5, 11, 17, 23],
+ }
+
+ # Instantiate backbone and reassemble blocks
+ self.pretrained, self.scratch = _make_encoder(
+ backbone,
+ features,
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
+ groups=1,
+ expand=False,
+ exportable=False,
+ hooks=hooks[backbone],
+ use_readout=readout,
+ )
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ self.scratch.output_conv = head
+
+
+ def forward(self, x):
+ if self.channels_last == True:
+ x.contiguous(memory_format=torch.channels_last)
+
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return out
+
+
+class DPTDepthModel(DPT):
+ def __init__(self, path=None, non_negative=True, **kwargs):
+ features = kwargs["features"] if "features" in kwargs else 256
+
+ head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ super().__init__(head, **kwargs)
+
+ if path is not None:
+ self.load(path)
+
+ def forward(self, x):
+ return super().forward(x).squeeze(dim=1)
+
diff --git a/sd-webui-controlnet/annotator/midas/midas/midas_net.py b/sd-webui-controlnet/annotator/midas/midas/midas_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/midas/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+
+
+class MidasNet(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=256, non_negative=True):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet, self).__init__()
+
+ use_pretrained = False if path is None else True
+
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
diff --git a/sd-webui-controlnet/annotator/midas/midas/midas_net_custom.py b/sd-webui-controlnet/annotator/midas/midas/midas_net_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/midas/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+
+
+class MidasNet_small(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+ blocks={'expand': True}):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet_small, self).__init__()
+
+ use_pretrained = False if path else True
+
+ self.channels_last = channels_last
+ self.blocks = blocks
+ self.backbone = backbone
+
+ self.groups = 1
+
+ features1=features
+ features2=features
+ features3=features
+ features4=features
+ self.expand = False
+ if "expand" in self.blocks and self.blocks['expand'] == True:
+ self.expand = True
+ features1=features
+ features2=features*2
+ features3=features*4
+ features4=features*8
+
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+
+ self.scratch.activation = nn.ReLU(False)
+
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+ self.scratch.activation,
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+ if self.channels_last==True:
+ print("self.channels_last = ", self.channels_last)
+ x.contiguous(memory_format=torch.channels_last)
+
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
+
+
+
+def fuse_model(m):
+ prev_previous_type = nn.Identity()
+ prev_previous_name = ''
+ previous_type = nn.Identity()
+ previous_name = ''
+ for name, module in m.named_modules():
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+ # print("FUSED ", prev_previous_name, previous_name, name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+ # print("FUSED ", prev_previous_name, previous_name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+ # print("FUSED ", previous_name, name)
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+
+ prev_previous_type = previous_type
+ prev_previous_name = previous_name
+ previous_type = type(module)
+ previous_name = name
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/midas/midas/transforms.py b/sd-webui-controlnet/annotator/midas/midas/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9
--- /dev/null
+++ b/sd-webui-controlnet/annotator/midas/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+
+ scale = max(scale)
+
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return tuple(shape)
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ if "disparity" in sample:
+ disparity = sample["disparity"].astype(np.float32)
+ sample["disparity"] = np.ascontiguousarray(disparity)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ return sample
diff --git a/sd-webui-controlnet/annotator/midas/midas/vit.py b/sd-webui-controlnet/annotator/midas/midas/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74
--- /dev/null
+++ b/sd-webui-controlnet/annotator/midas/midas/vit.py
@@ -0,0 +1,491 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+
+
+class Slice(nn.Module):
+ def __init__(self, start_index=1):
+ super(Slice, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ return x[:, self.start_index :]
+
+
+class AddReadout(nn.Module):
+ def __init__(self, start_index=1):
+ super(AddReadout, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ if self.start_index == 2:
+ readout = (x[:, 0] + x[:, 1]) / 2
+ else:
+ readout = x[:, 0]
+ return x[:, self.start_index :] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+ def __init__(self, in_features, start_index=1):
+ super(ProjectReadout, self).__init__()
+ self.start_index = start_index
+
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+ def forward(self, x):
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+ features = torch.cat((x[:, self.start_index :], readout), -1)
+
+ return self.project(features)
+
+
+class Transpose(nn.Module):
+ def __init__(self, dim0, dim1):
+ super(Transpose, self).__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x):
+ x = x.transpose(self.dim0, self.dim1)
+ return x
+
+
+def forward_vit(pretrained, x):
+ b, c, h, w = x.shape
+
+ glob = pretrained.model.forward_flex(x)
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+ unflatten = nn.Sequential(
+ nn.Unflatten(
+ 2,
+ torch.Size(
+ [
+ h // pretrained.model.patch_size[1],
+ w // pretrained.model.patch_size[0],
+ ]
+ ),
+ )
+ )
+
+ if layer_1.ndim == 3:
+ layer_1 = unflatten(layer_1)
+ if layer_2.ndim == 3:
+ layer_2 = unflatten(layer_2)
+ if layer_3.ndim == 3:
+ layer_3 = unflatten(layer_3)
+ if layer_4.ndim == 3:
+ layer_4 = unflatten(layer_4)
+
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+
+ return layer_1, layer_2, layer_3, layer_4
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+ posemb_tok, posemb_grid = (
+ posemb[:, : self.start_index],
+ posemb[0, self.start_index :],
+ )
+
+ gs_old = int(math.sqrt(len(posemb_grid)))
+
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+
+def forward_flex(self, x):
+ b, c, h, w = x.shape
+
+ pos_embed = self._resize_pos_embed(
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+ )
+
+ B = x.shape[0]
+
+ if hasattr(self.patch_embed, "backbone"):
+ x = self.patch_embed.backbone(x)
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+
+ if getattr(self, "dist_token", None) is not None:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ else:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x
+
+
+activations = {}
+
+
+def get_activation(name):
+ def hook(model, input, output):
+ activations[name] = output
+
+ return hook
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+ if use_readout == "ignore":
+ readout_oper = [Slice(start_index)] * len(features)
+ elif use_readout == "add":
+ readout_oper = [AddReadout(start_index)] * len(features)
+ elif use_readout == "project":
+ readout_oper = [
+ ProjectReadout(vit_features, start_index) for out_feat in features
+ ]
+ else:
+ assert (
+ False
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+ return readout_oper
+
+
+def _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ # 32, 48, 136, 384
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+
+
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model(
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+ )
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ start_index=2,
+ )
+
+
+def _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=[0, 1, 8, 11],
+ vit_features=768,
+ use_vit_only=False,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+
+ if use_vit_only == True:
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ else:
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+ get_activation("1")
+ )
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+ get_activation("2")
+ )
+
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ if use_vit_only == True:
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ else:
+ pretrained.act_postprocess1 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+):
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
+ return _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
diff --git a/sd-webui-controlnet/annotator/midas/utils.py b/sd-webui-controlnet/annotator/midas/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59
--- /dev/null
+++ b/sd-webui-controlnet/annotator/midas/utils.py
@@ -0,0 +1,189 @@
+"""Utils for monoDepth."""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+
+
+def read_pfm(path):
+ """Read pfm file.
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ tuple: (data, scale)
+ """
+ with open(path, "rb") as file:
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header.decode("ascii") == "PF":
+ color = True
+ elif header.decode("ascii") == "Pf":
+ color = False
+ else:
+ raise Exception("Not a PFM file: " + path)
+
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+ if dim_match:
+ width, height = list(map(int, dim_match.groups()))
+ else:
+ raise Exception("Malformed PFM header.")
+
+ scale = float(file.readline().decode("ascii").rstrip())
+ if scale < 0:
+ # little-endian
+ endian = "<"
+ scale = -scale
+ else:
+ # big-endian
+ endian = ">"
+
+ data = np.fromfile(file, endian + "f")
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+
+ return data, scale
+
+
+def write_pfm(path, image, scale=1):
+ """Write pfm file.
+
+ Args:
+ path (str): pathto file
+ image (array): data
+ scale (int, optional): Scale. Defaults to 1.
+ """
+
+ with open(path, "wb") as file:
+ color = None
+
+ if image.dtype.name != "float32":
+ raise Exception("Image dtype must be float32.")
+
+ image = np.flipud(image)
+
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif (
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+ ): # greyscale
+ color = False
+ else:
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+
+ file.write("PF\n" if color else "Pf\n".encode())
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+
+ endian = image.dtype.byteorder
+
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
+ scale = -scale
+
+ file.write("%f\n".encode() % scale)
+
+ image.tofile(file)
+
+
+def read_image(path):
+ """Read image and output RGB image (0-1).
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ array: RGB image (0-1)
+ """
+ img = cv2.imread(path)
+
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+
+ return img
+
+
+def resize_image(img):
+ """Resize image and make it fit for network.
+
+ Args:
+ img (array): image
+
+ Returns:
+ tensor: data ready for network
+ """
+ height_orig = img.shape[0]
+ width_orig = img.shape[1]
+
+ if width_orig > height_orig:
+ scale = width_orig / 384
+ else:
+ scale = height_orig / 384
+
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+
+ img_resized = (
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+ )
+ img_resized = img_resized.unsqueeze(0)
+
+ return img_resized
+
+
+def resize_depth(depth, width, height):
+ """Resize depth map and bring to CPU (numpy).
+
+ Args:
+ depth (tensor): depth
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ array: processed depth
+ """
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+
+ depth_resized = cv2.resize(
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+ )
+
+ return depth_resized
+
+def write_depth(path, depth, bits=1):
+ """Write depth map to pfm and png file.
+
+ Args:
+ path (str): filepath without extension
+ depth (array): depth
+ """
+ write_pfm(path + ".pfm", depth.astype(np.float32))
+
+ depth_min = depth.min()
+ depth_max = depth.max()
+
+ max_val = (2**(8*bits))-1
+
+ if depth_max - depth_min > np.finfo("float").eps:
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
+ else:
+ out = np.zeros(depth.shape, dtype=depth.type)
+
+ if bits == 1:
+ cv2.imwrite(path + ".png", out.astype("uint8"))
+ elif bits == 2:
+ cv2.imwrite(path + ".png", out.astype("uint16"))
+
+ return
diff --git a/sd-webui-controlnet/annotator/mlsd/LICENSE b/sd-webui-controlnet/annotator/mlsd/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..d855c6db44b4e873eedd750d34fa2eaf22e22363
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mlsd/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2021-present NAVER Corp.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/mlsd/__init__.py b/sd-webui-controlnet/annotator/mlsd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9791e3f78622f1e669df7e420ffd1cc7a0a4ec4
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mlsd/__init__.py
@@ -0,0 +1,49 @@
+import cv2
+import numpy as np
+import torch
+import os
+
+from einops import rearrange
+from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny
+from .models.mbv2_mlsd_large import MobileV2_MLSD_Large
+from .utils import pred_lines
+from modules import devices
+from annotator.annotator_path import models_path
+
+mlsdmodel = None
+remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/mlsd_large_512_fp32.pth"
+old_modeldir = os.path.dirname(os.path.realpath(__file__))
+modeldir = os.path.join(models_path, "mlsd")
+
+def unload_mlsd_model():
+ global mlsdmodel
+ if mlsdmodel is not None:
+ mlsdmodel = mlsdmodel.cpu()
+
+def apply_mlsd(input_image, thr_v, thr_d):
+ global modelpath, mlsdmodel
+ if mlsdmodel is None:
+ modelpath = os.path.join(modeldir, "mlsd_large_512_fp32.pth")
+ old_modelpath = os.path.join(old_modeldir, "mlsd_large_512_fp32.pth")
+ if os.path.exists(old_modelpath):
+ modelpath = old_modelpath
+ elif not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=modeldir)
+ mlsdmodel = MobileV2_MLSD_Large()
+ mlsdmodel.load_state_dict(torch.load(modelpath), strict=True)
+ mlsdmodel = mlsdmodel.to(devices.get_device_for("controlnet")).eval()
+
+ model = mlsdmodel
+ assert input_image.ndim == 3
+ img = input_image
+ img_output = np.zeros_like(img)
+ try:
+ with torch.no_grad():
+ lines = pred_lines(img, model, [img.shape[0], img.shape[1]], thr_v, thr_d)
+ for line in lines:
+ x_start, y_start, x_end, y_end = [int(val) for val in line]
+ cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
+ except Exception as e:
+ pass
+ return img_output[:, :, 0]
diff --git a/sd-webui-controlnet/annotator/mlsd/models/mbv2_mlsd_large.py b/sd-webui-controlnet/annotator/mlsd/models/mbv2_mlsd_large.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b9799e7573ca41549b3c3b13ac47b906b369603
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mlsd/models/mbv2_mlsd_large.py
@@ -0,0 +1,292 @@
+import os
+import sys
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+from torch.nn import functional as F
+
+
+class BlockTypeA(nn.Module):
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
+ super(BlockTypeA, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c2, out_c2, kernel_size=1),
+ nn.BatchNorm2d(out_c2),
+ nn.ReLU(inplace=True)
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c1, out_c1, kernel_size=1),
+ nn.BatchNorm2d(out_c1),
+ nn.ReLU(inplace=True)
+ )
+ self.upscale = upscale
+
+ def forward(self, a, b):
+ b = self.conv1(b)
+ a = self.conv2(a)
+ if self.upscale:
+ b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
+ return torch.cat((a, b), dim=1)
+
+
+class BlockTypeB(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeB, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(out_c),
+ nn.ReLU()
+ )
+
+ def forward(self, x):
+ x = self.conv1(x) + x
+ x = self.conv2(x)
+ return x
+
+class BlockTypeC(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeC, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ return x
+
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class ConvBNReLU(nn.Sequential):
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
+ self.channel_pad = out_planes - in_planes
+ self.stride = stride
+ #padding = (kernel_size - 1) // 2
+
+ # TFLite uses slightly different padding than PyTorch
+ if stride == 2:
+ padding = 0
+ else:
+ padding = (kernel_size - 1) // 2
+
+ super(ConvBNReLU, self).__init__(
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+ nn.BatchNorm2d(out_planes),
+ nn.ReLU6(inplace=True)
+ )
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
+
+
+ def forward(self, x):
+ # TFLite uses different padding
+ if self.stride == 2:
+ x = F.pad(x, (0, 1, 0, 1), "constant", 0)
+ #print(x.shape)
+
+ for module in self:
+ if not isinstance(module, nn.MaxPool2d):
+ x = module(x)
+ return x
+
+
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expand_ratio):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ hidden_dim = int(round(inp * expand_ratio))
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ layers = []
+ if expand_ratio != 1:
+ # pw
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+ layers.extend([
+ # dw
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ ])
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(self, pretrained=True):
+ """
+ MobileNet V2 main class
+ Args:
+ num_classes (int): Number of classes
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+ inverted_residual_setting: Network structure
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+ Set to 1 to turn off rounding
+ block: Module specifying inverted residual building block for mobilenet
+ """
+ super(MobileNetV2, self).__init__()
+
+ block = InvertedResidual
+ input_channel = 32
+ last_channel = 1280
+ width_mult = 1.0
+ round_nearest = 8
+
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ [6, 96, 3, 1],
+ #[6, 160, 3, 2],
+ #[6, 320, 1, 1],
+ ]
+
+ # only check the first element, assuming user knows t,c,n,s are required
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+ raise ValueError("inverted_residual_setting should be non-empty "
+ "or a 4-element list, got {}".format(inverted_residual_setting))
+
+ # building first layer
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+ features = [ConvBNReLU(4, input_channel, stride=2)]
+ # building inverted residual blocks
+ for t, c, n, s in inverted_residual_setting:
+ output_channel = _make_divisible(c * width_mult, round_nearest)
+ for i in range(n):
+ stride = s if i == 0 else 1
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t))
+ input_channel = output_channel
+
+ self.features = nn.Sequential(*features)
+ self.fpn_selected = [1, 3, 6, 10, 13]
+ # weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+ if pretrained:
+ self._load_pretrained_model()
+
+ def _forward_impl(self, x):
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+ fpn_features = []
+ for i, f in enumerate(self.features):
+ if i > self.fpn_selected[-1]:
+ break
+ x = f(x)
+ if i in self.fpn_selected:
+ fpn_features.append(x)
+
+ c1, c2, c3, c4, c5 = fpn_features
+ return c1, c2, c3, c4, c5
+
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+ def _load_pretrained_model(self):
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
+ model_dict = {}
+ state_dict = self.state_dict()
+ for k, v in pretrain_dict.items():
+ if k in state_dict:
+ model_dict[k] = v
+ state_dict.update(model_dict)
+ self.load_state_dict(state_dict)
+
+
+class MobileV2_MLSD_Large(nn.Module):
+ def __init__(self):
+ super(MobileV2_MLSD_Large, self).__init__()
+
+ self.backbone = MobileNetV2(pretrained=False)
+ ## A, B
+ self.block15 = BlockTypeA(in_c1= 64, in_c2= 96,
+ out_c1= 64, out_c2=64,
+ upscale=False)
+ self.block16 = BlockTypeB(128, 64)
+
+ ## A, B
+ self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64,
+ out_c1= 64, out_c2= 64)
+ self.block18 = BlockTypeB(128, 64)
+
+ ## A, B
+ self.block19 = BlockTypeA(in_c1=24, in_c2=64,
+ out_c1=64, out_c2=64)
+ self.block20 = BlockTypeB(128, 64)
+
+ ## A, B, C
+ self.block21 = BlockTypeA(in_c1=16, in_c2=64,
+ out_c1=64, out_c2=64)
+ self.block22 = BlockTypeB(128, 64)
+
+ self.block23 = BlockTypeC(64, 16)
+
+ def forward(self, x):
+ c1, c2, c3, c4, c5 = self.backbone(x)
+
+ x = self.block15(c4, c5)
+ x = self.block16(x)
+
+ x = self.block17(c3, x)
+ x = self.block18(x)
+
+ x = self.block19(c2, x)
+ x = self.block20(x)
+
+ x = self.block21(c1, x)
+ x = self.block22(x)
+ x = self.block23(x)
+ x = x[:, 7:, :, :]
+
+ return x
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/mlsd/models/mbv2_mlsd_tiny.py b/sd-webui-controlnet/annotator/mlsd/models/mbv2_mlsd_tiny.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3ed633f2cc23ea1829a627fdb879ab39f641f83
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mlsd/models/mbv2_mlsd_tiny.py
@@ -0,0 +1,275 @@
+import os
+import sys
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+from torch.nn import functional as F
+
+
+class BlockTypeA(nn.Module):
+ def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
+ super(BlockTypeA, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c2, out_c2, kernel_size=1),
+ nn.BatchNorm2d(out_c2),
+ nn.ReLU(inplace=True)
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c1, out_c1, kernel_size=1),
+ nn.BatchNorm2d(out_c1),
+ nn.ReLU(inplace=True)
+ )
+ self.upscale = upscale
+
+ def forward(self, a, b):
+ b = self.conv1(b)
+ a = self.conv2(a)
+ b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
+ return torch.cat((a, b), dim=1)
+
+
+class BlockTypeB(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeB, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(out_c),
+ nn.ReLU()
+ )
+
+ def forward(self, x):
+ x = self.conv1(x) + x
+ x = self.conv2(x)
+ return x
+
+class BlockTypeC(nn.Module):
+ def __init__(self, in_c, out_c):
+ super(BlockTypeC, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
+ nn.BatchNorm2d(in_c),
+ nn.ReLU()
+ )
+ self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ return x
+
+def _make_divisible(v, divisor, min_value=None):
+ """
+ This function is taken from the original tf repo.
+ It ensures that all layers have a channel number that is divisible by 8
+ It can be seen here:
+ https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+ :param v:
+ :param divisor:
+ :param min_value:
+ :return:
+ """
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class ConvBNReLU(nn.Sequential):
+ def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
+ self.channel_pad = out_planes - in_planes
+ self.stride = stride
+ #padding = (kernel_size - 1) // 2
+
+ # TFLite uses slightly different padding than PyTorch
+ if stride == 2:
+ padding = 0
+ else:
+ padding = (kernel_size - 1) // 2
+
+ super(ConvBNReLU, self).__init__(
+ nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+ nn.BatchNorm2d(out_planes),
+ nn.ReLU6(inplace=True)
+ )
+ self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
+
+
+ def forward(self, x):
+ # TFLite uses different padding
+ if self.stride == 2:
+ x = F.pad(x, (0, 1, 0, 1), "constant", 0)
+ #print(x.shape)
+
+ for module in self:
+ if not isinstance(module, nn.MaxPool2d):
+ x = module(x)
+ return x
+
+
+class InvertedResidual(nn.Module):
+ def __init__(self, inp, oup, stride, expand_ratio):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2]
+
+ hidden_dim = int(round(inp * expand_ratio))
+ self.use_res_connect = self.stride == 1 and inp == oup
+
+ layers = []
+ if expand_ratio != 1:
+ # pw
+ layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+ layers.extend([
+ # dw
+ ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
+ # pw-linear
+ nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ ])
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(self, pretrained=True):
+ """
+ MobileNet V2 main class
+ Args:
+ num_classes (int): Number of classes
+ width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+ inverted_residual_setting: Network structure
+ round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+ Set to 1 to turn off rounding
+ block: Module specifying inverted residual building block for mobilenet
+ """
+ super(MobileNetV2, self).__init__()
+
+ block = InvertedResidual
+ input_channel = 32
+ last_channel = 1280
+ width_mult = 1.0
+ round_nearest = 8
+
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ #[6, 96, 3, 1],
+ #[6, 160, 3, 2],
+ #[6, 320, 1, 1],
+ ]
+
+ # only check the first element, assuming user knows t,c,n,s are required
+ if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+ raise ValueError("inverted_residual_setting should be non-empty "
+ "or a 4-element list, got {}".format(inverted_residual_setting))
+
+ # building first layer
+ input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+ self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+ features = [ConvBNReLU(4, input_channel, stride=2)]
+ # building inverted residual blocks
+ for t, c, n, s in inverted_residual_setting:
+ output_channel = _make_divisible(c * width_mult, round_nearest)
+ for i in range(n):
+ stride = s if i == 0 else 1
+ features.append(block(input_channel, output_channel, stride, expand_ratio=t))
+ input_channel = output_channel
+ self.features = nn.Sequential(*features)
+
+ self.fpn_selected = [3, 6, 10]
+ # weight initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+
+ #if pretrained:
+ # self._load_pretrained_model()
+
+ def _forward_impl(self, x):
+ # This exists since TorchScript doesn't support inheritance, so the superclass method
+ # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+ fpn_features = []
+ for i, f in enumerate(self.features):
+ if i > self.fpn_selected[-1]:
+ break
+ x = f(x)
+ if i in self.fpn_selected:
+ fpn_features.append(x)
+
+ c2, c3, c4 = fpn_features
+ return c2, c3, c4
+
+
+ def forward(self, x):
+ return self._forward_impl(x)
+
+ def _load_pretrained_model(self):
+ pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
+ model_dict = {}
+ state_dict = self.state_dict()
+ for k, v in pretrain_dict.items():
+ if k in state_dict:
+ model_dict[k] = v
+ state_dict.update(model_dict)
+ self.load_state_dict(state_dict)
+
+
+class MobileV2_MLSD_Tiny(nn.Module):
+ def __init__(self):
+ super(MobileV2_MLSD_Tiny, self).__init__()
+
+ self.backbone = MobileNetV2(pretrained=True)
+
+ self.block12 = BlockTypeA(in_c1= 32, in_c2= 64,
+ out_c1= 64, out_c2=64)
+ self.block13 = BlockTypeB(128, 64)
+
+ self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64,
+ out_c1= 32, out_c2= 32)
+ self.block15 = BlockTypeB(64, 64)
+
+ self.block16 = BlockTypeC(64, 16)
+
+ def forward(self, x):
+ c2, c3, c4 = self.backbone(x)
+
+ x = self.block12(c3, c4)
+ x = self.block13(x)
+ x = self.block14(c2, x)
+ x = self.block15(x)
+ x = self.block16(x)
+ x = x[:, 7:, :, :]
+ #print(x.shape)
+ x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)
+
+ return x
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/mlsd/utils.py b/sd-webui-controlnet/annotator/mlsd/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9cc5d904d9dd34d2ba4c902f3993f7abbb7ac5e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mlsd/utils.py
@@ -0,0 +1,581 @@
+'''
+modified by lihaoweicv
+pytorch version
+'''
+
+'''
+M-LSD
+Copyright 2021-present NAVER Corp.
+Apache License v2.0
+'''
+
+import os
+import numpy as np
+import cv2
+import torch
+from torch.nn import functional as F
+from modules import devices
+
+
+def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
+ '''
+ tpMap:
+ center: tpMap[1, 0, :, :]
+ displacement: tpMap[1, 1:5, :, :]
+ '''
+ b, c, h, w = tpMap.shape
+ assert b==1, 'only support bsize==1'
+ displacement = tpMap[:, 1:5, :, :][0]
+ center = tpMap[:, 0, :, :]
+ heat = torch.sigmoid(center)
+ hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2)
+ keep = (hmax == heat).float()
+ heat = heat * keep
+ heat = heat.reshape(-1, )
+
+ scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
+ yy = torch.floor_divide(indices, w).unsqueeze(-1)
+ xx = torch.fmod(indices, w).unsqueeze(-1)
+ ptss = torch.cat((yy, xx),dim=-1)
+
+ ptss = ptss.detach().cpu().numpy()
+ scores = scores.detach().cpu().numpy()
+ displacement = displacement.detach().cpu().numpy()
+ displacement = displacement.transpose((1,2,0))
+ return ptss, scores, displacement
+
+
+def pred_lines(image, model,
+ input_shape=[512, 512],
+ score_thr=0.10,
+ dist_thr=20.0):
+ h, w, _ = image.shape
+ h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
+
+ resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
+ np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
+
+ resized_image = resized_image.transpose((2,0,1))
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
+ batch_image = (batch_image / 127.5) - 1.0
+
+ batch_image = torch.from_numpy(batch_image).float().to(devices.get_device_for("controlnet"))
+ outputs = model(batch_image)
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
+ start = vmap[:, :, :2]
+ end = vmap[:, :, 2:]
+ dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
+
+ segments_list = []
+ for center, score in zip(pts, pts_score):
+ y, x = center
+ distance = dist_map[y, x]
+ if score > score_thr and distance > dist_thr:
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
+ x_start = x + disp_x_start
+ y_start = y + disp_y_start
+ x_end = x + disp_x_end
+ y_end = y + disp_y_end
+ segments_list.append([x_start, y_start, x_end, y_end])
+
+ lines = 2 * np.array(segments_list) # 256 > 512
+ lines[:, 0] = lines[:, 0] * w_ratio
+ lines[:, 1] = lines[:, 1] * h_ratio
+ lines[:, 2] = lines[:, 2] * w_ratio
+ lines[:, 3] = lines[:, 3] * h_ratio
+
+ return lines
+
+
+def pred_squares(image,
+ model,
+ input_shape=[512, 512],
+ params={'score': 0.06,
+ 'outside_ratio': 0.28,
+ 'inside_ratio': 0.45,
+ 'w_overlap': 0.0,
+ 'w_degree': 1.95,
+ 'w_length': 0.0,
+ 'w_area': 1.86,
+ 'w_center': 0.14}):
+ '''
+ shape = [height, width]
+ '''
+ h, w, _ = image.shape
+ original_shape = [h, w]
+
+ resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
+ np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
+ resized_image = resized_image.transpose((2, 0, 1))
+ batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
+ batch_image = (batch_image / 127.5) - 1.0
+
+ batch_image = torch.from_numpy(batch_image).float().to(devices.get_device_for("controlnet"))
+ outputs = model(batch_image)
+
+ pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
+ start = vmap[:, :, :2] # (x, y)
+ end = vmap[:, :, 2:] # (x, y)
+ dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
+
+ junc_list = []
+ segments_list = []
+ for junc, score in zip(pts, pts_score):
+ y, x = junc
+ distance = dist_map[y, x]
+ if score > params['score'] and distance > 20.0:
+ junc_list.append([x, y])
+ disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
+ d_arrow = 1.0
+ x_start = x + d_arrow * disp_x_start
+ y_start = y + d_arrow * disp_y_start
+ x_end = x + d_arrow * disp_x_end
+ y_end = y + d_arrow * disp_y_end
+ segments_list.append([x_start, y_start, x_end, y_end])
+
+ segments = np.array(segments_list)
+
+ ####### post processing for squares
+ # 1. get unique lines
+ point = np.array([[0, 0]])
+ point = point[0]
+ start = segments[:, :2]
+ end = segments[:, 2:]
+ diff = start - end
+ a = diff[:, 1]
+ b = -diff[:, 0]
+ c = a * start[:, 0] + b * start[:, 1]
+
+ d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10)
+ theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
+ theta[theta < 0.0] += 180
+ hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
+
+ d_quant = 1
+ theta_quant = 2
+ hough[:, 0] //= d_quant
+ hough[:, 1] //= theta_quant
+ _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True)
+
+ acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32')
+ idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1
+ yx_indices = hough[indices, :].astype('int32')
+ acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
+ idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
+
+ acc_map_np = acc_map
+ # acc_map = acc_map[None, :, :, None]
+ #
+ # ### fast suppression using tensorflow op
+ # acc_map = tf.constant(acc_map, dtype=tf.float32)
+ # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
+ # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
+ # flatten_acc_map = tf.reshape(acc_map, [1, -1])
+ # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
+ # _, h, w, _ = acc_map.shape
+ # y = tf.expand_dims(topk_indices // w, axis=-1)
+ # x = tf.expand_dims(topk_indices % w, axis=-1)
+ # yx = tf.concat([y, x], axis=-1)
+
+ ### fast suppression using pytorch op
+ acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
+ _,_, h, w = acc_map.shape
+ max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2)
+ acc_map = acc_map * ( (acc_map == max_acc_map).float() )
+ flatten_acc_map = acc_map.reshape([-1, ])
+
+ scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True)
+ yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
+ xx = torch.fmod(indices, w).unsqueeze(-1)
+ yx = torch.cat((yy, xx), dim=-1)
+
+ yx = yx.detach().cpu().numpy()
+
+ topk_values = scores.detach().cpu().numpy()
+ indices = idx_map[yx[:, 0], yx[:, 1]]
+ basis = 5 // 2
+
+ merged_segments = []
+ for yx_pt, max_indice, value in zip(yx, indices, topk_values):
+ y, x = yx_pt
+ if max_indice == -1 or value == 0:
+ continue
+ segment_list = []
+ for y_offset in range(-basis, basis + 1):
+ for x_offset in range(-basis, basis + 1):
+ indice = idx_map[y + y_offset, x + x_offset]
+ cnt = int(acc_map_np[y + y_offset, x + x_offset])
+ if indice != -1:
+ segment_list.append(segments[indice])
+ if cnt > 1:
+ check_cnt = 1
+ current_hough = hough[indice]
+ for new_indice, new_hough in enumerate(hough):
+ if (current_hough == new_hough).all() and indice != new_indice:
+ segment_list.append(segments[new_indice])
+ check_cnt += 1
+ if check_cnt == cnt:
+ break
+ group_segments = np.array(segment_list).reshape([-1, 2])
+ sorted_group_segments = np.sort(group_segments, axis=0)
+ x_min, y_min = sorted_group_segments[0, :]
+ x_max, y_max = sorted_group_segments[-1, :]
+
+ deg = theta[max_indice]
+ if deg >= 90:
+ merged_segments.append([x_min, y_max, x_max, y_min])
+ else:
+ merged_segments.append([x_min, y_min, x_max, y_max])
+
+ # 2. get intersections
+ new_segments = np.array(merged_segments) # (x1, y1, x2, y2)
+ start = new_segments[:, :2] # (x1, y1)
+ end = new_segments[:, 2:] # (x2, y2)
+ new_centers = (start + end) / 2.0
+ diff = start - end
+ dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1))
+
+ # ax + by = c
+ a = diff[:, 1]
+ b = -diff[:, 0]
+ c = a * start[:, 0] + b * start[:, 1]
+ pre_det = a[:, None] * b[None, :]
+ det = pre_det - np.transpose(pre_det)
+
+ pre_inter_y = a[:, None] * c[None, :]
+ inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
+ pre_inter_x = c[:, None] * b[None, :]
+ inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
+ inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32')
+
+ # 3. get corner information
+ # 3.1 get distance
+ '''
+ dist_segments:
+ | dist(0), dist(1), dist(2), ...|
+ dist_inter_to_segment1:
+ | dist(inter,0), dist(inter,0), dist(inter,0), ... |
+ | dist(inter,1), dist(inter,1), dist(inter,1), ... |
+ ...
+ dist_inter_to_semgnet2:
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
+ | dist(inter,0), dist(inter,1), dist(inter,2), ... |
+ ...
+ '''
+
+ dist_inter_to_segment1_start = np.sqrt(
+ np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ dist_inter_to_segment1_end = np.sqrt(
+ np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ dist_inter_to_segment2_start = np.sqrt(
+ np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+ dist_inter_to_segment2_end = np.sqrt(
+ np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
+
+ # sort ascending
+ dist_inter_to_segment1 = np.sort(
+ np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
+ axis=-1) # [n_batch, n_batch, 2]
+ dist_inter_to_segment2 = np.sort(
+ np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
+ axis=-1) # [n_batch, n_batch, 2]
+
+ # 3.2 get degree
+ inter_to_start = new_centers[:, None, :] - inter_pts
+ deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi
+ deg_inter_to_start[deg_inter_to_start < 0.0] += 360
+ inter_to_end = new_centers[None, :, :] - inter_pts
+ deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi
+ deg_inter_to_end[deg_inter_to_end < 0.0] += 360
+
+ '''
+ B -- G
+ | |
+ C -- R
+ B : blue / G: green / C: cyan / R: red
+
+ 0 -- 1
+ | |
+ 3 -- 2
+ '''
+ # rename variables
+ deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
+ # sort deg ascending
+ deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1)
+
+ deg_diff_map = np.abs(deg1_map - deg2_map)
+ # we only consider the smallest degree of intersect
+ deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
+
+ # define available degree range
+ deg_range = [60, 120]
+
+ corner_dict = {corner_info: [] for corner_info in range(4)}
+ inter_points = []
+ for i in range(inter_pts.shape[0]):
+ for j in range(i + 1, inter_pts.shape[1]):
+ # i, j > line index, always i < j
+ x, y = inter_pts[i, j, :]
+ deg1, deg2 = deg_sort[i, j, :]
+ deg_diff = deg_diff_map[i, j]
+
+ check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
+
+ outside_ratio = params['outside_ratio'] # over ratio >>> drop it!
+ inside_ratio = params['inside_ratio'] # over ratio >>> drop it!
+ check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \
+ (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \
+ dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
+ ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \
+ (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \
+ dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
+
+ if check_degree and check_distance:
+ corner_info = None
+
+ if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
+ (deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
+ corner_info, color_info = 0, 'blue'
+ elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225):
+ corner_info, color_info = 1, 'green'
+ elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315):
+ corner_info, color_info = 2, 'black'
+ elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
+ (deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
+ corner_info, color_info = 3, 'cyan'
+ else:
+ corner_info, color_info = 4, 'red' # we don't use it
+ continue
+
+ corner_dict[corner_info].append([x, y, i, j])
+ inter_points.append([x, y])
+
+ square_list = []
+ connect_list = []
+ segments_list = []
+ for corner0 in corner_dict[0]:
+ for corner1 in corner_dict[1]:
+ connect01 = False
+ for corner0_line in corner0[2:]:
+ if corner0_line in corner1[2:]:
+ connect01 = True
+ break
+ if connect01:
+ for corner2 in corner_dict[2]:
+ connect12 = False
+ for corner1_line in corner1[2:]:
+ if corner1_line in corner2[2:]:
+ connect12 = True
+ break
+ if connect12:
+ for corner3 in corner_dict[3]:
+ connect23 = False
+ for corner2_line in corner2[2:]:
+ if corner2_line in corner3[2:]:
+ connect23 = True
+ break
+ if connect23:
+ for corner3_line in corner3[2:]:
+ if corner3_line in corner0[2:]:
+ # SQUARE!!!
+ '''
+ 0 -- 1
+ | |
+ 3 -- 2
+ square_list:
+ order: 0 > 1 > 2 > 3
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
+ | x0, y0, x1, y1, x2, y2, x3, y3 |
+ ...
+ connect_list:
+ order: 01 > 12 > 23 > 30
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
+ | line_idx01, line_idx12, line_idx23, line_idx30 |
+ ...
+ segments_list:
+ order: 0 > 1 > 2 > 3
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
+ | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
+ ...
+ '''
+ square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2])
+ connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line])
+ segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:])
+
+ def check_outside_inside(segments_info, connect_idx):
+ # return 'outside or inside', min distance, cover_param, peri_param
+ if connect_idx == segments_info[0]:
+ check_dist_mat = dist_inter_to_segment1
+ else:
+ check_dist_mat = dist_inter_to_segment2
+
+ i, j = segments_info
+ min_dist, max_dist = check_dist_mat[i, j, :]
+ connect_dist = dist_segments[connect_idx]
+ if max_dist > connect_dist:
+ return 'outside', min_dist, 0, 1
+ else:
+ return 'inside', min_dist, -1, -1
+
+ top_square = None
+
+ try:
+ map_size = input_shape[0] / 2
+ squares = np.array(square_list).reshape([-1, 4, 2])
+ score_array = []
+ connect_array = np.array(connect_list)
+ segments_array = np.array(segments_list).reshape([-1, 4, 2])
+
+ # get degree of corners:
+ squares_rollup = np.roll(squares, 1, axis=1)
+ squares_rolldown = np.roll(squares, -1, axis=1)
+ vec1 = squares_rollup - squares
+ normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
+ vec2 = squares_rolldown - squares
+ normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
+ inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4]
+ squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4]
+
+ # get square score
+ overlap_scores = []
+ degree_scores = []
+ length_scores = []
+
+ for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree):
+ '''
+ 0 -- 1
+ | |
+ 3 -- 2
+
+ # segments: [4, 2]
+ # connects: [4]
+ '''
+
+ ###################################### OVERLAP SCORES
+ cover = 0
+ perimeter = 0
+ # check 0 > 1 > 2 > 3
+ square_length = []
+
+ for start_idx in range(4):
+ end_idx = (start_idx + 1) % 4
+
+ connect_idx = connects[start_idx] # segment idx of segment01
+ start_segments = segments[start_idx]
+ end_segments = segments[end_idx]
+
+ start_point = square[start_idx]
+ end_point = square[end_idx]
+
+ # check whether outside or inside
+ start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments,
+ connect_idx)
+ end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx)
+
+ cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min
+ perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min
+
+ square_length.append(
+ dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min)
+
+ overlap_scores.append(cover / perimeter)
+ ######################################
+ ###################################### DEGREE SCORES
+ '''
+ deg0 vs deg2
+ deg1 vs deg3
+ '''
+ deg0, deg1, deg2, deg3 = degree
+ deg_ratio1 = deg0 / deg2
+ if deg_ratio1 > 1.0:
+ deg_ratio1 = 1 / deg_ratio1
+ deg_ratio2 = deg1 / deg3
+ if deg_ratio2 > 1.0:
+ deg_ratio2 = 1 / deg_ratio2
+ degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
+ ######################################
+ ###################################### LENGTH SCORES
+ '''
+ len0 vs len2
+ len1 vs len3
+ '''
+ len0, len1, len2, len3 = square_length
+ len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
+ len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
+ length_scores.append((len_ratio1 + len_ratio2) / 2)
+
+ ######################################
+
+ overlap_scores = np.array(overlap_scores)
+ overlap_scores /= np.max(overlap_scores)
+
+ degree_scores = np.array(degree_scores)
+ # degree_scores /= np.max(degree_scores)
+
+ length_scores = np.array(length_scores)
+
+ ###################################### AREA SCORES
+ area_scores = np.reshape(squares, [-1, 4, 2])
+ area_x = area_scores[:, :, 0]
+ area_y = area_scores[:, :, 1]
+ correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0]
+ area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1)
+ area_scores = 0.5 * np.abs(area_scores + correction)
+ area_scores /= (map_size * map_size) # np.max(area_scores)
+ ######################################
+
+ ###################################### CENTER SCORES
+ centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2]
+ # squares: [n, 4, 2]
+ square_centers = np.mean(squares, axis=1) # [n, 2]
+ center2center = np.sqrt(np.sum((centers - square_centers) ** 2))
+ center_scores = center2center / (map_size / np.sqrt(2.0))
+
+ '''
+ score_w = [overlap, degree, area, center, length]
+ '''
+ score_w = [0.0, 1.0, 10.0, 0.5, 1.0]
+ score_array = params['w_overlap'] * overlap_scores \
+ + params['w_degree'] * degree_scores \
+ + params['w_area'] * area_scores \
+ - params['w_center'] * center_scores \
+ + params['w_length'] * length_scores
+
+ best_square = []
+
+ sorted_idx = np.argsort(score_array)[::-1]
+ score_array = score_array[sorted_idx]
+ squares = squares[sorted_idx]
+
+ except Exception as e:
+ pass
+
+ '''return list
+ merged_lines, squares, scores
+ '''
+
+ try:
+ new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1]
+ new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0]
+ new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1]
+ new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0]
+ except:
+ new_segments = []
+
+ try:
+ squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
+ squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
+ except:
+ squares = []
+ score_array = []
+
+ try:
+ inter_points = np.array(inter_points)
+ inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1]
+ inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0]
+ except:
+ inter_points = []
+
+ return new_segments, squares, score_array, inter_points
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..210a2989138380559f23045b568d0fbbeb918c03
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# flake8: noqa
+from .arraymisc import *
+from .fileio import *
+from .image import *
+from .utils import *
+from .version import *
+from .video import *
+from .visualization import *
+
+# The following modules are not imported to this level, so mmcv may be used
+# without PyTorch.
+# - runner
+# - parallel
+# - op
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/arraymisc/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/arraymisc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b4700d6139ae3d604ff6e542468cce4200c020c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/arraymisc/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .quantization import dequantize, quantize
+
+__all__ = ['quantize', 'dequantize']
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/arraymisc/quantization.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/arraymisc/quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e47a3545780cf071a1ef8195efb0b7b662c8186
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/arraymisc/quantization.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+
+def quantize(arr, min_val, max_val, levels, dtype=np.int64):
+ """Quantize an array of (-inf, inf) to [0, levels-1].
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the quantized array.
+
+ Returns:
+ tuple: Quantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(
+ f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ arr = np.clip(arr, min_val, max_val) - min_val
+ quantized_arr = np.minimum(
+ np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
+
+ return quantized_arr
+
+
+def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
+ """Dequantize an array.
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the dequantized array.
+
+ Returns:
+ tuple: Dequantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(
+ f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
+ min_val) / levels + min_val
+
+ return dequantized_arr
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7246c897430f0cc7ce12719ad8608824fc734446
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/__init__.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .alexnet import AlexNet
+# yapf: disable
+from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
+ PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
+ ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
+ ConvTranspose2d, ConvTranspose3d, ConvWS2d,
+ DepthwiseSeparableConvModule, GeneralizedAttention,
+ HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
+ NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
+ build_activation_layer, build_conv_layer,
+ build_norm_layer, build_padding_layer, build_plugin_layer,
+ build_upsample_layer, conv_ws_2d, is_norm)
+from .builder import MODELS, build_model_from_cfg
+# yapf: enable
+from .resnet import ResNet, make_res_layer
+from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
+ NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
+ XavierInit, bias_init_with_prob, caffe2_xavier_init,
+ constant_init, fuse_conv_bn, get_model_complexity_info,
+ initialize, kaiming_init, normal_init, trunc_normal_init,
+ uniform_init, xavier_init)
+from .vgg import VGG, make_vgg_layer
+
+__all__ = [
+ 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
+ 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
+ 'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
+ 'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
+ 'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
+ 'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
+ 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
+ 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
+ 'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
+ 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
+ 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
+ 'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+ 'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/alexnet.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/alexnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..89e36b8c7851f895d9ae7f07149f0e707456aab0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/alexnet.py
@@ -0,0 +1,61 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.nn as nn
+
+
+class AlexNet(nn.Module):
+ """AlexNet backbone.
+
+ Args:
+ num_classes (int): number of classes for classification.
+ """
+
+ def __init__(self, num_classes=-1):
+ super(AlexNet, self).__init__()
+ self.num_classes = num_classes
+ self.features = nn.Sequential(
+ nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(64, 192, kernel_size=5, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(192, 384, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(384, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ )
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Dropout(),
+ nn.Linear(256 * 6 * 6, 4096),
+ nn.ReLU(inplace=True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(inplace=True),
+ nn.Linear(4096, num_classes),
+ )
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ # use default initializer
+ pass
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+
+ x = self.features(x)
+ if self.num_classes > 0:
+ x = x.view(x.size(0), 256 * 6 * 6)
+ x = self.classifier(x)
+
+ return x
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f33124ed23fc6f27119a37bcb5ab004d3572be0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/__init__.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .activation import build_activation_layer
+from .context_block import ContextBlock
+from .conv import build_conv_layer
+from .conv2d_adaptive_padding import Conv2dAdaptivePadding
+from .conv_module import ConvModule
+from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
+from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
+from .drop import Dropout, DropPath
+from .generalized_attention import GeneralizedAttention
+from .hsigmoid import HSigmoid
+from .hswish import HSwish
+from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
+from .norm import build_norm_layer, is_norm
+from .padding import build_padding_layer
+from .plugin import build_plugin_layer
+from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
+ PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS)
+from .scale import Scale
+from .swish import Swish
+from .upsample import build_upsample_layer
+from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
+ Linear, MaxPool2d, MaxPool3d)
+
+__all__ = [
+ 'ConvModule', 'build_activation_layer', 'build_conv_layer',
+ 'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
+ 'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
+ 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
+ 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
+ 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
+ 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
+ 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/activation.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8951058c8e77eda02c130f3401c9680702e231c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/activation.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from annotator.mmpkg.mmcv.utils import TORCH_VERSION, build_from_cfg, digit_version
+from .registry import ACTIVATION_LAYERS
+
+for module in [
+ nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
+ nn.Sigmoid, nn.Tanh
+]:
+ ACTIVATION_LAYERS.register_module(module=module)
+
+
+@ACTIVATION_LAYERS.register_module(name='Clip')
+@ACTIVATION_LAYERS.register_module()
+class Clamp(nn.Module):
+ """Clamp activation layer.
+
+ This activation function is to clamp the feature map value within
+ :math:`[min, max]`. More details can be found in ``torch.clamp()``.
+
+ Args:
+ min (Number | optional): Lower-bound of the range to be clamped to.
+ Default to -1.
+ max (Number | optional): Upper-bound of the range to be clamped to.
+ Default to 1.
+ """
+
+ def __init__(self, min=-1., max=1.):
+ super(Clamp, self).__init__()
+ self.min = min
+ self.max = max
+
+ def forward(self, x):
+ """Forward function.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: Clamped tensor.
+ """
+ return torch.clamp(x, min=self.min, max=self.max)
+
+
+class GELU(nn.Module):
+ r"""Applies the Gaussian Error Linear Units function:
+
+ .. math::
+ \text{GELU}(x) = x * \Phi(x)
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for
+ Gaussian Distribution.
+
+ Shape:
+ - Input: :math:`(N, *)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(N, *)`, same shape as the input
+
+ .. image:: scripts/activation_images/GELU.png
+
+ Examples::
+
+ >>> m = nn.GELU()
+ >>> input = torch.randn(2)
+ >>> output = m(input)
+ """
+
+ def forward(self, input):
+ return F.gelu(input)
+
+
+if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.4')):
+ ACTIVATION_LAYERS.register_module(module=GELU)
+else:
+ ACTIVATION_LAYERS.register_module(module=nn.GELU)
+
+
+def build_activation_layer(cfg):
+ """Build activation layer.
+
+ Args:
+ cfg (dict): The activation layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an activation layer.
+
+ Returns:
+ nn.Module: Created activation layer.
+ """
+ return build_from_cfg(cfg, ACTIVATION_LAYERS)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/context_block.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/context_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..d60fdb904c749ce3b251510dff3cc63cea70d42e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/context_block.py
@@ -0,0 +1,125 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+
+from ..utils import constant_init, kaiming_init
+from .registry import PLUGIN_LAYERS
+
+
+def last_zero_init(m):
+ if isinstance(m, nn.Sequential):
+ constant_init(m[-1], val=0)
+ else:
+ constant_init(m, val=0)
+
+
+@PLUGIN_LAYERS.register_module()
+class ContextBlock(nn.Module):
+ """ContextBlock module in GCNet.
+
+ See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
+ (https://arxiv.org/abs/1904.11492) for details.
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ ratio (float): Ratio of channels of transform bottleneck
+ pooling_type (str): Pooling method for context modeling.
+ Options are 'att' and 'avg', stand for attention pooling and
+ average pooling respectively. Default: 'att'.
+ fusion_types (Sequence[str]): Fusion method for feature fusion,
+ Options are 'channels_add', 'channel_mul', stand for channelwise
+ addition and multiplication respectively. Default: ('channel_add',)
+ """
+
+ _abbr_ = 'context_block'
+
+ def __init__(self,
+ in_channels,
+ ratio,
+ pooling_type='att',
+ fusion_types=('channel_add', )):
+ super(ContextBlock, self).__init__()
+ assert pooling_type in ['avg', 'att']
+ assert isinstance(fusion_types, (list, tuple))
+ valid_fusion_types = ['channel_add', 'channel_mul']
+ assert all([f in valid_fusion_types for f in fusion_types])
+ assert len(fusion_types) > 0, 'at least one fusion should be used'
+ self.in_channels = in_channels
+ self.ratio = ratio
+ self.planes = int(in_channels * ratio)
+ self.pooling_type = pooling_type
+ self.fusion_types = fusion_types
+ if pooling_type == 'att':
+ self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
+ self.softmax = nn.Softmax(dim=2)
+ else:
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ if 'channel_add' in fusion_types:
+ self.channel_add_conv = nn.Sequential(
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(inplace=True), # yapf: disable
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
+ else:
+ self.channel_add_conv = None
+ if 'channel_mul' in fusion_types:
+ self.channel_mul_conv = nn.Sequential(
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
+ nn.LayerNorm([self.planes, 1, 1]),
+ nn.ReLU(inplace=True), # yapf: disable
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
+ else:
+ self.channel_mul_conv = None
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.pooling_type == 'att':
+ kaiming_init(self.conv_mask, mode='fan_in')
+ self.conv_mask.inited = True
+
+ if self.channel_add_conv is not None:
+ last_zero_init(self.channel_add_conv)
+ if self.channel_mul_conv is not None:
+ last_zero_init(self.channel_mul_conv)
+
+ def spatial_pool(self, x):
+ batch, channel, height, width = x.size()
+ if self.pooling_type == 'att':
+ input_x = x
+ # [N, C, H * W]
+ input_x = input_x.view(batch, channel, height * width)
+ # [N, 1, C, H * W]
+ input_x = input_x.unsqueeze(1)
+ # [N, 1, H, W]
+ context_mask = self.conv_mask(x)
+ # [N, 1, H * W]
+ context_mask = context_mask.view(batch, 1, height * width)
+ # [N, 1, H * W]
+ context_mask = self.softmax(context_mask)
+ # [N, 1, H * W, 1]
+ context_mask = context_mask.unsqueeze(-1)
+ # [N, 1, C, 1]
+ context = torch.matmul(input_x, context_mask)
+ # [N, C, 1, 1]
+ context = context.view(batch, channel, 1, 1)
+ else:
+ # [N, C, 1, 1]
+ context = self.avg_pool(x)
+
+ return context
+
+ def forward(self, x):
+ # [N, C, 1, 1]
+ context = self.spatial_pool(x)
+
+ out = x
+ if self.channel_mul_conv is not None:
+ # [N, C, 1, 1]
+ channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
+ out = out * channel_mul_term
+ if self.channel_add_conv is not None:
+ # [N, C, 1, 1]
+ channel_add_term = self.channel_add_conv(context)
+ out = out + channel_add_term
+
+ return out
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf54491997a48ac3e7fadc4183ab7bf3e831024c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import nn
+
+from .registry import CONV_LAYERS
+
+CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d)
+CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d)
+CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
+CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
+
+
+def build_conv_layer(cfg, *args, **kwargs):
+ """Build convolution layer.
+
+ Args:
+ cfg (None or dict): The conv layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate an conv layer.
+ args (argument list): Arguments passed to the `__init__`
+ method of the corresponding conv layer.
+ kwargs (keyword arguments): Keyword arguments passed to the `__init__`
+ method of the corresponding conv layer.
+
+ Returns:
+ nn.Module: Created conv layer.
+ """
+ if cfg is None:
+ cfg_ = dict(type='Conv2d')
+ else:
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in CONV_LAYERS:
+ raise KeyError(f'Unrecognized norm type {layer_type}')
+ else:
+ conv_layer = CONV_LAYERS.get(layer_type)
+
+ layer = conv_layer(*args, **kwargs, **cfg_)
+
+ return layer
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv2d_adaptive_padding.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv2d_adaptive_padding.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45e758ac6cf8dfb0382d072fe09125bc7e9b888
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv2d_adaptive_padding.py
@@ -0,0 +1,62 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+from torch import nn
+from torch.nn import functional as F
+
+from .registry import CONV_LAYERS
+
+
+@CONV_LAYERS.register_module()
+class Conv2dAdaptivePadding(nn.Conv2d):
+ """Implementation of 2D convolution in tensorflow with `padding` as "same",
+ which applies padding to input (if needed) so that input image gets fully
+ covered by filter and stride you specified. For stride 1, this will ensure
+ that output image size is same as input. For stride of 2, output dimensions
+ will be half, for example.
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the
+ output. Default: ``True``
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True):
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0,
+ dilation, groups, bias)
+
+ def forward(self, x):
+ img_h, img_w = x.size()[-2:]
+ kernel_h, kernel_w = self.weight.size()[-2:]
+ stride_h, stride_w = self.stride
+ output_h = math.ceil(img_h / stride_h)
+ output_w = math.ceil(img_w / stride_w)
+ pad_h = (
+ max((output_h - 1) * self.stride[0] +
+ (kernel_h - 1) * self.dilation[0] + 1 - img_h, 0))
+ pad_w = (
+ max((output_w - 1) * self.stride[1] +
+ (kernel_w - 1) * self.dilation[1] + 1 - img_w, 0))
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
+ ])
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv_module.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..43cab72624ccc04b2f7877383588a4bbacf9117a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv_module.py
@@ -0,0 +1,206 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch.nn as nn
+
+from annotator.mmpkg.mmcv.utils import _BatchNorm, _InstanceNorm
+from ..utils import constant_init, kaiming_init
+from .activation import build_activation_layer
+from .conv import build_conv_layer
+from .norm import build_norm_layer
+from .padding import build_padding_layer
+from .registry import PLUGIN_LAYERS
+
+
+@PLUGIN_LAYERS.register_module()
+class ConvModule(nn.Module):
+ """A conv block that bundles conv/norm/activation layers.
+
+ This block simplifies the usage of convolution layers, which are commonly
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+ It is based upon three build methods: `build_conv_layer()`,
+ `build_norm_layer()` and `build_activation_layer()`.
+
+ Besides, we add some additional features in this module.
+ 1. Automatically set `bias` of the conv layer.
+ 2. Spectral norm is supported.
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
+ supports zero and circular padding, and we add "reflect" padding mode.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ Same as that in ``nn._ConvNd``.
+ out_channels (int): Number of channels produced by the convolution.
+ Same as that in ``nn._ConvNd``.
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
+ Same as that in ``nn._ConvNd``.
+ stride (int | tuple[int]): Stride of the convolution.
+ Same as that in ``nn._ConvNd``.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ the input. Same as that in ``nn._ConvNd``.
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ Same as that in ``nn._ConvNd``.
+ groups (int): Number of blocked connections from input channels to
+ output channels. Same as that in ``nn._ConvNd``.
+ bias (bool | str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
+ False. Default: "auto".
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ inplace (bool): Whether to use inplace mode for activation.
+ Default: True.
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
+ Default: False.
+ padding_mode (str): If the `padding_mode` has not been supported by
+ current `Conv2d` in PyTorch, we will use our own padding layer
+ instead. Currently, we support ['zeros', 'circular'] with official
+ implementation and ['reflect'] with our own implementation.
+ Default: 'zeros'.
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
+ sequence of "conv", "norm" and "act". Common examples are
+ ("conv", "norm", "act") and ("act", "conv", "norm").
+ Default: ('conv', 'norm', 'act').
+ """
+
+ _abbr_ = 'conv_block'
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias='auto',
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ inplace=True,
+ with_spectral_norm=False,
+ padding_mode='zeros',
+ order=('conv', 'norm', 'act')):
+ super(ConvModule, self).__init__()
+ assert conv_cfg is None or isinstance(conv_cfg, dict)
+ assert norm_cfg is None or isinstance(norm_cfg, dict)
+ assert act_cfg is None or isinstance(act_cfg, dict)
+ official_padding_mode = ['zeros', 'circular']
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.inplace = inplace
+ self.with_spectral_norm = with_spectral_norm
+ self.with_explicit_padding = padding_mode not in official_padding_mode
+ self.order = order
+ assert isinstance(self.order, tuple) and len(self.order) == 3
+ assert set(order) == set(['conv', 'norm', 'act'])
+
+ self.with_norm = norm_cfg is not None
+ self.with_activation = act_cfg is not None
+ # if the conv layer is before a norm layer, bias is unnecessary.
+ if bias == 'auto':
+ bias = not self.with_norm
+ self.with_bias = bias
+
+ if self.with_explicit_padding:
+ pad_cfg = dict(type=padding_mode)
+ self.padding_layer = build_padding_layer(pad_cfg, padding)
+
+ # reset padding to 0 for conv module
+ conv_padding = 0 if self.with_explicit_padding else padding
+ # build convolution layer
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=conv_padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ # export the attributes of self.conv to a higher level for convenience
+ self.in_channels = self.conv.in_channels
+ self.out_channels = self.conv.out_channels
+ self.kernel_size = self.conv.kernel_size
+ self.stride = self.conv.stride
+ self.padding = padding
+ self.dilation = self.conv.dilation
+ self.transposed = self.conv.transposed
+ self.output_padding = self.conv.output_padding
+ self.groups = self.conv.groups
+
+ if self.with_spectral_norm:
+ self.conv = nn.utils.spectral_norm(self.conv)
+
+ # build normalization layers
+ if self.with_norm:
+ # norm layer is after conv layer
+ if order.index('norm') > order.index('conv'):
+ norm_channels = out_channels
+ else:
+ norm_channels = in_channels
+ self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
+ self.add_module(self.norm_name, norm)
+ if self.with_bias:
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
+ warnings.warn(
+ 'Unnecessary conv bias before batch/instance norm')
+ else:
+ self.norm_name = None
+
+ # build activation layer
+ if self.with_activation:
+ act_cfg_ = act_cfg.copy()
+ # nn.Tanh has no 'inplace' argument
+ if act_cfg_['type'] not in [
+ 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish'
+ ]:
+ act_cfg_.setdefault('inplace', inplace)
+ self.activate = build_activation_layer(act_cfg_)
+
+ # Use msra init by default
+ self.init_weights()
+
+ @property
+ def norm(self):
+ if self.norm_name:
+ return getattr(self, self.norm_name)
+ else:
+ return None
+
+ def init_weights(self):
+ # 1. It is mainly for customized conv layers with their own
+ # initialization manners by calling their own ``init_weights()``,
+ # and we do not want ConvModule to override the initialization.
+ # 2. For customized conv layers without their own initialization
+ # manners (that is, they don't have their own ``init_weights()``)
+ # and PyTorch's conv layers, they will be initialized by
+ # this method with default ``kaiming_init``.
+ # Note: For PyTorch's conv layers, they will be overwritten by our
+ # initialization implementation using default ``kaiming_init``.
+ if not hasattr(self.conv, 'init_weights'):
+ if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
+ nonlinearity = 'leaky_relu'
+ a = self.act_cfg.get('negative_slope', 0.01)
+ else:
+ nonlinearity = 'relu'
+ a = 0
+ kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
+ if self.with_norm:
+ constant_init(self.norm, 1, bias=0)
+
+ def forward(self, x, activate=True, norm=True):
+ for layer in self.order:
+ if layer == 'conv':
+ if self.with_explicit_padding:
+ x = self.padding_layer(x)
+ x = self.conv(x)
+ elif layer == 'norm' and norm and self.with_norm:
+ x = self.norm(x)
+ elif layer == 'act' and activate and self.with_activation:
+ x = self.activate(x)
+ return x
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv_ws.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv_ws.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3941e27874993418b3b5708d5a7485f175ff9c8
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv_ws.py
@@ -0,0 +1,148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .registry import CONV_LAYERS
+
+
+def conv_ws_2d(input,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ eps=1e-5):
+ c_in = weight.size(0)
+ weight_flat = weight.view(c_in, -1)
+ mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+ std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+ weight = (weight - mean) / (std + eps)
+ return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
+
+
+@CONV_LAYERS.register_module('ConvWS')
+class ConvWS2d(nn.Conv2d):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ eps=1e-5):
+ super(ConvWS2d, self).__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ self.eps = eps
+
+ def forward(self, x):
+ return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups, self.eps)
+
+
+@CONV_LAYERS.register_module(name='ConvAWS')
+class ConvAWS2d(nn.Conv2d):
+ """AWS (Adaptive Weight Standardization)
+
+ This is a variant of Weight Standardization
+ (https://arxiv.org/pdf/1903.10520.pdf)
+ It is used in DetectoRS to avoid NaN
+ (https://arxiv.org/pdf/2006.02334.pdf)
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the conv kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If set True, adds a learnable bias to the
+ output. Default: True
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True):
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ self.register_buffer('weight_gamma',
+ torch.ones(self.out_channels, 1, 1, 1))
+ self.register_buffer('weight_beta',
+ torch.zeros(self.out_channels, 1, 1, 1))
+
+ def _get_weight(self, weight):
+ weight_flat = weight.view(weight.size(0), -1)
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
+ weight = (weight - mean) / std
+ weight = self.weight_gamma * weight + self.weight_beta
+ return weight
+
+ def forward(self, x):
+ weight = self._get_weight(self.weight)
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups)
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ """Override default load function.
+
+ AWS overrides the function _load_from_state_dict to recover
+ weight_gamma and weight_beta if they are missing. If weight_gamma and
+ weight_beta are found in the checkpoint, this function will return
+ after super()._load_from_state_dict. Otherwise, it will compute the
+ mean and std of the pretrained weights and store them in weight_beta
+ and weight_gamma.
+ """
+
+ self.weight_gamma.data.fill_(-1)
+ local_missing_keys = []
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, local_missing_keys,
+ unexpected_keys, error_msgs)
+ if self.weight_gamma.data.mean() > 0:
+ for k in local_missing_keys:
+ missing_keys.append(k)
+ return
+ weight = self.weight.data
+ weight_flat = weight.view(weight.size(0), -1)
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
+ self.weight_beta.data.copy_(mean)
+ self.weight_gamma.data.copy_(std)
+ missing_gamma_beta = [
+ k for k in local_missing_keys
+ if k.endswith('weight_gamma') or k.endswith('weight_beta')
+ ]
+ for k in missing_gamma_beta:
+ local_missing_keys.remove(k)
+ for k in local_missing_keys:
+ missing_keys.append(k)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/depthwise_separable_conv_module.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/depthwise_separable_conv_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..722d5d8d71f75486e2db3008907c4eadfca41d63
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/depthwise_separable_conv_module.py
@@ -0,0 +1,96 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .conv_module import ConvModule
+
+
+class DepthwiseSeparableConvModule(nn.Module):
+ """Depthwise separable convolution module.
+
+ See https://arxiv.org/pdf/1704.04861.pdf for details.
+
+ This module can replace a ConvModule with the conv block replaced by two
+ conv block: depthwise conv block and pointwise conv block. The depthwise
+ conv block contains depthwise-conv/norm/activation layers. The pointwise
+ conv block contains pointwise-conv/norm/activation layers. It should be
+ noted that there will be norm/activation layer in the depthwise conv block
+ if `norm_cfg` and `act_cfg` are specified.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ Same as that in ``nn._ConvNd``.
+ out_channels (int): Number of channels produced by the convolution.
+ Same as that in ``nn._ConvNd``.
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
+ Same as that in ``nn._ConvNd``.
+ stride (int | tuple[int]): Stride of the convolution.
+ Same as that in ``nn._ConvNd``. Default: 1.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ the input. Same as that in ``nn._ConvNd``. Default: 0.
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ Same as that in ``nn._ConvNd``. Default: 1.
+ norm_cfg (dict): Default norm config for both depthwise ConvModule and
+ pointwise ConvModule. Default: None.
+ act_cfg (dict): Default activation config for both depthwise ConvModule
+ and pointwise ConvModule. Default: dict(type='ReLU').
+ dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
+ dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
+ pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
+ pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
+ kwargs (optional): Other shared arguments for depthwise and pointwise
+ ConvModule. See ConvModule for ref.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ dw_norm_cfg='default',
+ dw_act_cfg='default',
+ pw_norm_cfg='default',
+ pw_act_cfg='default',
+ **kwargs):
+ super(DepthwiseSeparableConvModule, self).__init__()
+ assert 'groups' not in kwargs, 'groups should not be specified'
+
+ # if norm/activation config of depthwise/pointwise ConvModule is not
+ # specified, use default config.
+ dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
+ dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
+ pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
+ pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
+
+ # depthwise convolution
+ self.depthwise_conv = ConvModule(
+ in_channels,
+ in_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=in_channels,
+ norm_cfg=dw_norm_cfg,
+ act_cfg=dw_act_cfg,
+ **kwargs)
+
+ self.pointwise_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 1,
+ norm_cfg=pw_norm_cfg,
+ act_cfg=pw_act_cfg,
+ **kwargs)
+
+ def forward(self, x):
+ x = self.depthwise_conv(x)
+ x = self.pointwise_conv(x)
+ return x
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/drop.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..465ed38339fe64dde8cdc959451b1236a3a55b95
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/drop.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from annotator.mmpkg.mmcv import build_from_cfg
+from .registry import DROPOUT_LAYERS
+
+
+def drop_path(x, drop_prob=0., training=False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+
+ We follow the implementation
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ # handle tensors with different dimensions, not just 4D tensors.
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
+ random_tensor = keep_prob + torch.rand(
+ shape, dtype=x.dtype, device=x.device)
+ output = x.div(keep_prob) * random_tensor.floor()
+ return output
+
+
+@DROPOUT_LAYERS.register_module()
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+
+ We follow the implementation
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
+
+ Args:
+ drop_prob (float): Probability of the path to be zeroed. Default: 0.1
+ """
+
+ def __init__(self, drop_prob=0.1):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+@DROPOUT_LAYERS.register_module()
+class Dropout(nn.Dropout):
+ """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
+ ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
+ ``DropPath``
+
+ Args:
+ drop_prob (float): Probability of the elements to be
+ zeroed. Default: 0.5.
+ inplace (bool): Do the operation inplace or not. Default: False.
+ """
+
+ def __init__(self, drop_prob=0.5, inplace=False):
+ super().__init__(p=drop_prob, inplace=inplace)
+
+
+def build_dropout(cfg, default_args=None):
+ """Builder for drop out layers."""
+ return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/generalized_attention.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/generalized_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..988d9adf2f289ef223bd1c680a5ae1d3387f0269
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/generalized_attention.py
@@ -0,0 +1,412 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..utils import kaiming_init
+from .registry import PLUGIN_LAYERS
+
+
+@PLUGIN_LAYERS.register_module()
+class GeneralizedAttention(nn.Module):
+ """GeneralizedAttention module.
+
+ See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks'
+ (https://arxiv.org/abs/1711.07971) for details.
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ spatial_range (int): The spatial range. -1 indicates no spatial range
+ constraint. Default: -1.
+ num_heads (int): The head number of empirical_attention module.
+ Default: 9.
+ position_embedding_dim (int): The position embedding dimension.
+ Default: -1.
+ position_magnitude (int): A multiplier acting on coord difference.
+ Default: 1.
+ kv_stride (int): The feature stride acting on key/value feature map.
+ Default: 2.
+ q_stride (int): The feature stride acting on query feature map.
+ Default: 1.
+ attention_type (str): A binary indicator string for indicating which
+ items in generalized empirical_attention module are used.
+ Default: '1111'.
+
+ - '1000' indicates 'query and key content' (appr - appr) item,
+ - '0100' indicates 'query content and relative position'
+ (appr - position) item,
+ - '0010' indicates 'key content only' (bias - appr) item,
+ - '0001' indicates 'relative position only' (bias - position) item.
+ """
+
+ _abbr_ = 'gen_attention_block'
+
+ def __init__(self,
+ in_channels,
+ spatial_range=-1,
+ num_heads=9,
+ position_embedding_dim=-1,
+ position_magnitude=1,
+ kv_stride=2,
+ q_stride=1,
+ attention_type='1111'):
+
+ super(GeneralizedAttention, self).__init__()
+
+ # hard range means local range for non-local operation
+ self.position_embedding_dim = (
+ position_embedding_dim
+ if position_embedding_dim > 0 else in_channels)
+
+ self.position_magnitude = position_magnitude
+ self.num_heads = num_heads
+ self.in_channels = in_channels
+ self.spatial_range = spatial_range
+ self.kv_stride = kv_stride
+ self.q_stride = q_stride
+ self.attention_type = [bool(int(_)) for _ in attention_type]
+ self.qk_embed_dim = in_channels // num_heads
+ out_c = self.qk_embed_dim * num_heads
+
+ if self.attention_type[0] or self.attention_type[1]:
+ self.query_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_c,
+ kernel_size=1,
+ bias=False)
+ self.query_conv.kaiming_init = True
+
+ if self.attention_type[0] or self.attention_type[2]:
+ self.key_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_c,
+ kernel_size=1,
+ bias=False)
+ self.key_conv.kaiming_init = True
+
+ self.v_dim = in_channels // num_heads
+ self.value_conv = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=self.v_dim * num_heads,
+ kernel_size=1,
+ bias=False)
+ self.value_conv.kaiming_init = True
+
+ if self.attention_type[1] or self.attention_type[3]:
+ self.appr_geom_fc_x = nn.Linear(
+ self.position_embedding_dim // 2, out_c, bias=False)
+ self.appr_geom_fc_x.kaiming_init = True
+
+ self.appr_geom_fc_y = nn.Linear(
+ self.position_embedding_dim // 2, out_c, bias=False)
+ self.appr_geom_fc_y.kaiming_init = True
+
+ if self.attention_type[2]:
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
+ appr_bias_value = -2 * stdv * torch.rand(out_c) + stdv
+ self.appr_bias = nn.Parameter(appr_bias_value)
+
+ if self.attention_type[3]:
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
+ geom_bias_value = -2 * stdv * torch.rand(out_c) + stdv
+ self.geom_bias = nn.Parameter(geom_bias_value)
+
+ self.proj_conv = nn.Conv2d(
+ in_channels=self.v_dim * num_heads,
+ out_channels=in_channels,
+ kernel_size=1,
+ bias=True)
+ self.proj_conv.kaiming_init = True
+ self.gamma = nn.Parameter(torch.zeros(1))
+
+ if self.spatial_range >= 0:
+ # only works when non local is after 3*3 conv
+ if in_channels == 256:
+ max_len = 84
+ elif in_channels == 512:
+ max_len = 42
+
+ max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
+ local_constraint_map = np.ones(
+ (max_len, max_len, max_len_kv, max_len_kv), dtype=np.int)
+ for iy in range(max_len):
+ for ix in range(max_len):
+ local_constraint_map[
+ iy, ix,
+ max((iy - self.spatial_range) //
+ self.kv_stride, 0):min((iy + self.spatial_range +
+ 1) // self.kv_stride +
+ 1, max_len),
+ max((ix - self.spatial_range) //
+ self.kv_stride, 0):min((ix + self.spatial_range +
+ 1) // self.kv_stride +
+ 1, max_len)] = 0
+
+ self.local_constraint_map = nn.Parameter(
+ torch.from_numpy(local_constraint_map).byte(),
+ requires_grad=False)
+
+ if self.q_stride > 1:
+ self.q_downsample = nn.AvgPool2d(
+ kernel_size=1, stride=self.q_stride)
+ else:
+ self.q_downsample = None
+
+ if self.kv_stride > 1:
+ self.kv_downsample = nn.AvgPool2d(
+ kernel_size=1, stride=self.kv_stride)
+ else:
+ self.kv_downsample = None
+
+ self.init_weights()
+
+ def get_position_embedding(self,
+ h,
+ w,
+ h_kv,
+ w_kv,
+ q_stride,
+ kv_stride,
+ device,
+ dtype,
+ feat_dim,
+ wave_length=1000):
+ # the default type of Tensor is float32, leading to type mismatch
+ # in fp16 mode. Cast it to support fp16 mode.
+ h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
+ h_idxs = h_idxs.view((h, 1)) * q_stride
+
+ w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
+ w_idxs = w_idxs.view((w, 1)) * q_stride
+
+ h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
+ device=device, dtype=dtype)
+ h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
+
+ w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
+ device=device, dtype=dtype)
+ w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
+
+ # (h, h_kv, 1)
+ h_diff = h_idxs.unsqueeze(1) - h_kv_idxs.unsqueeze(0)
+ h_diff *= self.position_magnitude
+
+ # (w, w_kv, 1)
+ w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
+ w_diff *= self.position_magnitude
+
+ feat_range = torch.arange(0, feat_dim / 4).to(
+ device=device, dtype=dtype)
+
+ dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
+ dim_mat = dim_mat**((4. / feat_dim) * feat_range)
+ dim_mat = dim_mat.view((1, 1, -1))
+
+ embedding_x = torch.cat(
+ ((w_diff / dim_mat).sin(), (w_diff / dim_mat).cos()), dim=2)
+
+ embedding_y = torch.cat(
+ ((h_diff / dim_mat).sin(), (h_diff / dim_mat).cos()), dim=2)
+
+ return embedding_x, embedding_y
+
+ def forward(self, x_input):
+ num_heads = self.num_heads
+
+ # use empirical_attention
+ if self.q_downsample is not None:
+ x_q = self.q_downsample(x_input)
+ else:
+ x_q = x_input
+ n, _, h, w = x_q.shape
+
+ if self.kv_downsample is not None:
+ x_kv = self.kv_downsample(x_input)
+ else:
+ x_kv = x_input
+ _, _, h_kv, w_kv = x_kv.shape
+
+ if self.attention_type[0] or self.attention_type[1]:
+ proj_query = self.query_conv(x_q).view(
+ (n, num_heads, self.qk_embed_dim, h * w))
+ proj_query = proj_query.permute(0, 1, 3, 2)
+
+ if self.attention_type[0] or self.attention_type[2]:
+ proj_key = self.key_conv(x_kv).view(
+ (n, num_heads, self.qk_embed_dim, h_kv * w_kv))
+
+ if self.attention_type[1] or self.attention_type[3]:
+ position_embed_x, position_embed_y = self.get_position_embedding(
+ h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
+ x_input.device, x_input.dtype, self.position_embedding_dim)
+ # (n, num_heads, w, w_kv, dim)
+ position_feat_x = self.appr_geom_fc_x(position_embed_x).\
+ view(1, w, w_kv, num_heads, self.qk_embed_dim).\
+ permute(0, 3, 1, 2, 4).\
+ repeat(n, 1, 1, 1, 1)
+
+ # (n, num_heads, h, h_kv, dim)
+ position_feat_y = self.appr_geom_fc_y(position_embed_y).\
+ view(1, h, h_kv, num_heads, self.qk_embed_dim).\
+ permute(0, 3, 1, 2, 4).\
+ repeat(n, 1, 1, 1, 1)
+
+ position_feat_x /= math.sqrt(2)
+ position_feat_y /= math.sqrt(2)
+
+ # accelerate for saliency only
+ if (np.sum(self.attention_type) == 1) and self.attention_type[2]:
+ appr_bias = self.appr_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim).\
+ repeat(n, 1, 1, 1)
+
+ energy = torch.matmul(appr_bias, proj_key).\
+ view(n, num_heads, 1, h_kv * w_kv)
+
+ h = 1
+ w = 1
+ else:
+ # (n, num_heads, h*w, h_kv*w_kv), query before key, 540mb for
+ if not self.attention_type[0]:
+ energy = torch.zeros(
+ n,
+ num_heads,
+ h,
+ w,
+ h_kv,
+ w_kv,
+ dtype=x_input.dtype,
+ device=x_input.device)
+
+ # attention_type[0]: appr - appr
+ # attention_type[1]: appr - position
+ # attention_type[2]: bias - appr
+ # attention_type[3]: bias - position
+ if self.attention_type[0] or self.attention_type[2]:
+ if self.attention_type[0] and self.attention_type[2]:
+ appr_bias = self.appr_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim)
+ energy = torch.matmul(proj_query + appr_bias, proj_key).\
+ view(n, num_heads, h, w, h_kv, w_kv)
+
+ elif self.attention_type[0]:
+ energy = torch.matmul(proj_query, proj_key).\
+ view(n, num_heads, h, w, h_kv, w_kv)
+
+ elif self.attention_type[2]:
+ appr_bias = self.appr_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim).\
+ repeat(n, 1, 1, 1)
+
+ energy += torch.matmul(appr_bias, proj_key).\
+ view(n, num_heads, 1, 1, h_kv, w_kv)
+
+ if self.attention_type[1] or self.attention_type[3]:
+ if self.attention_type[1] and self.attention_type[3]:
+ geom_bias = self.geom_bias.\
+ view(1, num_heads, 1, self.qk_embed_dim)
+
+ proj_query_reshape = (proj_query + geom_bias).\
+ view(n, num_heads, h, w, self.qk_embed_dim)
+
+ energy_x = torch.matmul(
+ proj_query_reshape.permute(0, 1, 3, 2, 4),
+ position_feat_x.permute(0, 1, 2, 4, 3))
+ energy_x = energy_x.\
+ permute(0, 1, 3, 2, 4).unsqueeze(4)
+
+ energy_y = torch.matmul(
+ proj_query_reshape,
+ position_feat_y.permute(0, 1, 2, 4, 3))
+ energy_y = energy_y.unsqueeze(5)
+
+ energy += energy_x + energy_y
+
+ elif self.attention_type[1]:
+ proj_query_reshape = proj_query.\
+ view(n, num_heads, h, w, self.qk_embed_dim)
+ proj_query_reshape = proj_query_reshape.\
+ permute(0, 1, 3, 2, 4)
+ position_feat_x_reshape = position_feat_x.\
+ permute(0, 1, 2, 4, 3)
+ position_feat_y_reshape = position_feat_y.\
+ permute(0, 1, 2, 4, 3)
+
+ energy_x = torch.matmul(proj_query_reshape,
+ position_feat_x_reshape)
+ energy_x = energy_x.permute(0, 1, 3, 2, 4).unsqueeze(4)
+
+ energy_y = torch.matmul(proj_query_reshape,
+ position_feat_y_reshape)
+ energy_y = energy_y.unsqueeze(5)
+
+ energy += energy_x + energy_y
+
+ elif self.attention_type[3]:
+ geom_bias = self.geom_bias.\
+ view(1, num_heads, self.qk_embed_dim, 1).\
+ repeat(n, 1, 1, 1)
+
+ position_feat_x_reshape = position_feat_x.\
+ view(n, num_heads, w*w_kv, self.qk_embed_dim)
+
+ position_feat_y_reshape = position_feat_y.\
+ view(n, num_heads, h * h_kv, self.qk_embed_dim)
+
+ energy_x = torch.matmul(position_feat_x_reshape, geom_bias)
+ energy_x = energy_x.view(n, num_heads, 1, w, 1, w_kv)
+
+ energy_y = torch.matmul(position_feat_y_reshape, geom_bias)
+ energy_y = energy_y.view(n, num_heads, h, 1, h_kv, 1)
+
+ energy += energy_x + energy_y
+
+ energy = energy.view(n, num_heads, h * w, h_kv * w_kv)
+
+ if self.spatial_range >= 0:
+ cur_local_constraint_map = \
+ self.local_constraint_map[:h, :w, :h_kv, :w_kv].\
+ contiguous().\
+ view(1, 1, h*w, h_kv*w_kv)
+
+ energy = energy.masked_fill_(cur_local_constraint_map,
+ float('-inf'))
+
+ attention = F.softmax(energy, 3)
+
+ proj_value = self.value_conv(x_kv)
+ proj_value_reshape = proj_value.\
+ view((n, num_heads, self.v_dim, h_kv * w_kv)).\
+ permute(0, 1, 3, 2)
+
+ out = torch.matmul(attention, proj_value_reshape).\
+ permute(0, 1, 3, 2).\
+ contiguous().\
+ view(n, self.v_dim * self.num_heads, h, w)
+
+ out = self.proj_conv(out)
+
+ # output is downsampled, upsample back to input size
+ if self.q_downsample is not None:
+ out = F.interpolate(
+ out,
+ size=x_input.shape[2:],
+ mode='bilinear',
+ align_corners=False)
+
+ out = self.gamma * out + x_input
+ return out
+
+ def init_weights(self):
+ for m in self.modules():
+ if hasattr(m, 'kaiming_init') and m.kaiming_init:
+ kaiming_init(
+ m,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ bias=0,
+ distribution='uniform',
+ a=1)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/hsigmoid.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/hsigmoid.py
new file mode 100644
index 0000000000000000000000000000000000000000..30b1a3d6580cf0360710426fbea1f05acdf07b4b
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/hsigmoid.py
@@ -0,0 +1,34 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .registry import ACTIVATION_LAYERS
+
+
+@ACTIVATION_LAYERS.register_module()
+class HSigmoid(nn.Module):
+ """Hard Sigmoid Module. Apply the hard sigmoid function:
+ Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
+ Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
+
+ Args:
+ bias (float): Bias of the input feature map. Default: 1.0.
+ divisor (float): Divisor of the input feature map. Default: 2.0.
+ min_value (float): Lower bound value. Default: 0.0.
+ max_value (float): Upper bound value. Default: 1.0.
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0):
+ super(HSigmoid, self).__init__()
+ self.bias = bias
+ self.divisor = divisor
+ assert self.divisor != 0
+ self.min_value = min_value
+ self.max_value = max_value
+
+ def forward(self, x):
+ x = (x + self.bias) / self.divisor
+
+ return x.clamp_(self.min_value, self.max_value)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/hswish.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/hswish.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e0c090ff037c99ee6c5c84c4592e87beae02208
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/hswish.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .registry import ACTIVATION_LAYERS
+
+
+@ACTIVATION_LAYERS.register_module()
+class HSwish(nn.Module):
+ """Hard Swish Module.
+
+ This module applies the hard swish function:
+
+ .. math::
+ Hswish(x) = x * ReLU6(x + 3) / 6
+
+ Args:
+ inplace (bool): can optionally do the operation in-place.
+ Default: False.
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self, inplace=False):
+ super(HSwish, self).__init__()
+ self.act = nn.ReLU6(inplace)
+
+ def forward(self, x):
+ return x * self.act(x + 3) / 6
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/non_local.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/non_local.py
new file mode 100644
index 0000000000000000000000000000000000000000..92d00155ef275c1201ea66bba30470a1785cc5d7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/non_local.py
@@ -0,0 +1,306 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta
+
+import torch
+import torch.nn as nn
+
+from ..utils import constant_init, normal_init
+from .conv_module import ConvModule
+from .registry import PLUGIN_LAYERS
+
+
+class _NonLocalNd(nn.Module, metaclass=ABCMeta):
+ """Basic Non-local module.
+
+ This module is proposed in
+ "Non-local Neural Networks"
+ Paper reference: https://arxiv.org/abs/1711.07971
+ Code reference: https://github.com/AlexHex7/Non-local_pytorch
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ reduction (int): Channel reduction ratio. Default: 2.
+ use_scale (bool): Whether to scale pairwise_weight by
+ `1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
+ Default: True.
+ conv_cfg (None | dict): The config dict for convolution layers.
+ If not specified, it will use `nn.Conv2d` for convolution layers.
+ Default: None.
+ norm_cfg (None | dict): The config dict for normalization layers.
+ Default: None. (This parameter is only applicable to conv_out.)
+ mode (str): Options are `gaussian`, `concatenation`,
+ `embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
+ """
+
+ def __init__(self,
+ in_channels,
+ reduction=2,
+ use_scale=True,
+ conv_cfg=None,
+ norm_cfg=None,
+ mode='embedded_gaussian',
+ **kwargs):
+ super(_NonLocalNd, self).__init__()
+ self.in_channels = in_channels
+ self.reduction = reduction
+ self.use_scale = use_scale
+ self.inter_channels = max(in_channels // reduction, 1)
+ self.mode = mode
+
+ if mode not in [
+ 'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
+ ]:
+ raise ValueError("Mode should be in 'gaussian', 'concatenation', "
+ f"'embedded_gaussian' or 'dot_product', but got "
+ f'{mode} instead.')
+
+ # g, theta, phi are defaulted as `nn.ConvNd`.
+ # Here we use ConvModule for potential usage.
+ self.g = ConvModule(
+ self.in_channels,
+ self.inter_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ act_cfg=None)
+ self.conv_out = ConvModule(
+ self.inter_channels,
+ self.in_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ if self.mode != 'gaussian':
+ self.theta = ConvModule(
+ self.in_channels,
+ self.inter_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ act_cfg=None)
+ self.phi = ConvModule(
+ self.in_channels,
+ self.inter_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ act_cfg=None)
+
+ if self.mode == 'concatenation':
+ self.concat_project = ConvModule(
+ self.inter_channels * 2,
+ 1,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ act_cfg=dict(type='ReLU'))
+
+ self.init_weights(**kwargs)
+
+ def init_weights(self, std=0.01, zeros_init=True):
+ if self.mode != 'gaussian':
+ for m in [self.g, self.theta, self.phi]:
+ normal_init(m.conv, std=std)
+ else:
+ normal_init(self.g.conv, std=std)
+ if zeros_init:
+ if self.conv_out.norm_cfg is None:
+ constant_init(self.conv_out.conv, 0)
+ else:
+ constant_init(self.conv_out.norm, 0)
+ else:
+ if self.conv_out.norm_cfg is None:
+ normal_init(self.conv_out.conv, std=std)
+ else:
+ normal_init(self.conv_out.norm, std=std)
+
+ def gaussian(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
+ return pairwise_weight
+
+ def embedded_gaussian(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ if self.use_scale:
+ # theta_x.shape[-1] is `self.inter_channels`
+ pairwise_weight /= theta_x.shape[-1]**0.5
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
+ return pairwise_weight
+
+ def dot_product(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ pairwise_weight /= pairwise_weight.shape[-1]
+ return pairwise_weight
+
+ def concatenation(self, theta_x, phi_x):
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ h = theta_x.size(2)
+ w = phi_x.size(3)
+ theta_x = theta_x.repeat(1, 1, 1, w)
+ phi_x = phi_x.repeat(1, 1, h, 1)
+
+ concat_feature = torch.cat([theta_x, phi_x], dim=1)
+ pairwise_weight = self.concat_project(concat_feature)
+ n, _, h, w = pairwise_weight.size()
+ pairwise_weight = pairwise_weight.view(n, h, w)
+ pairwise_weight /= pairwise_weight.shape[-1]
+
+ return pairwise_weight
+
+ def forward(self, x):
+ # Assume `reduction = 1`, then `inter_channels = C`
+ # or `inter_channels = C` when `mode="gaussian"`
+
+ # NonLocal1d x: [N, C, H]
+ # NonLocal2d x: [N, C, H, W]
+ # NonLocal3d x: [N, C, T, H, W]
+ n = x.size(0)
+
+ # NonLocal1d g_x: [N, H, C]
+ # NonLocal2d g_x: [N, HxW, C]
+ # NonLocal3d g_x: [N, TxHxW, C]
+ g_x = self.g(x).view(n, self.inter_channels, -1)
+ g_x = g_x.permute(0, 2, 1)
+
+ # NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
+ # NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
+ # NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
+ if self.mode == 'gaussian':
+ theta_x = x.view(n, self.in_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ if self.sub_sample:
+ phi_x = self.phi(x).view(n, self.in_channels, -1)
+ else:
+ phi_x = x.view(n, self.in_channels, -1)
+ elif self.mode == 'concatenation':
+ theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
+ else:
+ theta_x = self.theta(x).view(n, self.inter_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, -1)
+
+ pairwise_func = getattr(self, self.mode)
+ # NonLocal1d pairwise_weight: [N, H, H]
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+ pairwise_weight = pairwise_func(theta_x, phi_x)
+
+ # NonLocal1d y: [N, H, C]
+ # NonLocal2d y: [N, HxW, C]
+ # NonLocal3d y: [N, TxHxW, C]
+ y = torch.matmul(pairwise_weight, g_x)
+ # NonLocal1d y: [N, C, H]
+ # NonLocal2d y: [N, C, H, W]
+ # NonLocal3d y: [N, C, T, H, W]
+ y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
+ *x.size()[2:])
+
+ output = x + self.conv_out(y)
+
+ return output
+
+
+class NonLocal1d(_NonLocalNd):
+ """1D Non-local module.
+
+ Args:
+ in_channels (int): Same as `NonLocalND`.
+ sub_sample (bool): Whether to apply max pooling after pairwise
+ function (Note that the `sub_sample` is applied on spatial only).
+ Default: False.
+ conv_cfg (None | dict): Same as `NonLocalND`.
+ Default: dict(type='Conv1d').
+ """
+
+ def __init__(self,
+ in_channels,
+ sub_sample=False,
+ conv_cfg=dict(type='Conv1d'),
+ **kwargs):
+ super(NonLocal1d, self).__init__(
+ in_channels, conv_cfg=conv_cfg, **kwargs)
+
+ self.sub_sample = sub_sample
+
+ if sub_sample:
+ max_pool_layer = nn.MaxPool1d(kernel_size=2)
+ self.g = nn.Sequential(self.g, max_pool_layer)
+ if self.mode != 'gaussian':
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
+ else:
+ self.phi = max_pool_layer
+
+
+@PLUGIN_LAYERS.register_module()
+class NonLocal2d(_NonLocalNd):
+ """2D Non-local module.
+
+ Args:
+ in_channels (int): Same as `NonLocalND`.
+ sub_sample (bool): Whether to apply max pooling after pairwise
+ function (Note that the `sub_sample` is applied on spatial only).
+ Default: False.
+ conv_cfg (None | dict): Same as `NonLocalND`.
+ Default: dict(type='Conv2d').
+ """
+
+ _abbr_ = 'nonlocal_block'
+
+ def __init__(self,
+ in_channels,
+ sub_sample=False,
+ conv_cfg=dict(type='Conv2d'),
+ **kwargs):
+ super(NonLocal2d, self).__init__(
+ in_channels, conv_cfg=conv_cfg, **kwargs)
+
+ self.sub_sample = sub_sample
+
+ if sub_sample:
+ max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
+ self.g = nn.Sequential(self.g, max_pool_layer)
+ if self.mode != 'gaussian':
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
+ else:
+ self.phi = max_pool_layer
+
+
+class NonLocal3d(_NonLocalNd):
+ """3D Non-local module.
+
+ Args:
+ in_channels (int): Same as `NonLocalND`.
+ sub_sample (bool): Whether to apply max pooling after pairwise
+ function (Note that the `sub_sample` is applied on spatial only).
+ Default: False.
+ conv_cfg (None | dict): Same as `NonLocalND`.
+ Default: dict(type='Conv3d').
+ """
+
+ def __init__(self,
+ in_channels,
+ sub_sample=False,
+ conv_cfg=dict(type='Conv3d'),
+ **kwargs):
+ super(NonLocal3d, self).__init__(
+ in_channels, conv_cfg=conv_cfg, **kwargs)
+ self.sub_sample = sub_sample
+
+ if sub_sample:
+ max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
+ self.g = nn.Sequential(self.g, max_pool_layer)
+ if self.mode != 'gaussian':
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
+ else:
+ self.phi = max_pool_layer
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/norm.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..31f4e49b24080485fc1d85b3e8ff810dc1383c95
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/norm.py
@@ -0,0 +1,144 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+
+import torch.nn as nn
+
+from annotator.mmpkg.mmcv.utils import is_tuple_of
+from annotator.mmpkg.mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm
+from .registry import NORM_LAYERS
+
+NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
+NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
+NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
+NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
+NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm)
+NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
+NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
+NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
+NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d)
+NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d)
+NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d)
+
+
+def infer_abbr(class_type):
+ """Infer abbreviation from the class name.
+
+ When we build a norm layer with `build_norm_layer()`, we want to preserve
+ the norm type in variable names, e.g, self.bn1, self.gn. This method will
+ infer the abbreviation to map class types to abbreviations.
+
+ Rule 1: If the class has the property "_abbr_", return the property.
+ Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
+ InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
+ "in" respectively.
+ Rule 3: If the class name contains "batch", "group", "layer" or "instance",
+ the abbreviation of this layer will be "bn", "gn", "ln" and "in"
+ respectively.
+ Rule 4: Otherwise, the abbreviation falls back to "norm".
+
+ Args:
+ class_type (type): The norm layer type.
+
+ Returns:
+ str: The inferred abbreviation.
+ """
+ if not inspect.isclass(class_type):
+ raise TypeError(
+ f'class_type must be a type, but got {type(class_type)}')
+ if hasattr(class_type, '_abbr_'):
+ return class_type._abbr_
+ if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
+ return 'in'
+ elif issubclass(class_type, _BatchNorm):
+ return 'bn'
+ elif issubclass(class_type, nn.GroupNorm):
+ return 'gn'
+ elif issubclass(class_type, nn.LayerNorm):
+ return 'ln'
+ else:
+ class_name = class_type.__name__.lower()
+ if 'batch' in class_name:
+ return 'bn'
+ elif 'group' in class_name:
+ return 'gn'
+ elif 'layer' in class_name:
+ return 'ln'
+ elif 'instance' in class_name:
+ return 'in'
+ else:
+ return 'norm_layer'
+
+
+def build_norm_layer(cfg, num_features, postfix=''):
+ """Build normalization layer.
+
+ Args:
+ cfg (dict): The norm layer config, which should contain:
+
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate a norm layer.
+ - requires_grad (bool, optional): Whether stop gradient updates.
+ num_features (int): Number of input channels.
+ postfix (int | str): The postfix to be appended into norm abbreviation
+ to create named layer.
+
+ Returns:
+ (str, nn.Module): The first element is the layer name consisting of
+ abbreviation and postfix, e.g., bn1, gn. The second element is the
+ created norm layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in NORM_LAYERS:
+ raise KeyError(f'Unrecognized norm type {layer_type}')
+
+ norm_layer = NORM_LAYERS.get(layer_type)
+ abbr = infer_abbr(norm_layer)
+
+ assert isinstance(postfix, (int, str))
+ name = abbr + str(postfix)
+
+ requires_grad = cfg_.pop('requires_grad', True)
+ cfg_.setdefault('eps', 1e-5)
+ if layer_type != 'GN':
+ layer = norm_layer(num_features, **cfg_)
+ if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
+ layer._specify_ddp_gpu_num(1)
+ else:
+ assert 'num_groups' in cfg_
+ layer = norm_layer(num_channels=num_features, **cfg_)
+
+ for param in layer.parameters():
+ param.requires_grad = requires_grad
+
+ return name, layer
+
+
+def is_norm(layer, exclude=None):
+ """Check if a layer is a normalization layer.
+
+ Args:
+ layer (nn.Module): The layer to be checked.
+ exclude (type | tuple[type]): Types to be excluded.
+
+ Returns:
+ bool: Whether the layer is a norm layer.
+ """
+ if exclude is not None:
+ if not isinstance(exclude, tuple):
+ exclude = (exclude, )
+ if not is_tuple_of(exclude, type):
+ raise TypeError(
+ f'"exclude" must be either None or type or a tuple of types, '
+ f'but got {type(exclude)}: {exclude}')
+
+ if exclude and isinstance(layer, exclude):
+ return False
+
+ all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
+ return isinstance(layer, all_norm_bases)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/padding.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/padding.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4ac6b28a1789bd551c613a7d3e7b622433ac7ec
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/padding.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .registry import PADDING_LAYERS
+
+PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d)
+PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
+PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
+
+
+def build_padding_layer(cfg, *args, **kwargs):
+ """Build padding layer.
+
+ Args:
+ cfg (None or dict): The padding layer config, which should contain:
+ - type (str): Layer type.
+ - layer args: Args needed to instantiate a padding layer.
+
+ Returns:
+ nn.Module: Created padding layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+
+ cfg_ = cfg.copy()
+ padding_type = cfg_.pop('type')
+ if padding_type not in PADDING_LAYERS:
+ raise KeyError(f'Unrecognized padding type {padding_type}.')
+ else:
+ padding_layer = PADDING_LAYERS.get(padding_type)
+
+ layer = padding_layer(*args, **kwargs, **cfg_)
+
+ return layer
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/plugin.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/plugin.py
new file mode 100644
index 0000000000000000000000000000000000000000..07c010d4053174dd41107aa654ea67e82b46a25c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/plugin.py
@@ -0,0 +1,88 @@
+import inspect
+import platform
+
+from .registry import PLUGIN_LAYERS
+
+if platform.system() == 'Windows':
+ import regex as re
+else:
+ import re
+
+
+def infer_abbr(class_type):
+ """Infer abbreviation from the class name.
+
+ This method will infer the abbreviation to map class types to
+ abbreviations.
+
+ Rule 1: If the class has the property "abbr", return the property.
+ Rule 2: Otherwise, the abbreviation falls back to snake case of class
+ name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
+
+ Args:
+ class_type (type): The norm layer type.
+
+ Returns:
+ str: The inferred abbreviation.
+ """
+
+ def camel2snack(word):
+ """Convert camel case word into snack case.
+
+ Modified from `inflection lib
+ `_.
+
+ Example::
+
+ >>> camel2snack("FancyBlock")
+ 'fancy_block'
+ """
+
+ word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
+ word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
+ word = word.replace('-', '_')
+ return word.lower()
+
+ if not inspect.isclass(class_type):
+ raise TypeError(
+ f'class_type must be a type, but got {type(class_type)}')
+ if hasattr(class_type, '_abbr_'):
+ return class_type._abbr_
+ else:
+ return camel2snack(class_type.__name__)
+
+
+def build_plugin_layer(cfg, postfix='', **kwargs):
+ """Build plugin layer.
+
+ Args:
+ cfg (None or dict): cfg should contain:
+ type (str): identify plugin layer type.
+ layer args: args needed to instantiate a plugin layer.
+ postfix (int, str): appended into norm abbreviation to
+ create named layer. Default: ''.
+
+ Returns:
+ tuple[str, nn.Module]:
+ name (str): abbreviation + postfix
+ layer (nn.Module): created plugin layer
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError('cfg must be a dict')
+ if 'type' not in cfg:
+ raise KeyError('the cfg dict must contain the key "type"')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in PLUGIN_LAYERS:
+ raise KeyError(f'Unrecognized plugin type {layer_type}')
+
+ plugin_layer = PLUGIN_LAYERS.get(layer_type)
+ abbr = infer_abbr(plugin_layer)
+
+ assert isinstance(postfix, (int, str))
+ name = abbr + str(postfix)
+
+ layer = plugin_layer(**kwargs, **cfg_)
+
+ return name, layer
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/registry.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f374cca4961c06babf328bb7407723a14026c47
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/registry.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from annotator.mmpkg.mmcv.utils import Registry
+
+CONV_LAYERS = Registry('conv layer')
+NORM_LAYERS = Registry('norm layer')
+ACTIVATION_LAYERS = Registry('activation layer')
+PADDING_LAYERS = Registry('padding layer')
+UPSAMPLE_LAYERS = Registry('upsample layer')
+PLUGIN_LAYERS = Registry('plugin layer')
+
+DROPOUT_LAYERS = Registry('drop out layers')
+POSITIONAL_ENCODING = Registry('position encoding')
+ATTENTION = Registry('attention')
+FEEDFORWARD_NETWORK = Registry('feed-forward Network')
+TRANSFORMER_LAYER = Registry('transformerLayer')
+TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/scale.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..c905fffcc8bf998d18d94f927591963c428025e2
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/scale.py
@@ -0,0 +1,21 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+
+class Scale(nn.Module):
+ """A learnable scale parameter.
+
+ This layer scales the input by a learnable factor. It multiplies a
+ learnable scale parameter of shape (1,) with input of any shape.
+
+ Args:
+ scale (float): Initial value of scale factor. Default: 1.0
+ """
+
+ def __init__(self, scale=1.0):
+ super(Scale, self).__init__()
+ self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
+
+ def forward(self, x):
+ return x * self.scale
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/swish.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/swish.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2ca8ed7b749413f011ae54aac0cab27e6f0b51f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/swish.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from .registry import ACTIVATION_LAYERS
+
+
+@ACTIVATION_LAYERS.register_module()
+class Swish(nn.Module):
+ """Swish Module.
+
+ This module applies the swish function:
+
+ .. math::
+ Swish(x) = x * Sigmoid(x)
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self):
+ super(Swish, self).__init__()
+
+ def forward(self, x):
+ return x * torch.sigmoid(x)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/transformer.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e16707142b645144b676059ffa992fc4306ef778
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/transformer.py
@@ -0,0 +1,595 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+
+import torch
+import torch.nn as nn
+
+from annotator.mmpkg.mmcv import ConfigDict, deprecated_api_warning
+from annotator.mmpkg.mmcv.cnn import Linear, build_activation_layer, build_norm_layer
+from annotator.mmpkg.mmcv.runner.base_module import BaseModule, ModuleList, Sequential
+from annotator.mmpkg.mmcv.utils import build_from_cfg
+from .drop import build_dropout
+from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
+ TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
+
+# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
+try:
+ from annotator.mmpkg.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
+ warnings.warn(
+ ImportWarning(
+ '``MultiScaleDeformableAttention`` has been moved to '
+ '``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501
+ '``from annotator.mmpkg.mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501
+ 'to ``from annotator.mmpkg.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501
+ ))
+
+except ImportError:
+ warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
+ '``mmcv.ops.multi_scale_deform_attn``, '
+ 'You should install ``mmcv-full`` if you need this module. ')
+
+
+def build_positional_encoding(cfg, default_args=None):
+ """Builder for Position Encoding."""
+ return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
+
+
+def build_attention(cfg, default_args=None):
+ """Builder for attention."""
+ return build_from_cfg(cfg, ATTENTION, default_args)
+
+
+def build_feedforward_network(cfg, default_args=None):
+ """Builder for feed-forward network (FFN)."""
+ return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
+
+
+def build_transformer_layer(cfg, default_args=None):
+ """Builder for transformer layer."""
+ return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
+
+
+def build_transformer_layer_sequence(cfg, default_args=None):
+ """Builder for transformer encoder and transformer decoder."""
+ return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
+
+
+@ATTENTION.register_module()
+class MultiheadAttention(BaseModule):
+ """A wrapper for ``torch.nn.MultiheadAttention``.
+
+ This module implements MultiheadAttention with identity connection,
+ and positional encoding is also passed as input.
+
+ Args:
+ embed_dims (int): The embedding dimension.
+ num_heads (int): Parallel attention heads.
+ attn_drop (float): A Dropout layer on attn_output_weights.
+ Default: 0.0.
+ proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
+ Default: 0.0.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ batch_first (bool): When it is True, Key, Query and Value are shape of
+ (batch, n, embed_dim), otherwise (n, batch, embed_dim).
+ Default to False.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ attn_drop=0.,
+ proj_drop=0.,
+ dropout_layer=dict(type='Dropout', drop_prob=0.),
+ init_cfg=None,
+ batch_first=False,
+ **kwargs):
+ super(MultiheadAttention, self).__init__(init_cfg)
+ if 'dropout' in kwargs:
+ warnings.warn('The arguments `dropout` in MultiheadAttention '
+ 'has been deprecated, now you can separately '
+ 'set `attn_drop`(float), proj_drop(float), '
+ 'and `dropout_layer`(dict) ')
+ attn_drop = kwargs['dropout']
+ dropout_layer['drop_prob'] = kwargs.pop('dropout')
+
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.batch_first = batch_first
+
+ self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
+ **kwargs)
+
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else nn.Identity()
+
+ @deprecated_api_warning({'residual': 'identity'},
+ cls_name='MultiheadAttention')
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ identity=None,
+ query_pos=None,
+ key_pos=None,
+ attn_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `MultiheadAttention`.
+
+ **kwargs allow passing a more general data flow when combining
+ with other operations in `transformerlayer`.
+
+ Args:
+ query (Tensor): The input query with shape [num_queries, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ key (Tensor): The key tensor with shape [num_keys, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_keys, embed_dims] .
+ If None, the ``query`` will be used. Defaults to None.
+ value (Tensor): The value tensor with same shape as `key`.
+ Same in `nn.MultiheadAttention.forward`. Defaults to None.
+ If None, the `key` will be used.
+ identity (Tensor): This tensor, with the same shape as x,
+ will be used for the identity link.
+ If None, `x` will be used. Defaults to None.
+ query_pos (Tensor): The positional encoding for query, with
+ the same shape as `x`. If not None, it will
+ be added to `x` before forward function. Defaults to None.
+ key_pos (Tensor): The positional encoding for `key`, with the
+ same shape as `key`. Defaults to None. If not None, it will
+ be added to `key` before forward function. If None, and
+ `query_pos` has the same shape as `key`, then `query_pos`
+ will be used for `key_pos`. Defaults to None.
+ attn_mask (Tensor): ByteTensor mask with shape [num_queries,
+ num_keys]. Same in `nn.MultiheadAttention.forward`.
+ Defaults to None.
+ key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
+ Defaults to None.
+
+ Returns:
+ Tensor: forwarded results with shape
+ [num_queries, bs, embed_dims]
+ if self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ """
+
+ if key is None:
+ key = query
+ if value is None:
+ value = key
+ if identity is None:
+ identity = query
+ if key_pos is None:
+ if query_pos is not None:
+ # use query_pos if key_pos is not available
+ if query_pos.shape == key.shape:
+ key_pos = query_pos
+ else:
+ warnings.warn(f'position encoding of key is'
+ f'missing in {self.__class__.__name__}.')
+ if query_pos is not None:
+ query = query + query_pos
+ if key_pos is not None:
+ key = key + key_pos
+
+ # Because the dataflow('key', 'query', 'value') of
+ # ``torch.nn.MultiheadAttention`` is (num_query, batch,
+ # embed_dims), We should adjust the shape of dataflow from
+ # batch_first (batch, num_query, embed_dims) to num_query_first
+ # (num_query ,batch, embed_dims), and recover ``attn_output``
+ # from num_query_first to batch_first.
+ if self.batch_first:
+ query = query.transpose(0, 1)
+ key = key.transpose(0, 1)
+ value = value.transpose(0, 1)
+
+ out = self.attn(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask)[0]
+
+ if self.batch_first:
+ out = out.transpose(0, 1)
+
+ return identity + self.dropout_layer(self.proj_drop(out))
+
+
+@FEEDFORWARD_NETWORK.register_module()
+class FFN(BaseModule):
+ """Implements feed-forward networks (FFNs) with identity connection.
+
+ Args:
+ embed_dims (int): The feature dimension. Same as
+ `MultiheadAttention`. Defaults: 256.
+ feedforward_channels (int): The hidden dimension of FFNs.
+ Defaults: 1024.
+ num_fcs (int, optional): The number of fully-connected layers in
+ FFNs. Default: 2.
+ act_cfg (dict, optional): The activation config for FFNs.
+ Default: dict(type='ReLU')
+ ffn_drop (float, optional): Probability of an element to be
+ zeroed in FFN. Default 0.0.
+ add_identity (bool, optional): Whether to add the
+ identity connection. Default: `True`.
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ @deprecated_api_warning(
+ {
+ 'dropout': 'ffn_drop',
+ 'add_residual': 'add_identity'
+ },
+ cls_name='FFN')
+ def __init__(self,
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ffn_drop=0.,
+ dropout_layer=None,
+ add_identity=True,
+ init_cfg=None,
+ **kwargs):
+ super(FFN, self).__init__(init_cfg)
+ assert num_fcs >= 2, 'num_fcs should be no less ' \
+ f'than 2. got {num_fcs}.'
+ self.embed_dims = embed_dims
+ self.feedforward_channels = feedforward_channels
+ self.num_fcs = num_fcs
+ self.act_cfg = act_cfg
+ self.activate = build_activation_layer(act_cfg)
+
+ layers = []
+ in_channels = embed_dims
+ for _ in range(num_fcs - 1):
+ layers.append(
+ Sequential(
+ Linear(in_channels, feedforward_channels), self.activate,
+ nn.Dropout(ffn_drop)))
+ in_channels = feedforward_channels
+ layers.append(Linear(feedforward_channels, embed_dims))
+ layers.append(nn.Dropout(ffn_drop))
+ self.layers = Sequential(*layers)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else torch.nn.Identity()
+ self.add_identity = add_identity
+
+ @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
+ def forward(self, x, identity=None):
+ """Forward function for `FFN`.
+
+ The function would add x to the output tensor if residue is None.
+ """
+ out = self.layers(x)
+ if not self.add_identity:
+ return self.dropout_layer(out)
+ if identity is None:
+ identity = x
+ return identity + self.dropout_layer(out)
+
+
+@TRANSFORMER_LAYER.register_module()
+class BaseTransformerLayer(BaseModule):
+ """Base `TransformerLayer` for vision transformer.
+
+ It can be built from `mmcv.ConfigDict` and support more flexible
+ customization, for example, using any number of `FFN or LN ` and
+ use different kinds of `attention` by specifying a list of `ConfigDict`
+ named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
+ when you specifying `norm` as the first element of `operation_order`.
+ More details about the `prenorm`: `On Layer Normalization in the
+ Transformer Architecture `_ .
+
+ Args:
+ attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
+ Configs for `self_attention` or `cross_attention` modules,
+ The order of the configs in the list should be consistent with
+ corresponding attentions in operation_order.
+ If it is a dict, all of the attention modules in operation_order
+ will be built with this config. Default: None.
+ ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
+ Configs for FFN, The order of the configs in the list should be
+ consistent with corresponding ffn in operation_order.
+ If it is a dict, all of the attention modules in operation_order
+ will be built with this config.
+ operation_order (tuple[str]): The execution order of operation
+ in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
+ Support `prenorm` when you specifying first element as `norm`.
+ Default:None.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ batch_first (bool): Key, Query and Value are shape
+ of (batch, n, embed_dim)
+ or (n, batch, embed_dim). Default to False.
+ """
+
+ def __init__(self,
+ attn_cfgs=None,
+ ffn_cfgs=dict(
+ type='FFN',
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ ffn_drop=0.,
+ act_cfg=dict(type='ReLU', inplace=True),
+ ),
+ operation_order=None,
+ norm_cfg=dict(type='LN'),
+ init_cfg=None,
+ batch_first=False,
+ **kwargs):
+
+ deprecated_args = dict(
+ feedforward_channels='feedforward_channels',
+ ffn_dropout='ffn_drop',
+ ffn_num_fcs='num_fcs')
+ for ori_name, new_name in deprecated_args.items():
+ if ori_name in kwargs:
+ warnings.warn(
+ f'The arguments `{ori_name}` in BaseTransformerLayer '
+ f'has been deprecated, now you should set `{new_name}` '
+ f'and other FFN related arguments '
+ f'to a dict named `ffn_cfgs`. ')
+ ffn_cfgs[new_name] = kwargs[ori_name]
+
+ super(BaseTransformerLayer, self).__init__(init_cfg)
+
+ self.batch_first = batch_first
+
+ assert set(operation_order) & set(
+ ['self_attn', 'norm', 'ffn', 'cross_attn']) == \
+ set(operation_order), f'The operation_order of' \
+ f' {self.__class__.__name__} should ' \
+ f'contains all four operation type ' \
+ f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
+
+ num_attn = operation_order.count('self_attn') + operation_order.count(
+ 'cross_attn')
+ if isinstance(attn_cfgs, dict):
+ attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
+ else:
+ assert num_attn == len(attn_cfgs), f'The length ' \
+ f'of attn_cfg {num_attn} is ' \
+ f'not consistent with the number of attention' \
+ f'in operation_order {operation_order}.'
+
+ self.num_attn = num_attn
+ self.operation_order = operation_order
+ self.norm_cfg = norm_cfg
+ self.pre_norm = operation_order[0] == 'norm'
+ self.attentions = ModuleList()
+
+ index = 0
+ for operation_name in operation_order:
+ if operation_name in ['self_attn', 'cross_attn']:
+ if 'batch_first' in attn_cfgs[index]:
+ assert self.batch_first == attn_cfgs[index]['batch_first']
+ else:
+ attn_cfgs[index]['batch_first'] = self.batch_first
+ attention = build_attention(attn_cfgs[index])
+ # Some custom attentions used as `self_attn`
+ # or `cross_attn` can have different behavior.
+ attention.operation_name = operation_name
+ self.attentions.append(attention)
+ index += 1
+
+ self.embed_dims = self.attentions[0].embed_dims
+
+ self.ffns = ModuleList()
+ num_ffns = operation_order.count('ffn')
+ if isinstance(ffn_cfgs, dict):
+ ffn_cfgs = ConfigDict(ffn_cfgs)
+ if isinstance(ffn_cfgs, dict):
+ ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
+ assert len(ffn_cfgs) == num_ffns
+ for ffn_index in range(num_ffns):
+ if 'embed_dims' not in ffn_cfgs[ffn_index]:
+ ffn_cfgs['embed_dims'] = self.embed_dims
+ else:
+ assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
+ self.ffns.append(
+ build_feedforward_network(ffn_cfgs[ffn_index],
+ dict(type='FFN')))
+
+ self.norms = ModuleList()
+ num_norms = operation_order.count('norm')
+ for _ in range(num_norms):
+ self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
+
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ query_pos=None,
+ key_pos=None,
+ attn_masks=None,
+ query_key_padding_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `TransformerDecoderLayer`.
+
+ **kwargs contains some specific arguments of attentions.
+
+ Args:
+ query (Tensor): The input query with shape
+ [num_queries, bs, embed_dims] if
+ self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ key (Tensor): The key tensor with shape [num_keys, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_keys, embed_dims] .
+ value (Tensor): The value tensor with same shape as `key`.
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ key_pos (Tensor): The positional encoding for `key`.
+ Default: None.
+ attn_masks (List[Tensor] | None): 2D Tensor used in
+ calculation of corresponding attention. The length of
+ it should equal to the number of `attention` in
+ `operation_order`. Default: None.
+ query_key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_queries]. Only used in `self_attn` layer.
+ Defaults to None.
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_keys]. Default: None.
+
+ Returns:
+ Tensor: forwarded results with shape [num_queries, bs, embed_dims].
+ """
+
+ norm_index = 0
+ attn_index = 0
+ ffn_index = 0
+ identity = query
+ if attn_masks is None:
+ attn_masks = [None for _ in range(self.num_attn)]
+ elif isinstance(attn_masks, torch.Tensor):
+ attn_masks = [
+ copy.deepcopy(attn_masks) for _ in range(self.num_attn)
+ ]
+ warnings.warn(f'Use same attn_mask in all attentions in '
+ f'{self.__class__.__name__} ')
+ else:
+ assert len(attn_masks) == self.num_attn, f'The length of ' \
+ f'attn_masks {len(attn_masks)} must be equal ' \
+ f'to the number of attention in ' \
+ f'operation_order {self.num_attn}'
+
+ for layer in self.operation_order:
+ if layer == 'self_attn':
+ temp_key = temp_value = query
+ query = self.attentions[attn_index](
+ query,
+ temp_key,
+ temp_value,
+ identity if self.pre_norm else None,
+ query_pos=query_pos,
+ key_pos=query_pos,
+ attn_mask=attn_masks[attn_index],
+ key_padding_mask=query_key_padding_mask,
+ **kwargs)
+ attn_index += 1
+ identity = query
+
+ elif layer == 'norm':
+ query = self.norms[norm_index](query)
+ norm_index += 1
+
+ elif layer == 'cross_attn':
+ query = self.attentions[attn_index](
+ query,
+ key,
+ value,
+ identity if self.pre_norm else None,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ attn_mask=attn_masks[attn_index],
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ attn_index += 1
+ identity = query
+
+ elif layer == 'ffn':
+ query = self.ffns[ffn_index](
+ query, identity if self.pre_norm else None)
+ ffn_index += 1
+
+ return query
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class TransformerLayerSequence(BaseModule):
+ """Base class for TransformerEncoder and TransformerDecoder in vision
+ transformer.
+
+ As base-class of Encoder and Decoder in vision transformer.
+ Support customization such as specifying different kind
+ of `transformer_layer` in `transformer_coder`.
+
+ Args:
+ transformerlayer (list[obj:`mmcv.ConfigDict`] |
+ obj:`mmcv.ConfigDict`): Config of transformerlayer
+ in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
+ it would be repeated `num_layer` times to a
+ list[`mmcv.ConfigDict`]. Default: None.
+ num_layers (int): The number of `TransformerLayer`. Default: None.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
+ super(TransformerLayerSequence, self).__init__(init_cfg)
+ if isinstance(transformerlayers, dict):
+ transformerlayers = [
+ copy.deepcopy(transformerlayers) for _ in range(num_layers)
+ ]
+ else:
+ assert isinstance(transformerlayers, list) and \
+ len(transformerlayers) == num_layers
+ self.num_layers = num_layers
+ self.layers = ModuleList()
+ for i in range(num_layers):
+ self.layers.append(build_transformer_layer(transformerlayers[i]))
+ self.embed_dims = self.layers[0].embed_dims
+ self.pre_norm = self.layers[0].pre_norm
+
+ def forward(self,
+ query,
+ key,
+ value,
+ query_pos=None,
+ key_pos=None,
+ attn_masks=None,
+ query_key_padding_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `TransformerCoder`.
+
+ Args:
+ query (Tensor): Input query with shape
+ `(num_queries, bs, embed_dims)`.
+ key (Tensor): The key tensor with shape
+ `(num_keys, bs, embed_dims)`.
+ value (Tensor): The value tensor with shape
+ `(num_keys, bs, embed_dims)`.
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ key_pos (Tensor): The positional encoding for `key`.
+ Default: None.
+ attn_masks (List[Tensor], optional): Each element is 2D Tensor
+ which is used in calculation of corresponding attention in
+ operation_order. Default: None.
+ query_key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_queries]. Only used in self-attention
+ Default: None.
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_keys]. Default: None.
+
+ Returns:
+ Tensor: results with shape [num_queries, bs, embed_dims].
+ """
+ for layer in self.layers:
+ query = layer(
+ query,
+ key,
+ value,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ attn_masks=attn_masks,
+ query_key_padding_mask=query_key_padding_mask,
+ key_padding_mask=key_padding_mask,
+ **kwargs)
+ return query
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/upsample.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/upsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1a353767d0ce8518f0d7289bed10dba0178ed12
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/upsample.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..utils import xavier_init
+from .registry import UPSAMPLE_LAYERS
+
+UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
+UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
+
+
+@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
+class PixelShufflePack(nn.Module):
+ """Pixel Shuffle upsample layer.
+
+ This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to
+ achieve a simple upsampling with pixel shuffle.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ scale_factor (int): Upsample ratio.
+ upsample_kernel (int): Kernel size of the conv layer to expand the
+ channels.
+ """
+
+ def __init__(self, in_channels, out_channels, scale_factor,
+ upsample_kernel):
+ super(PixelShufflePack, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.scale_factor = scale_factor
+ self.upsample_kernel = upsample_kernel
+ self.upsample_conv = nn.Conv2d(
+ self.in_channels,
+ self.out_channels * scale_factor * scale_factor,
+ self.upsample_kernel,
+ padding=(self.upsample_kernel - 1) // 2)
+ self.init_weights()
+
+ def init_weights(self):
+ xavier_init(self.upsample_conv, distribution='uniform')
+
+ def forward(self, x):
+ x = self.upsample_conv(x)
+ x = F.pixel_shuffle(x, self.scale_factor)
+ return x
+
+
+def build_upsample_layer(cfg, *args, **kwargs):
+ """Build upsample layer.
+
+ Args:
+ cfg (dict): The upsample layer config, which should contain:
+
+ - type (str): Layer type.
+ - scale_factor (int): Upsample ratio, which is not applicable to
+ deconv.
+ - layer args: Args needed to instantiate a upsample layer.
+ args (argument list): Arguments passed to the ``__init__``
+ method of the corresponding conv layer.
+ kwargs (keyword arguments): Keyword arguments passed to the
+ ``__init__`` method of the corresponding conv layer.
+
+ Returns:
+ nn.Module: Created upsample layer.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+ if 'type' not in cfg:
+ raise KeyError(
+ f'the cfg dict must contain the key "type", but got {cfg}')
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop('type')
+ if layer_type not in UPSAMPLE_LAYERS:
+ raise KeyError(f'Unrecognized upsample type {layer_type}')
+ else:
+ upsample = UPSAMPLE_LAYERS.get(layer_type)
+
+ if upsample is nn.Upsample:
+ cfg_['mode'] = layer_type
+ layer = upsample(*args, **kwargs, **cfg_)
+ return layer
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/wrappers.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aebf67bf52355a513f21756ee74fe510902d075
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/bricks/wrappers.py
@@ -0,0 +1,180 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py # noqa: E501
+
+Wrap some nn modules to support empty tensor input. Currently, these wrappers
+are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask
+heads are trained on only positive RoIs.
+"""
+import math
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.utils import _pair, _triple
+
+from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
+
+if torch.__version__ == 'parrots':
+ TORCH_VERSION = torch.__version__
+else:
+ # torch.__version__ could be 1.3.1+cu92, we only need the first two
+ # for comparison
+ TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
+
+
+def obsolete_torch_version(torch_version, version_threshold):
+ return torch_version == 'parrots' or torch_version <= version_threshold
+
+
+class NewEmptyTensorOp(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, x, new_shape):
+ ctx.shape = x.shape
+ return x.new_empty(new_shape)
+
+ @staticmethod
+ def backward(ctx, grad):
+ shape = ctx.shape
+ return NewEmptyTensorOp.apply(grad, shape), None
+
+
+@CONV_LAYERS.register_module('Conv', force=True)
+class Conv2d(nn.Conv2d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
+ self.padding, self.stride, self.dilation):
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+@CONV_LAYERS.register_module('Conv3d', force=True)
+class Conv3d(nn.Conv3d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
+ self.padding, self.stride, self.dilation):
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+@CONV_LAYERS.register_module()
+@CONV_LAYERS.register_module('deconv')
+@UPSAMPLE_LAYERS.register_module('deconv', force=True)
+class ConvTranspose2d(nn.ConvTranspose2d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
+ self.padding, self.stride,
+ self.dilation, self.output_padding):
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+@CONV_LAYERS.register_module()
+@CONV_LAYERS.register_module('deconv3d')
+@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
+class ConvTranspose3d(nn.ConvTranspose3d):
+
+ def forward(self, x):
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+ out_shape = [x.shape[0], self.out_channels]
+ for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
+ self.padding, self.stride,
+ self.dilation, self.output_padding):
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
+
+
+class MaxPool2d(nn.MaxPool2d):
+
+ def forward(self, x):
+ # PyTorch 1.9 does not support empty tensor inference yet
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+ out_shape = list(x.shape[:2])
+ for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
+ _pair(self.padding), _pair(self.stride),
+ _pair(self.dilation)):
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ return empty
+
+ return super().forward(x)
+
+
+class MaxPool3d(nn.MaxPool3d):
+
+ def forward(self, x):
+ # PyTorch 1.9 does not support empty tensor inference yet
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+ out_shape = list(x.shape[:2])
+ for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
+ _triple(self.padding),
+ _triple(self.stride),
+ _triple(self.dilation)):
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
+ out_shape.append(o)
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ return empty
+
+ return super().forward(x)
+
+
+class Linear(torch.nn.Linear):
+
+ def forward(self, x):
+ # empty tensor forward of Linear layer is supported in Pytorch 1.6
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
+ out_shape = [x.shape[0], self.out_features]
+ empty = NewEmptyTensorOp.apply(x, out_shape)
+ if self.training:
+ # produce dummy gradient to avoid DDP warning.
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+ return empty + dummy
+ else:
+ return empty
+
+ return super().forward(x)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/builder.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7567316c566bd3aca6d8f65a84b00e9e890948a7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/builder.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..runner import Sequential
+from ..utils import Registry, build_from_cfg
+
+
+def build_model_from_cfg(cfg, registry, default_args=None):
+ """Build a PyTorch model from config dict(s). Different from
+ ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
+
+ Args:
+ cfg (dict, list[dict]): The config of modules, is is either a config
+ dict or a list of config dicts. If cfg is a list, a
+ the built modules will be wrapped with ``nn.Sequential``.
+ registry (:obj:`Registry`): A registry the module belongs to.
+ default_args (dict, optional): Default arguments to build the module.
+ Defaults to None.
+
+ Returns:
+ nn.Module: A built nn module.
+ """
+ if isinstance(cfg, list):
+ modules = [
+ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
+ ]
+ return Sequential(*modules)
+ else:
+ return build_from_cfg(cfg, registry, default_args)
+
+
+MODELS = Registry('model', build_func=build_model_from_cfg)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/resnet.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cb3ac057ee2d52c46fc94685b5d4e698aad8d5f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/resnet.py
@@ -0,0 +1,316 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+
+from .utils import constant_init, kaiming_init
+
+
+def conv3x3(in_planes, out_planes, stride=1, dilation=1):
+ """3x3 convolution with padding."""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False):
+ super(BasicBlock, self).__init__()
+ assert style in ['pytorch', 'caffe']
+ self.conv1 = conv3x3(inplanes, planes, stride, dilation)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ assert not with_cp
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False):
+ """Bottleneck block.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super(Bottleneck, self).__init__()
+ assert style in ['pytorch', 'caffe']
+ if style == 'pytorch':
+ conv1_stride = 1
+ conv2_stride = stride
+ else:
+ conv1_stride = stride
+ conv2_stride = 1
+ self.conv1 = nn.Conv2d(
+ inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False)
+ self.conv2 = nn.Conv2d(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(
+ planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+def make_res_layer(block,
+ inplanes,
+ planes,
+ blocks,
+ stride=1,
+ dilation=1,
+ style='pytorch',
+ with_cp=False):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ stride,
+ dilation,
+ downsample,
+ style=style,
+ with_cp=with_cp))
+ inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp))
+
+ return nn.Sequential(*layers)
+
+
+class ResNet(nn.Module):
+ """ResNet backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ num_stages (int): Resnet stages, normally 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
+ running stats (mean and var).
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ """
+
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ depth,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3),
+ style='pytorch',
+ frozen_stages=-1,
+ bn_eval=True,
+ bn_frozen=False,
+ with_cp=False):
+ super(ResNet, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+ assert num_stages >= 1 and num_stages <= 4
+ block, stage_blocks = self.arch_settings[depth]
+ stage_blocks = stage_blocks[:num_stages]
+ assert len(strides) == len(dilations) == num_stages
+ assert max(out_indices) < num_stages
+
+ self.out_indices = out_indices
+ self.style = style
+ self.frozen_stages = frozen_stages
+ self.bn_eval = bn_eval
+ self.bn_frozen = bn_frozen
+ self.with_cp = with_cp
+
+ self.inplanes = 64
+ self.conv1 = nn.Conv2d(
+ 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.res_layers = []
+ for i, num_blocks in enumerate(stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ planes = 64 * 2**i
+ res_layer = make_res_layer(
+ block,
+ self.inplanes,
+ planes,
+ num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ with_cp=with_cp)
+ self.inplanes = planes * block.expansion
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super(ResNet, self).train(mode)
+ if self.bn_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ if self.bn_frozen:
+ for params in m.parameters():
+ params.requires_grad = False
+ if mode and self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for param in self.bn1.parameters():
+ param.requires_grad = False
+ self.bn1.eval()
+ self.bn1.weight.requires_grad = False
+ self.bn1.bias.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ mod = getattr(self, f'layer{i}')
+ mod.eval()
+ for param in mod.parameters():
+ param.requires_grad = False
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/utils/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a263e31c1e3977712827ca229bbc04910b4e928e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/utils/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .flops_counter import get_model_complexity_info
+from .fuse_conv_bn import fuse_conv_bn
+from .sync_bn import revert_sync_batchnorm
+from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
+ KaimingInit, NormalInit, PretrainedInit,
+ TruncNormalInit, UniformInit, XavierInit,
+ bias_init_with_prob, caffe2_xavier_init,
+ constant_init, initialize, kaiming_init, normal_init,
+ trunc_normal_init, uniform_init, xavier_init)
+
+__all__ = [
+ 'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
+ 'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
+ 'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
+ 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+ 'Caffe2XavierInit', 'revert_sync_batchnorm'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/utils/flops_counter.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/utils/flops_counter.py
new file mode 100644
index 0000000000000000000000000000000000000000..104240bfa524af727782ceb781147c5815529ee6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/utils/flops_counter.py
@@ -0,0 +1,599 @@
+# Modified from flops-counter.pytorch by Vladislav Sovrasov
+# original repo: https://github.com/sovrasov/flops-counter.pytorch
+
+# MIT License
+
+# Copyright (c) 2018 Vladislav Sovrasov
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import sys
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+import annotator.mmpkg.mmcv as mmcv
+
+
+def get_model_complexity_info(model,
+ input_shape,
+ print_per_layer_stat=True,
+ as_strings=True,
+ input_constructor=None,
+ flush=False,
+ ost=sys.stdout):
+ """Get complexity information of a model.
+
+ This method can calculate FLOPs and parameter counts of a model with
+ corresponding input shape. It can also print complexity information for
+ each layer in a model.
+
+ Supported layers are listed as below:
+ - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
+ - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
+ ``nn.ReLU6``.
+ - Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
+ ``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
+ ``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
+ ``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
+ ``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
+ - BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
+ ``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``,
+ ``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``.
+ - Linear: ``nn.Linear``.
+ - Deconvolution: ``nn.ConvTranspose2d``.
+ - Upsample: ``nn.Upsample``.
+
+ Args:
+ model (nn.Module): The model for complexity calculation.
+ input_shape (tuple): Input shape used for calculation.
+ print_per_layer_stat (bool): Whether to print complexity information
+ for each layer in a model. Default: True.
+ as_strings (bool): Output FLOPs and params counts in a string form.
+ Default: True.
+ input_constructor (None | callable): If specified, it takes a callable
+ method that generates input. otherwise, it will generate a random
+ tensor with input shape to calculate FLOPs. Default: None.
+ flush (bool): same as that in :func:`print`. Default: False.
+ ost (stream): same as ``file`` param in :func:`print`.
+ Default: sys.stdout.
+
+ Returns:
+ tuple[float | str]: If ``as_strings`` is set to True, it will return
+ FLOPs and parameter counts in a string format. otherwise, it will
+ return those in a float number format.
+ """
+ assert type(input_shape) is tuple
+ assert len(input_shape) >= 1
+ assert isinstance(model, nn.Module)
+ flops_model = add_flops_counting_methods(model)
+ flops_model.eval()
+ flops_model.start_flops_count()
+ if input_constructor:
+ input = input_constructor(input_shape)
+ _ = flops_model(**input)
+ else:
+ try:
+ batch = torch.ones(()).new_empty(
+ (1, *input_shape),
+ dtype=next(flops_model.parameters()).dtype,
+ device=next(flops_model.parameters()).device)
+ except StopIteration:
+ # Avoid StopIteration for models which have no parameters,
+ # like `nn.Relu()`, `nn.AvgPool2d`, etc.
+ batch = torch.ones(()).new_empty((1, *input_shape))
+
+ _ = flops_model(batch)
+
+ flops_count, params_count = flops_model.compute_average_flops_cost()
+ if print_per_layer_stat:
+ print_model_with_flops(
+ flops_model, flops_count, params_count, ost=ost, flush=flush)
+ flops_model.stop_flops_count()
+
+ if as_strings:
+ return flops_to_string(flops_count), params_to_string(params_count)
+
+ return flops_count, params_count
+
+
+def flops_to_string(flops, units='GFLOPs', precision=2):
+ """Convert FLOPs number into a string.
+
+ Note that Here we take a multiply-add counts as one FLOP.
+
+ Args:
+ flops (float): FLOPs number to be converted.
+ units (str | None): Converted FLOPs units. Options are None, 'GFLOPs',
+ 'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically
+ choose the most suitable unit for FLOPs. Default: 'GFLOPs'.
+ precision (int): Digit number after the decimal point. Default: 2.
+
+ Returns:
+ str: The converted FLOPs number with units.
+
+ Examples:
+ >>> flops_to_string(1e9)
+ '1.0 GFLOPs'
+ >>> flops_to_string(2e5, 'MFLOPs')
+ '0.2 MFLOPs'
+ >>> flops_to_string(3e-9, None)
+ '3e-09 FLOPs'
+ """
+ if units is None:
+ if flops // 10**9 > 0:
+ return str(round(flops / 10.**9, precision)) + ' GFLOPs'
+ elif flops // 10**6 > 0:
+ return str(round(flops / 10.**6, precision)) + ' MFLOPs'
+ elif flops // 10**3 > 0:
+ return str(round(flops / 10.**3, precision)) + ' KFLOPs'
+ else:
+ return str(flops) + ' FLOPs'
+ else:
+ if units == 'GFLOPs':
+ return str(round(flops / 10.**9, precision)) + ' ' + units
+ elif units == 'MFLOPs':
+ return str(round(flops / 10.**6, precision)) + ' ' + units
+ elif units == 'KFLOPs':
+ return str(round(flops / 10.**3, precision)) + ' ' + units
+ else:
+ return str(flops) + ' FLOPs'
+
+
+def params_to_string(num_params, units=None, precision=2):
+ """Convert parameter number into a string.
+
+ Args:
+ num_params (float): Parameter number to be converted.
+ units (str | None): Converted FLOPs units. Options are None, 'M',
+ 'K' and ''. If set to None, it will automatically choose the most
+ suitable unit for Parameter number. Default: None.
+ precision (int): Digit number after the decimal point. Default: 2.
+
+ Returns:
+ str: The converted parameter number with units.
+
+ Examples:
+ >>> params_to_string(1e9)
+ '1000.0 M'
+ >>> params_to_string(2e5)
+ '200.0 k'
+ >>> params_to_string(3e-9)
+ '3e-09'
+ """
+ if units is None:
+ if num_params // 10**6 > 0:
+ return str(round(num_params / 10**6, precision)) + ' M'
+ elif num_params // 10**3:
+ return str(round(num_params / 10**3, precision)) + ' k'
+ else:
+ return str(num_params)
+ else:
+ if units == 'M':
+ return str(round(num_params / 10.**6, precision)) + ' ' + units
+ elif units == 'K':
+ return str(round(num_params / 10.**3, precision)) + ' ' + units
+ else:
+ return str(num_params)
+
+
+def print_model_with_flops(model,
+ total_flops,
+ total_params,
+ units='GFLOPs',
+ precision=3,
+ ost=sys.stdout,
+ flush=False):
+ """Print a model with FLOPs for each layer.
+
+ Args:
+ model (nn.Module): The model to be printed.
+ total_flops (float): Total FLOPs of the model.
+ total_params (float): Total parameter counts of the model.
+ units (str | None): Converted FLOPs units. Default: 'GFLOPs'.
+ precision (int): Digit number after the decimal point. Default: 3.
+ ost (stream): same as `file` param in :func:`print`.
+ Default: sys.stdout.
+ flush (bool): same as that in :func:`print`. Default: False.
+
+ Example:
+ >>> class ExampleModel(nn.Module):
+
+ >>> def __init__(self):
+ >>> super().__init__()
+ >>> self.conv1 = nn.Conv2d(3, 8, 3)
+ >>> self.conv2 = nn.Conv2d(8, 256, 3)
+ >>> self.conv3 = nn.Conv2d(256, 8, 3)
+ >>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
+ >>> self.flatten = nn.Flatten()
+ >>> self.fc = nn.Linear(8, 1)
+
+ >>> def forward(self, x):
+ >>> x = self.conv1(x)
+ >>> x = self.conv2(x)
+ >>> x = self.conv3(x)
+ >>> x = self.avg_pool(x)
+ >>> x = self.flatten(x)
+ >>> x = self.fc(x)
+ >>> return x
+
+ >>> model = ExampleModel()
+ >>> x = (3, 16, 16)
+ to print the complexity information state for each layer, you can use
+ >>> get_model_complexity_info(model, x)
+ or directly use
+ >>> print_model_with_flops(model, 4579784.0, 37361)
+ ExampleModel(
+ 0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs,
+ (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501
+ (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
+ (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
+ (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
+ (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
+ (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
+ )
+ """
+
+ def accumulate_params(self):
+ if is_supported_instance(self):
+ return self.__params__
+ else:
+ sum = 0
+ for m in self.children():
+ sum += m.accumulate_params()
+ return sum
+
+ def accumulate_flops(self):
+ if is_supported_instance(self):
+ return self.__flops__ / model.__batch_counter__
+ else:
+ sum = 0
+ for m in self.children():
+ sum += m.accumulate_flops()
+ return sum
+
+ def flops_repr(self):
+ accumulated_num_params = self.accumulate_params()
+ accumulated_flops_cost = self.accumulate_flops()
+ return ', '.join([
+ params_to_string(
+ accumulated_num_params, units='M', precision=precision),
+ '{:.3%} Params'.format(accumulated_num_params / total_params),
+ flops_to_string(
+ accumulated_flops_cost, units=units, precision=precision),
+ '{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops),
+ self.original_extra_repr()
+ ])
+
+ def add_extra_repr(m):
+ m.accumulate_flops = accumulate_flops.__get__(m)
+ m.accumulate_params = accumulate_params.__get__(m)
+ flops_extra_repr = flops_repr.__get__(m)
+ if m.extra_repr != flops_extra_repr:
+ m.original_extra_repr = m.extra_repr
+ m.extra_repr = flops_extra_repr
+ assert m.extra_repr != m.original_extra_repr
+
+ def del_extra_repr(m):
+ if hasattr(m, 'original_extra_repr'):
+ m.extra_repr = m.original_extra_repr
+ del m.original_extra_repr
+ if hasattr(m, 'accumulate_flops'):
+ del m.accumulate_flops
+
+ model.apply(add_extra_repr)
+ print(model, file=ost, flush=flush)
+ model.apply(del_extra_repr)
+
+
+def get_model_parameters_number(model):
+ """Calculate parameter number of a model.
+
+ Args:
+ model (nn.module): The model for parameter number calculation.
+
+ Returns:
+ float: Parameter number of the model.
+ """
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return num_params
+
+
+def add_flops_counting_methods(net_main_module):
+ # adding additional methods to the existing module object,
+ # this is done this way so that each function has access to self object
+ net_main_module.start_flops_count = start_flops_count.__get__(
+ net_main_module)
+ net_main_module.stop_flops_count = stop_flops_count.__get__(
+ net_main_module)
+ net_main_module.reset_flops_count = reset_flops_count.__get__(
+ net_main_module)
+ net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # noqa: E501
+ net_main_module)
+
+ net_main_module.reset_flops_count()
+
+ return net_main_module
+
+
+def compute_average_flops_cost(self):
+ """Compute average FLOPs cost.
+
+ A method to compute average FLOPs cost, which will be available after
+ `add_flops_counting_methods()` is called on a desired net object.
+
+ Returns:
+ float: Current mean flops consumption per image.
+ """
+ batches_count = self.__batch_counter__
+ flops_sum = 0
+ for module in self.modules():
+ if is_supported_instance(module):
+ flops_sum += module.__flops__
+ params_sum = get_model_parameters_number(self)
+ return flops_sum / batches_count, params_sum
+
+
+def start_flops_count(self):
+ """Activate the computation of mean flops consumption per image.
+
+ A method to activate the computation of mean flops consumption per image.
+ which will be available after ``add_flops_counting_methods()`` is called on
+ a desired net object. It should be called before running the network.
+ """
+ add_batch_counter_hook_function(self)
+
+ def add_flops_counter_hook_function(module):
+ if is_supported_instance(module):
+ if hasattr(module, '__flops_handle__'):
+ return
+
+ else:
+ handle = module.register_forward_hook(
+ get_modules_mapping()[type(module)])
+
+ module.__flops_handle__ = handle
+
+ self.apply(partial(add_flops_counter_hook_function))
+
+
+def stop_flops_count(self):
+ """Stop computing the mean flops consumption per image.
+
+ A method to stop computing the mean flops consumption per image, which will
+ be available after ``add_flops_counting_methods()`` is called on a desired
+ net object. It can be called to pause the computation whenever.
+ """
+ remove_batch_counter_hook_function(self)
+ self.apply(remove_flops_counter_hook_function)
+
+
+def reset_flops_count(self):
+ """Reset statistics computed so far.
+
+ A method to Reset computed statistics, which will be available after
+ `add_flops_counting_methods()` is called on a desired net object.
+ """
+ add_batch_counter_variables_or_reset(self)
+ self.apply(add_flops_counter_variable_or_reset)
+
+
+# ---- Internal functions
+def empty_flops_counter_hook(module, input, output):
+ module.__flops__ += 0
+
+
+def upsample_flops_counter_hook(module, input, output):
+ output_size = output[0]
+ batch_size = output_size.shape[0]
+ output_elements_count = batch_size
+ for val in output_size.shape[1:]:
+ output_elements_count *= val
+ module.__flops__ += int(output_elements_count)
+
+
+def relu_flops_counter_hook(module, input, output):
+ active_elements_count = output.numel()
+ module.__flops__ += int(active_elements_count)
+
+
+def linear_flops_counter_hook(module, input, output):
+ input = input[0]
+ output_last_dim = output.shape[
+ -1] # pytorch checks dimensions, so here we don't care much
+ module.__flops__ += int(np.prod(input.shape) * output_last_dim)
+
+
+def pool_flops_counter_hook(module, input, output):
+ input = input[0]
+ module.__flops__ += int(np.prod(input.shape))
+
+
+def norm_flops_counter_hook(module, input, output):
+ input = input[0]
+
+ batch_flops = np.prod(input.shape)
+ if (getattr(module, 'affine', False)
+ or getattr(module, 'elementwise_affine', False)):
+ batch_flops *= 2
+ module.__flops__ += int(batch_flops)
+
+
+def deconv_flops_counter_hook(conv_module, input, output):
+ # Can have multiple inputs, getting the first one
+ input = input[0]
+
+ batch_size = input.shape[0]
+ input_height, input_width = input.shape[2:]
+
+ kernel_height, kernel_width = conv_module.kernel_size
+ in_channels = conv_module.in_channels
+ out_channels = conv_module.out_channels
+ groups = conv_module.groups
+
+ filters_per_channel = out_channels // groups
+ conv_per_position_flops = (
+ kernel_height * kernel_width * in_channels * filters_per_channel)
+
+ active_elements_count = batch_size * input_height * input_width
+ overall_conv_flops = conv_per_position_flops * active_elements_count
+ bias_flops = 0
+ if conv_module.bias is not None:
+ output_height, output_width = output.shape[2:]
+ bias_flops = out_channels * batch_size * output_height * output_height
+ overall_flops = overall_conv_flops + bias_flops
+
+ conv_module.__flops__ += int(overall_flops)
+
+
+def conv_flops_counter_hook(conv_module, input, output):
+ # Can have multiple inputs, getting the first one
+ input = input[0]
+
+ batch_size = input.shape[0]
+ output_dims = list(output.shape[2:])
+
+ kernel_dims = list(conv_module.kernel_size)
+ in_channels = conv_module.in_channels
+ out_channels = conv_module.out_channels
+ groups = conv_module.groups
+
+ filters_per_channel = out_channels // groups
+ conv_per_position_flops = int(
+ np.prod(kernel_dims)) * in_channels * filters_per_channel
+
+ active_elements_count = batch_size * int(np.prod(output_dims))
+
+ overall_conv_flops = conv_per_position_flops * active_elements_count
+
+ bias_flops = 0
+
+ if conv_module.bias is not None:
+
+ bias_flops = out_channels * active_elements_count
+
+ overall_flops = overall_conv_flops + bias_flops
+
+ conv_module.__flops__ += int(overall_flops)
+
+
+def batch_counter_hook(module, input, output):
+ batch_size = 1
+ if len(input) > 0:
+ # Can have multiple inputs, getting the first one
+ input = input[0]
+ batch_size = len(input)
+ else:
+ pass
+ print('Warning! No positional inputs found for a module, '
+ 'assuming batch size is 1.')
+ module.__batch_counter__ += batch_size
+
+
+def add_batch_counter_variables_or_reset(module):
+
+ module.__batch_counter__ = 0
+
+
+def add_batch_counter_hook_function(module):
+ if hasattr(module, '__batch_counter_handle__'):
+ return
+
+ handle = module.register_forward_hook(batch_counter_hook)
+ module.__batch_counter_handle__ = handle
+
+
+def remove_batch_counter_hook_function(module):
+ if hasattr(module, '__batch_counter_handle__'):
+ module.__batch_counter_handle__.remove()
+ del module.__batch_counter_handle__
+
+
+def add_flops_counter_variable_or_reset(module):
+ if is_supported_instance(module):
+ if hasattr(module, '__flops__') or hasattr(module, '__params__'):
+ print('Warning: variables __flops__ or __params__ are already '
+ 'defined for the module' + type(module).__name__ +
+ ' ptflops can affect your code!')
+ module.__flops__ = 0
+ module.__params__ = get_model_parameters_number(module)
+
+
+def is_supported_instance(module):
+ if type(module) in get_modules_mapping():
+ return True
+ return False
+
+
+def remove_flops_counter_hook_function(module):
+ if is_supported_instance(module):
+ if hasattr(module, '__flops_handle__'):
+ module.__flops_handle__.remove()
+ del module.__flops_handle__
+
+
+def get_modules_mapping():
+ return {
+ # convolutions
+ nn.Conv1d: conv_flops_counter_hook,
+ nn.Conv2d: conv_flops_counter_hook,
+ mmcv.cnn.bricks.Conv2d: conv_flops_counter_hook,
+ nn.Conv3d: conv_flops_counter_hook,
+ mmcv.cnn.bricks.Conv3d: conv_flops_counter_hook,
+ # activations
+ nn.ReLU: relu_flops_counter_hook,
+ nn.PReLU: relu_flops_counter_hook,
+ nn.ELU: relu_flops_counter_hook,
+ nn.LeakyReLU: relu_flops_counter_hook,
+ nn.ReLU6: relu_flops_counter_hook,
+ # poolings
+ nn.MaxPool1d: pool_flops_counter_hook,
+ nn.AvgPool1d: pool_flops_counter_hook,
+ nn.AvgPool2d: pool_flops_counter_hook,
+ nn.MaxPool2d: pool_flops_counter_hook,
+ mmcv.cnn.bricks.MaxPool2d: pool_flops_counter_hook,
+ nn.MaxPool3d: pool_flops_counter_hook,
+ mmcv.cnn.bricks.MaxPool3d: pool_flops_counter_hook,
+ nn.AvgPool3d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
+ nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
+ nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
+ # normalizations
+ nn.BatchNorm1d: norm_flops_counter_hook,
+ nn.BatchNorm2d: norm_flops_counter_hook,
+ nn.BatchNorm3d: norm_flops_counter_hook,
+ nn.GroupNorm: norm_flops_counter_hook,
+ nn.InstanceNorm1d: norm_flops_counter_hook,
+ nn.InstanceNorm2d: norm_flops_counter_hook,
+ nn.InstanceNorm3d: norm_flops_counter_hook,
+ nn.LayerNorm: norm_flops_counter_hook,
+ # FC
+ nn.Linear: linear_flops_counter_hook,
+ mmcv.cnn.bricks.Linear: linear_flops_counter_hook,
+ # Upscale
+ nn.Upsample: upsample_flops_counter_hook,
+ # Deconvolution
+ nn.ConvTranspose2d: deconv_flops_counter_hook,
+ mmcv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook,
+ }
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/utils/fuse_conv_bn.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/utils/fuse_conv_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb7076f80bf37f7931185bf0293ffcc1ce19c8ef
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/utils/fuse_conv_bn.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+
+def _fuse_conv_bn(conv, bn):
+ """Fuse conv and bn into one module.
+
+ Args:
+ conv (nn.Module): Conv to be fused.
+ bn (nn.Module): BN to be fused.
+
+ Returns:
+ nn.Module: Fused module.
+ """
+ conv_w = conv.weight
+ conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
+ bn.running_mean)
+
+ factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
+ conv.weight = nn.Parameter(conv_w *
+ factor.reshape([conv.out_channels, 1, 1, 1]))
+ conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
+ return conv
+
+
+def fuse_conv_bn(module):
+ """Recursively fuse conv and bn in a module.
+
+ During inference, the functionary of batch norm layers is turned off
+ but only the mean and var alone channels are used, which exposes the
+ chance to fuse it with the preceding conv layers to save computations and
+ simplify network structures.
+
+ Args:
+ module (nn.Module): Module to be fused.
+
+ Returns:
+ nn.Module: Fused module.
+ """
+ last_conv = None
+ last_conv_name = None
+
+ for name, child in module.named_children():
+ if isinstance(child,
+ (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
+ if last_conv is None: # only fuse BN that is after Conv
+ continue
+ fused_conv = _fuse_conv_bn(last_conv, child)
+ module._modules[last_conv_name] = fused_conv
+ # To reduce changes, set BN as Identity instead of deleting it.
+ module._modules[name] = nn.Identity()
+ last_conv = None
+ elif isinstance(child, nn.Conv2d):
+ last_conv = child
+ last_conv_name = name
+ else:
+ fuse_conv_bn(child)
+ return module
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/utils/sync_bn.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/utils/sync_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0dbcb1b167ea0df690c0f47fe0217a3454b5d59
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/utils/sync_bn.py
@@ -0,0 +1,59 @@
+import torch
+
+import annotator.mmpkg.mmcv as mmcv
+
+
+class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
+ """A general BatchNorm layer without input dimension check.
+
+ Reproduced from @kapily's work:
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+ The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
+ is `_check_input_dim` that is designed for tensor sanity checks.
+ The check has been bypassed in this class for the convenience of converting
+ SyncBatchNorm.
+ """
+
+ def _check_input_dim(self, input):
+ return
+
+
+def revert_sync_batchnorm(module):
+ """Helper function to convert all `SyncBatchNorm` (SyncBN) and
+ `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
+ `BatchNormXd` layers.
+
+ Adapted from @kapily's work:
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+
+ Args:
+ module (nn.Module): The module containing `SyncBatchNorm` layers.
+
+ Returns:
+ module_output: The converted module with `BatchNormXd` layers.
+ """
+ module_output = module
+ module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
+ if hasattr(mmcv, 'ops'):
+ module_checklist.append(mmcv.ops.SyncBatchNorm)
+ if isinstance(module, tuple(module_checklist)):
+ module_output = _BatchNormXd(module.num_features, module.eps,
+ module.momentum, module.affine,
+ module.track_running_stats)
+ if module.affine:
+ # no_grad() may not be needed here but
+ # just to be consistent with `convert_sync_batchnorm()`
+ with torch.no_grad():
+ module_output.weight = module.weight
+ module_output.bias = module.bias
+ module_output.running_mean = module.running_mean
+ module_output.running_var = module.running_var
+ module_output.num_batches_tracked = module.num_batches_tracked
+ module_output.training = module.training
+ # qconfig exists in quantized models
+ if hasattr(module, 'qconfig'):
+ module_output.qconfig = module.qconfig
+ for name, child in module.named_children():
+ module_output.add_module(name, revert_sync_batchnorm(child))
+ del module
+ return module_output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/utils/weight_init.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/utils/weight_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..096d0ddcccbec84675f0771cb546d0fa003417e7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/utils/weight_init.py
@@ -0,0 +1,684 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import math
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from annotator.mmpkg.mmcv.utils import Registry, build_from_cfg, get_logger, print_log
+
+INITIALIZERS = Registry('initializer')
+
+
+def update_init_info(module, init_info):
+ """Update the `_params_init_info` in the module if the value of parameters
+ are changed.
+
+ Args:
+ module (obj:`nn.Module`): The module of PyTorch with a user-defined
+ attribute `_params_init_info` which records the initialization
+ information.
+ init_info (str): The string that describes the initialization.
+ """
+ assert hasattr(
+ module,
+ '_params_init_info'), f'Can not find `_params_init_info` in {module}'
+ for name, param in module.named_parameters():
+
+ assert param in module._params_init_info, (
+ f'Find a new :obj:`Parameter` '
+ f'named `{name}` during executing the '
+ f'`init_weights` of '
+ f'`{module.__class__.__name__}`. '
+ f'Please do not add or '
+ f'replace parameters during executing '
+ f'the `init_weights`. ')
+
+ # The parameter has been changed during executing the
+ # `init_weights` of module
+ mean_value = param.data.mean()
+ if module._params_init_info[param]['tmp_mean_value'] != mean_value:
+ module._params_init_info[param]['init_info'] = init_info
+ module._params_init_info[param]['tmp_mean_value'] = mean_value
+
+
+def constant_init(module, val, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def xavier_init(module, gain=1, bias=0, distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if hasattr(module, 'weight') and module.weight is not None:
+ if distribution == 'uniform':
+ nn.init.xavier_uniform_(module.weight, gain=gain)
+ else:
+ nn.init.xavier_normal_(module.weight, gain=gain)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def normal_init(module, mean=0, std=1, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.normal_(module.weight, mean, std)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def trunc_normal_init(module: nn.Module,
+ mean: float = 0,
+ std: float = 1,
+ a: float = -2,
+ b: float = 2,
+ bias: float = 0) -> None:
+ if hasattr(module, 'weight') and module.weight is not None:
+ trunc_normal_(module.weight, mean, std, a, b) # type: ignore
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias) # type: ignore
+
+
+def uniform_init(module, a=0, b=1, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.uniform_(module.weight, a, b)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def kaiming_init(module,
+ a=0,
+ mode='fan_out',
+ nonlinearity='relu',
+ bias=0,
+ distribution='normal'):
+ assert distribution in ['uniform', 'normal']
+ if hasattr(module, 'weight') and module.weight is not None:
+ if distribution == 'uniform':
+ nn.init.kaiming_uniform_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ else:
+ nn.init.kaiming_normal_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def caffe2_xavier_init(module, bias=0):
+ # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+ # Acknowledgment to FAIR's internal code
+ kaiming_init(
+ module,
+ a=1,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ bias=bias,
+ distribution='uniform')
+
+
+def bias_init_with_prob(prior_prob):
+ """initialize conv/fc bias value according to a given probability value."""
+ bias_init = float(-np.log((1 - prior_prob) / prior_prob))
+ return bias_init
+
+
+def _get_bases_name(m):
+ return [b.__name__ for b in m.__class__.__bases__]
+
+
+class BaseInit(object):
+
+ def __init__(self, *, bias=0, bias_prob=None, layer=None):
+ self.wholemodule = False
+ if not isinstance(bias, (int, float)):
+ raise TypeError(f'bias must be a number, but got a {type(bias)}')
+
+ if bias_prob is not None:
+ if not isinstance(bias_prob, float):
+ raise TypeError(f'bias_prob type must be float, \
+ but got {type(bias_prob)}')
+
+ if layer is not None:
+ if not isinstance(layer, (str, list)):
+ raise TypeError(f'layer must be a str or a list of str, \
+ but got a {type(layer)}')
+ else:
+ layer = []
+
+ if bias_prob is not None:
+ self.bias = bias_init_with_prob(bias_prob)
+ else:
+ self.bias = bias
+ self.layer = [layer] if isinstance(layer, str) else layer
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Constant')
+class ConstantInit(BaseInit):
+ """Initialize module parameters with constant values.
+
+ Args:
+ val (int | float): the value to fill the weights in the module with
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self, val, **kwargs):
+ super().__init__(**kwargs)
+ self.val = val
+
+ def __call__(self, module):
+
+ def init(m):
+ if self.wholemodule:
+ constant_init(m, self.val, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ constant_init(m, self.val, self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Xavier')
+class XavierInit(BaseInit):
+ r"""Initialize module parameters with values according to the method
+ described in `Understanding the difficulty of training deep feedforward
+ neural networks - Glorot, X. & Bengio, Y. (2010).
+ `_
+
+ Args:
+ gain (int | float): an optional scaling factor. Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ distribution (str): distribution either be ``'normal'``
+ or ``'uniform'``. Defaults to ``'normal'``.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self, gain=1, distribution='normal', **kwargs):
+ super().__init__(**kwargs)
+ self.gain = gain
+ self.distribution = distribution
+
+ def __call__(self, module):
+
+ def init(m):
+ if self.wholemodule:
+ xavier_init(m, self.gain, self.bias, self.distribution)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ xavier_init(m, self.gain, self.bias, self.distribution)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: gain={self.gain}, ' \
+ f'distribution={self.distribution}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Normal')
+class NormalInit(BaseInit):
+ r"""Initialize module parameters with the values drawn from the normal
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
+
+ Args:
+ mean (int | float):the mean of the normal distribution. Defaults to 0.
+ std (int | float): the standard deviation of the normal distribution.
+ Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+
+ """
+
+ def __init__(self, mean=0, std=1, **kwargs):
+ super().__init__(**kwargs)
+ self.mean = mean
+ self.std = std
+
+ def __call__(self, module):
+
+ def init(m):
+ if self.wholemodule:
+ normal_init(m, self.mean, self.std, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ normal_init(m, self.mean, self.std, self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: mean={self.mean},' \
+ f' std={self.std}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='TruncNormal')
+class TruncNormalInit(BaseInit):
+ r"""Initialize module parameters with the values drawn from the normal
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
+ outside :math:`[a, b]`.
+
+ Args:
+ mean (float): the mean of the normal distribution. Defaults to 0.
+ std (float): the standard deviation of the normal distribution.
+ Defaults to 1.
+ a (float): The minimum cutoff value.
+ b ( float): The maximum cutoff value.
+ bias (float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+
+ """
+
+ def __init__(self,
+ mean: float = 0,
+ std: float = 1,
+ a: float = -2,
+ b: float = 2,
+ **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.mean = mean
+ self.std = std
+ self.a = a
+ self.b = b
+
+ def __call__(self, module: nn.Module) -> None:
+
+ def init(m):
+ if self.wholemodule:
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+ self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+ self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
+ f' mean={self.mean}, std={self.std}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Uniform')
+class UniformInit(BaseInit):
+ r"""Initialize module parameters with values drawn from the uniform
+ distribution :math:`\mathcal{U}(a, b)`.
+
+ Args:
+ a (int | float): the lower bound of the uniform distribution.
+ Defaults to 0.
+ b (int | float): the upper bound of the uniform distribution.
+ Defaults to 1.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self, a=0, b=1, **kwargs):
+ super().__init__(**kwargs)
+ self.a = a
+ self.b = b
+
+ def __call__(self, module):
+
+ def init(m):
+ if self.wholemodule:
+ uniform_init(m, self.a, self.b, self.bias)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ uniform_init(m, self.a, self.b, self.bias)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: a={self.a},' \
+ f' b={self.b}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Kaiming')
+class KaimingInit(BaseInit):
+ r"""Initialize module parameters with the values according to the method
+ described in `Delving deep into rectifiers: Surpassing human-level
+ performance on ImageNet classification - He, K. et al. (2015).
+ `_
+
+ Args:
+ a (int | float): the negative slope of the rectifier used after this
+ layer (only used with ``'leaky_relu'``). Defaults to 0.
+ mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
+ ``'fan_in'`` preserves the magnitude of the variance of the weights
+ in the forward pass. Choosing ``'fan_out'`` preserves the
+ magnitudes in the backwards pass. Defaults to ``'fan_out'``.
+ nonlinearity (str): the non-linear function (`nn.functional` name),
+ recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
+ Defaults to 'relu'.
+ bias (int | float): the value to fill the bias. Defaults to 0.
+ bias_prob (float, optional): the probability for bias initialization.
+ Defaults to None.
+ distribution (str): distribution either be ``'normal'`` or
+ ``'uniform'``. Defaults to ``'normal'``.
+ layer (str | list[str], optional): the layer will be initialized.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ a=0,
+ mode='fan_out',
+ nonlinearity='relu',
+ distribution='normal',
+ **kwargs):
+ super().__init__(**kwargs)
+ self.a = a
+ self.mode = mode
+ self.nonlinearity = nonlinearity
+ self.distribution = distribution
+
+ def __call__(self, module):
+
+ def init(m):
+ if self.wholemodule:
+ kaiming_init(m, self.a, self.mode, self.nonlinearity,
+ self.bias, self.distribution)
+ else:
+ layername = m.__class__.__name__
+ basesname = _get_bases_name(m)
+ if len(set(self.layer) & set([layername] + basesname)):
+ kaiming_init(m, self.a, self.mode, self.nonlinearity,
+ self.bias, self.distribution)
+
+ module.apply(init)
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
+ f'nonlinearity={self.nonlinearity}, ' \
+ f'distribution ={self.distribution}, bias={self.bias}'
+ return info
+
+
+@INITIALIZERS.register_module(name='Caffe2Xavier')
+class Caffe2XavierInit(KaimingInit):
+ # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+ # Acknowledgment to FAIR's internal code
+ def __init__(self, **kwargs):
+ super().__init__(
+ a=1,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ distribution='uniform',
+ **kwargs)
+
+ def __call__(self, module):
+ super().__call__(module)
+
+
+@INITIALIZERS.register_module(name='Pretrained')
+class PretrainedInit(object):
+ """Initialize module by loading a pretrained model.
+
+ Args:
+ checkpoint (str): the checkpoint file of the pretrained model should
+ be load.
+ prefix (str, optional): the prefix of a sub-module in the pretrained
+ model. it is for loading a part of the pretrained model to
+ initialize. For example, if we would like to only load the
+ backbone of a detector model, we can set ``prefix='backbone.'``.
+ Defaults to None.
+ map_location (str): map tensors into proper locations.
+ """
+
+ def __init__(self, checkpoint, prefix=None, map_location=None):
+ self.checkpoint = checkpoint
+ self.prefix = prefix
+ self.map_location = map_location
+
+ def __call__(self, module):
+ from annotator.mmpkg.mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
+ load_state_dict)
+ logger = get_logger('mmcv')
+ if self.prefix is None:
+ print_log(f'load model from: {self.checkpoint}', logger=logger)
+ load_checkpoint(
+ module,
+ self.checkpoint,
+ map_location=self.map_location,
+ strict=False,
+ logger=logger)
+ else:
+ print_log(
+ f'load {self.prefix} in model from: {self.checkpoint}',
+ logger=logger)
+ state_dict = _load_checkpoint_with_prefix(
+ self.prefix, self.checkpoint, map_location=self.map_location)
+ load_state_dict(module, state_dict, strict=False, logger=logger)
+
+ if hasattr(module, '_params_init_info'):
+ update_init_info(module, init_info=self._get_init_info())
+
+ def _get_init_info(self):
+ info = f'{self.__class__.__name__}: load from {self.checkpoint}'
+ return info
+
+
+def _initialize(module, cfg, wholemodule=False):
+ func = build_from_cfg(cfg, INITIALIZERS)
+ # wholemodule flag is for override mode, there is no layer key in override
+ # and initializer will give init values for the whole module with the name
+ # in override.
+ func.wholemodule = wholemodule
+ func(module)
+
+
+def _initialize_override(module, override, cfg):
+ if not isinstance(override, (dict, list)):
+ raise TypeError(f'override must be a dict or a list of dict, \
+ but got {type(override)}')
+
+ override = [override] if isinstance(override, dict) else override
+
+ for override_ in override:
+
+ cp_override = copy.deepcopy(override_)
+ name = cp_override.pop('name', None)
+ if name is None:
+ raise ValueError('`override` must contain the key "name",'
+ f'but got {cp_override}')
+ # if override only has name key, it means use args in init_cfg
+ if not cp_override:
+ cp_override.update(cfg)
+ # if override has name key and other args except type key, it will
+ # raise error
+ elif 'type' not in cp_override.keys():
+ raise ValueError(
+ f'`override` need "type" key, but got {cp_override}')
+
+ if hasattr(module, name):
+ _initialize(getattr(module, name), cp_override, wholemodule=True)
+ else:
+ raise RuntimeError(f'module did not have attribute {name}, '
+ f'but init_cfg is {cp_override}.')
+
+
+def initialize(module, init_cfg):
+ """Initialize a module.
+
+ Args:
+ module (``torch.nn.Module``): the module will be initialized.
+ init_cfg (dict | list[dict]): initialization configuration dict to
+ define initializer. OpenMMLab has implemented 6 initializers
+ including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
+ ``Kaiming``, and ``Pretrained``.
+ Example:
+ >>> module = nn.Linear(2, 3, bias=True)
+ >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
+ >>> initialize(module, init_cfg)
+
+ >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
+ >>> # define key ``'layer'`` for initializing layer with different
+ >>> # configuration
+ >>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
+ dict(type='Constant', layer='Linear', val=2)]
+ >>> initialize(module, init_cfg)
+
+ >>> # define key``'override'`` to initialize some specific part in
+ >>> # module
+ >>> class FooNet(nn.Module):
+ >>> def __init__(self):
+ >>> super().__init__()
+ >>> self.feat = nn.Conv2d(3, 16, 3)
+ >>> self.reg = nn.Conv2d(16, 10, 3)
+ >>> self.cls = nn.Conv2d(16, 5, 3)
+ >>> model = FooNet()
+ >>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
+ >>> override=dict(type='Constant', name='reg', val=3, bias=4))
+ >>> initialize(model, init_cfg)
+
+ >>> model = ResNet(depth=50)
+ >>> # Initialize weights with the pretrained model.
+ >>> init_cfg = dict(type='Pretrained',
+ checkpoint='torchvision://resnet50')
+ >>> initialize(model, init_cfg)
+
+ >>> # Initialize weights of a sub-module with the specific part of
+ >>> # a pretrained model by using "prefix".
+ >>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
+ >>> 'retinanet_r50_fpn_1x_coco/'\
+ >>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
+ >>> init_cfg = dict(type='Pretrained',
+ checkpoint=url, prefix='backbone.')
+ """
+ if not isinstance(init_cfg, (dict, list)):
+ raise TypeError(f'init_cfg must be a dict or a list of dict, \
+ but got {type(init_cfg)}')
+
+ if isinstance(init_cfg, dict):
+ init_cfg = [init_cfg]
+
+ for cfg in init_cfg:
+ # should deeply copy the original config because cfg may be used by
+ # other modules, e.g., one init_cfg shared by multiple bottleneck
+ # blocks, the expected cfg will be changed after pop and will change
+ # the initialization behavior of other modules
+ cp_cfg = copy.deepcopy(cfg)
+ override = cp_cfg.pop('override', None)
+ _initialize(module, cp_cfg)
+
+ if override is not None:
+ cp_cfg.pop('layer', None)
+ _initialize_override(module, override, cp_cfg)
+ else:
+ # All attributes in module have same initialization.
+ pass
+
+
+def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
+ b: float) -> Tensor:
+ # Method based on
+ # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ # Modified from
+ # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ lower = norm_cdf((a - mean) / std)
+ upper = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [lower, upper], then translate
+ # to [2lower-1, 2upper-1].
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor: Tensor,
+ mean: float = 0.,
+ std: float = 1.,
+ a: float = -2.,
+ b: float = 2.) -> Tensor:
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+
+ Modified from
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+
+ Args:
+ tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
+ mean (float): the mean of the normal distribution.
+ std (float): the standard deviation of the normal distribution.
+ a (float): the minimum cutoff value.
+ b (float): the maximum cutoff value.
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/vgg.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..8778b649561a45a9652b1a15a26c2d171e58f3e1
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/cnn/vgg.py
@@ -0,0 +1,175 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.nn as nn
+
+from .utils import constant_init, kaiming_init, normal_init
+
+
+def conv3x3(in_planes, out_planes, dilation=1):
+ """3x3 convolution with padding."""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ padding=dilation,
+ dilation=dilation)
+
+
+def make_vgg_layer(inplanes,
+ planes,
+ num_blocks,
+ dilation=1,
+ with_bn=False,
+ ceil_mode=False):
+ layers = []
+ for _ in range(num_blocks):
+ layers.append(conv3x3(inplanes, planes, dilation))
+ if with_bn:
+ layers.append(nn.BatchNorm2d(planes))
+ layers.append(nn.ReLU(inplace=True))
+ inplanes = planes
+ layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
+
+ return layers
+
+
+class VGG(nn.Module):
+ """VGG backbone.
+
+ Args:
+ depth (int): Depth of vgg, from {11, 13, 16, 19}.
+ with_bn (bool): Use BatchNorm or not.
+ num_classes (int): number of classes for classification.
+ num_stages (int): VGG stages, normally 5.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
+ running stats (mean and var).
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+ """
+
+ arch_settings = {
+ 11: (1, 1, 2, 2, 2),
+ 13: (2, 2, 2, 2, 2),
+ 16: (2, 2, 3, 3, 3),
+ 19: (2, 2, 4, 4, 4)
+ }
+
+ def __init__(self,
+ depth,
+ with_bn=False,
+ num_classes=-1,
+ num_stages=5,
+ dilations=(1, 1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3, 4),
+ frozen_stages=-1,
+ bn_eval=True,
+ bn_frozen=False,
+ ceil_mode=False,
+ with_last_pool=True):
+ super(VGG, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for vgg')
+ assert num_stages >= 1 and num_stages <= 5
+ stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ assert len(dilations) == num_stages
+ assert max(out_indices) <= num_stages
+
+ self.num_classes = num_classes
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.bn_eval = bn_eval
+ self.bn_frozen = bn_frozen
+
+ self.inplanes = 3
+ start_idx = 0
+ vgg_layers = []
+ self.range_sub_modules = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ num_modules = num_blocks * (2 + with_bn) + 1
+ end_idx = start_idx + num_modules
+ dilation = dilations[i]
+ planes = 64 * 2**i if i < 4 else 512
+ vgg_layer = make_vgg_layer(
+ self.inplanes,
+ planes,
+ num_blocks,
+ dilation=dilation,
+ with_bn=with_bn,
+ ceil_mode=ceil_mode)
+ vgg_layers.extend(vgg_layer)
+ self.inplanes = planes
+ self.range_sub_modules.append([start_idx, end_idx])
+ start_idx = end_idx
+ if not with_last_pool:
+ vgg_layers.pop(-1)
+ self.range_sub_modules[-1][1] -= 1
+ self.module_name = 'features'
+ self.add_module(self.module_name, nn.Sequential(*vgg_layers))
+
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Linear(512 * 7 * 7, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, num_classes),
+ )
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ from ..runner import load_checkpoint
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ elif isinstance(m, nn.Linear):
+ normal_init(m, std=0.01)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ outs = []
+ vgg_layers = getattr(self, self.module_name)
+ for i in range(len(self.stage_blocks)):
+ for j in range(*self.range_sub_modules[i]):
+ vgg_layer = vgg_layers[j]
+ x = vgg_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if self.num_classes > 0:
+ x = x.view(x.size(0), -1)
+ x = self.classifier(x)
+ outs.append(x)
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super(VGG, self).train(mode)
+ if self.bn_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ if self.bn_frozen:
+ for params in m.parameters():
+ params.requires_grad = False
+ vgg_layers = getattr(self, self.module_name)
+ if mode and self.frozen_stages >= 0:
+ for i in range(self.frozen_stages):
+ for j in range(*self.range_sub_modules[i]):
+ mod = vgg_layers[j]
+ mod.eval()
+ for param in mod.parameters():
+ param.requires_grad = False
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/engine/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3193b7f664e19ce2458d81c836597fa22e4bb082
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/engine/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test,
+ single_gpu_test)
+
+__all__ = [
+ 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test',
+ 'single_gpu_test'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/engine/test.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/engine/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad5f55c4b181f7ad7bf17ed9003496f7377bbd3e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/engine/test.py
@@ -0,0 +1,202 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+import time
+
+import torch
+import torch.distributed as dist
+
+import annotator.mmpkg.mmcv as mmcv
+from annotator.mmpkg.mmcv.runner import get_dist_info
+
+
+def single_gpu_test(model, data_loader):
+ """Test model with a single gpu.
+
+ This method tests model with a single gpu and displays test progress bar.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for data in data_loader:
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+ results.extend(result)
+
+ # Assume result has the same length of batch_size
+ # refer to https://github.com/open-mmlab/mmcv/issues/985
+ batch_size = len(result)
+ for _ in range(batch_size):
+ prog_bar.update()
+ return results
+
+
+def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
+ """Test model with multiple gpus.
+
+ This method tests model with multiple gpus and collects the results
+ under two different modes: gpu and cpu modes. By setting
+ ``gpu_collect=True``, it encodes results to gpu tensors and use gpu
+ communication for results collection. On cpu mode it saves the results on
+ different gpus to ``tmpdir`` and collects them by the rank 0 worker.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+ tmpdir (str): Path of directory to save the temporary results from
+ different gpus under cpu mode.
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
+
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+ if rank == 0:
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ time.sleep(2) # This line can prevent deadlock problem in some cases.
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+ results.extend(result)
+
+ if rank == 0:
+ batch_size = len(result)
+ batch_size_all = batch_size * world_size
+ if batch_size_all + prog_bar.completed > len(dataset):
+ batch_size_all = len(dataset) - prog_bar.completed
+ for _ in range(batch_size_all):
+ prog_bar.update()
+
+ # collect results from all ranks
+ if gpu_collect:
+ results = collect_results_gpu(results, len(dataset))
+ else:
+ results = collect_results_cpu(results, len(dataset), tmpdir)
+ return results
+
+
+def collect_results_cpu(result_part, size, tmpdir=None):
+ """Collect results under cpu mode.
+
+ On cpu mode, this function will save the results on different gpus to
+ ``tmpdir`` and collect them by the rank 0 worker.
+
+ Args:
+ result_part (list): Result list containing result parts
+ to be collected.
+ size (int): Size of the results, commonly equal to length of
+ the results.
+ tmpdir (str | None): temporal directory for collected results to
+ store. If set to None, it will create a random temporal directory
+ for it.
+
+ Returns:
+ list: The collected results.
+ """
+ rank, world_size = get_dist_info()
+ # create a tmp dir if it is not specified
+ if tmpdir is None:
+ MAX_LEN = 512
+ # 32 is whitespace
+ dir_tensor = torch.full((MAX_LEN, ),
+ 32,
+ dtype=torch.uint8,
+ device='cuda')
+ if rank == 0:
+ mmcv.mkdir_or_exist('.dist_test')
+ tmpdir = tempfile.mkdtemp(dir='.dist_test')
+ tmpdir = torch.tensor(
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
+ dir_tensor[:len(tmpdir)] = tmpdir
+ dist.broadcast(dir_tensor, 0)
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+ else:
+ mmcv.mkdir_or_exist(tmpdir)
+ # dump the part result to the dir
+ mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
+ dist.barrier()
+ # collect all parts
+ if rank != 0:
+ return None
+ else:
+ # load results of all parts from tmp dir
+ part_list = []
+ for i in range(world_size):
+ part_file = osp.join(tmpdir, f'part_{i}.pkl')
+ part_result = mmcv.load(part_file)
+ # When data is severely insufficient, an empty part_result
+ # on a certain gpu could makes the overall outputs empty.
+ if part_result:
+ part_list.append(part_result)
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ # remove tmp dir
+ shutil.rmtree(tmpdir)
+ return ordered_results
+
+
+def collect_results_gpu(result_part, size):
+ """Collect results under gpu mode.
+
+ On gpu mode, this function will encode results to gpu tensors and use gpu
+ communication for results collection.
+
+ Args:
+ result_part (list): Result list containing result parts
+ to be collected.
+ size (int): Size of the results, commonly equal to length of
+ the results.
+
+ Returns:
+ list: The collected results.
+ """
+ rank, world_size = get_dist_info()
+ # dump result part to tensor with pickle
+ part_tensor = torch.tensor(
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+ # gather all result part tensor shape
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
+ dist.all_gather(shape_list, shape_tensor)
+ # padding result part tensor to max length
+ shape_max = torch.tensor(shape_list).max()
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+ part_send[:shape_tensor[0]] = part_tensor
+ part_recv_list = [
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
+ ]
+ # gather all result part
+ dist.all_gather(part_recv_list, part_send)
+
+ if rank == 0:
+ part_list = []
+ for recv, shape in zip(part_recv_list, shape_list):
+ part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
+ # When data is severely insufficient, an empty part_result
+ # on a certain gpu could makes the overall outputs empty.
+ if part_result:
+ part_list.append(part_result)
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ return ordered_results
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2051b85f7e59bff7bdbaa131849ce8cd31f059a4
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .file_client import BaseStorageBackend, FileClient
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+from .io import dump, load, register_handler
+from .parse import dict_from_file, list_from_file
+
+__all__ = [
+ 'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler',
+ 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler',
+ 'list_from_file', 'dict_from_file'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/file_client.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/file_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ed2bf5f41a29000f9a080066497d8f3674fae15
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/file_client.py
@@ -0,0 +1,1148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import os
+import os.path as osp
+import re
+import tempfile
+import warnings
+from abc import ABCMeta, abstractmethod
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Iterable, Iterator, Optional, Tuple, Union
+from urllib.request import urlopen
+
+import annotator.mmpkg.mmcv as mmcv
+from annotator.mmpkg.mmcv.utils.misc import has_method
+from annotator.mmpkg.mmcv.utils.path import is_filepath
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+ """Abstract class of storage backends.
+
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+ as texts.
+ """
+
+ # a flag to indicate whether the backend can create a symlink for a file
+ _allow_symlink = False
+
+ @property
+ def name(self):
+ return self.__class__.__name__
+
+ @property
+ def allow_symlink(self):
+ return self._allow_symlink
+
+ @abstractmethod
+ def get(self, filepath):
+ pass
+
+ @abstractmethod
+ def get_text(self, filepath):
+ pass
+
+
+class CephBackend(BaseStorageBackend):
+ """Ceph storage backend (for internal use).
+
+ Args:
+ path_mapping (dict|None): path mapping dict from local path to Petrel
+ path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath``
+ will be replaced by ``dst``. Default: None.
+
+ .. warning::
+ :class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
+ please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
+ """
+
+ def __init__(self, path_mapping=None):
+ try:
+ import ceph
+ except ImportError:
+ raise ImportError('Please install ceph to enable CephBackend.')
+
+ warnings.warn(
+ 'CephBackend will be deprecated, please use PetrelBackend instead')
+ self._client = ceph.S3Client()
+ assert isinstance(path_mapping, dict) or path_mapping is None
+ self.path_mapping = path_mapping
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ if self.path_mapping is not None:
+ for k, v in self.path_mapping.items():
+ filepath = filepath.replace(k, v)
+ value = self._client.Get(filepath)
+ value_buf = memoryview(value)
+ return value_buf
+
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+
+
+class PetrelBackend(BaseStorageBackend):
+ """Petrel storage backend (for internal use).
+
+ PetrelBackend supports reading and writing data to multiple clusters.
+ If the file path contains the cluster name, PetrelBackend will read data
+ from specified cluster or write data to it. Otherwise, PetrelBackend will
+ access the default cluster.
+
+ Args:
+ path_mapping (dict, optional): Path mapping dict from local path to
+ Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in
+ ``filepath`` will be replaced by ``dst``. Default: None.
+ enable_mc (bool, optional): Whether to enable memcached support.
+ Default: True.
+
+ Examples:
+ >>> filepath1 = 's3://path/of/file'
+ >>> filepath2 = 'cluster-name:s3://path/of/file'
+ >>> client = PetrelBackend()
+ >>> client.get(filepath1) # get data from default cluster
+ >>> client.get(filepath2) # get data from 'cluster-name' cluster
+ """
+
+ def __init__(self,
+ path_mapping: Optional[dict] = None,
+ enable_mc: bool = True):
+ try:
+ from petrel_client import client
+ except ImportError:
+ raise ImportError('Please install petrel_client to enable '
+ 'PetrelBackend.')
+
+ self._client = client.Client(enable_mc=enable_mc)
+ assert isinstance(path_mapping, dict) or path_mapping is None
+ self.path_mapping = path_mapping
+
+ def _map_path(self, filepath: Union[str, Path]) -> str:
+ """Map ``filepath`` to a string path whose prefix will be replaced by
+ :attr:`self.path_mapping`.
+
+ Args:
+ filepath (str): Path to be mapped.
+ """
+ filepath = str(filepath)
+ if self.path_mapping is not None:
+ for k, v in self.path_mapping.items():
+ filepath = filepath.replace(k, v)
+ return filepath
+
+ def _format_path(self, filepath: str) -> str:
+ """Convert a ``filepath`` to standard format of petrel oss.
+
+ If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
+ environment, the ``filepath`` will be the format of
+ 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the
+ above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
+
+ Args:
+ filepath (str): Path to be formatted.
+ """
+ return re.sub(r'\\+', '/', filepath)
+
+ def get(self, filepath: Union[str, Path]) -> memoryview:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ memoryview: A memory view of expected bytes object to avoid
+ copying. The memoryview object can be converted to bytes by
+ ``value_buf.tobytes()``.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ value = self._client.Get(filepath)
+ value_buf = memoryview(value)
+ return value_buf
+
+ def get_text(self,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ return str(self.get(filepath), encoding=encoding)
+
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Save data to a given ``filepath``.
+
+ Args:
+ obj (bytes): Data to be saved.
+ filepath (str or Path): Path to write data.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ self._client.put(filepath, obj)
+
+ def put_text(self,
+ obj: str,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ """Save data to a given ``filepath``.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to encode the ``obj``.
+ Default: 'utf-8'.
+ """
+ self.put(bytes(obj, encoding=encoding), filepath)
+
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+
+ Args:
+ filepath (str or Path): Path to be removed.
+ """
+ if not has_method(self._client, 'delete'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `delete` method, please use a higher version or dev'
+ ' branch instead.'))
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ self._client.delete(filepath)
+
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ if not (has_method(self._client, 'contains')
+ and has_method(self._client, 'isdir')):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `contains` and `isdir` methods, please use a higher'
+ 'version or dev branch instead.'))
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.contains(filepath) or self._client.isdir(filepath)
+
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ if not has_method(self._client, 'isdir'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `isdir` method, please use a higher version or dev'
+ ' branch instead.'))
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.isdir(filepath)
+
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ if not has_method(self._client, 'contains'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `contains` method, please use a higher version or '
+ 'dev branch instead.'))
+
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ return self._client.contains(filepath)
+
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+
+ Args:
+ filepath (str or Path): Path to be concatenated.
+
+ Returns:
+ str: The result after concatenation.
+ """
+ filepath = self._format_path(self._map_path(filepath))
+ if filepath.endswith('/'):
+ filepath = filepath[:-1]
+ formatted_paths = [filepath]
+ for path in filepaths:
+ formatted_paths.append(self._format_path(self._map_path(path)))
+ return '/'.join(formatted_paths)
+
+ @contextmanager
+ def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
+ """Download a file from ``filepath`` and return a temporary path.
+
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Args:
+ filepath (str | Path): Download a file from ``filepath``.
+
+ Examples:
+ >>> client = PetrelBackend()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with client.get_local_path('s3://path/of/your/file') as path:
+ ... # do something here
+
+ Yields:
+ Iterable[str]: Only yield one temporary path.
+ """
+ filepath = self._map_path(filepath)
+ filepath = self._format_path(filepath)
+ assert self.isfile(filepath)
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.get(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+
+ Note:
+ Petrel has no concept of directories but it simulates the directory
+ hierarchy in the filesystem through public prefixes. In addition,
+ if the returned path ends with '/', it means the path is a public
+ prefix which is a logical directory.
+
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+ In addition, the returned path of directory will not contains the
+ suffix '/' which is consistent with other backends.
+
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ if not has_method(self._client, 'list'):
+ raise NotImplementedError(
+ ('Current version of Petrel Python SDK has not supported '
+ 'the `list` method, please use a higher version or dev'
+ ' branch instead.'))
+
+ dir_path = self._map_path(dir_path)
+ dir_path = self._format_path(dir_path)
+ if list_dir and suffix is not None:
+ raise TypeError(
+ '`list_dir` should be False when `suffix` is not None')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('`suffix` must be a string or tuple of strings')
+
+ # Petrel's simulated directory hierarchy assumes that directory paths
+ # should end with `/`
+ if not dir_path.endswith('/'):
+ dir_path += '/'
+
+ root = dir_path
+
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive):
+ for path in self._client.list(dir_path):
+ # the `self.isdir` is not used here to determine whether path
+ # is a directory, because `self.isdir` relies on
+ # `self._client.list`
+ if path.endswith('/'): # a directory path
+ next_dir_path = self.join_path(dir_path, path)
+ if list_dir:
+ # get the relative path and exclude the last
+ # character '/'
+ rel_dir = next_dir_path[len(root):-1]
+ yield rel_dir
+ if recursive:
+ yield from _list_dir_or_file(next_dir_path, list_dir,
+ list_file, suffix,
+ recursive)
+ else: # a file path
+ absolute_path = self.join_path(dir_path, path)
+ rel_path = absolute_path[len(root):]
+ if (suffix is None
+ or rel_path.endswith(suffix)) and list_file:
+ yield rel_path
+
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive)
+
+
+class MemcachedBackend(BaseStorageBackend):
+ """Memcached storage backend.
+
+ Attributes:
+ server_list_cfg (str): Config file for memcached server list.
+ client_cfg (str): Config file for memcached client.
+ sys_path (str | None): Additional path to be appended to `sys.path`.
+ Default: None.
+ """
+
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+ if sys_path is not None:
+ import sys
+ sys.path.append(sys_path)
+ try:
+ import mc
+ except ImportError:
+ raise ImportError(
+ 'Please install memcached to enable MemcachedBackend.')
+
+ self.server_list_cfg = server_list_cfg
+ self.client_cfg = client_cfg
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
+ self.client_cfg)
+ # mc.pyvector servers as a point which points to a memory cache
+ self._mc_buffer = mc.pyvector()
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ import mc
+ self._client.Get(filepath, self._mc_buffer)
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
+ return value_buf
+
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+
+
+class LmdbBackend(BaseStorageBackend):
+ """Lmdb storage backend.
+
+ Args:
+ db_path (str): Lmdb database path.
+ readonly (bool, optional): Lmdb environment parameter. If True,
+ disallow any write operations. Default: True.
+ lock (bool, optional): Lmdb environment parameter. If False, when
+ concurrent access occurs, do not lock the database. Default: False.
+ readahead (bool, optional): Lmdb environment parameter. If False,
+ disable the OS filesystem readahead mechanism, which may improve
+ random read performance when a database is larger than RAM.
+ Default: False.
+
+ Attributes:
+ db_path (str): Lmdb database path.
+ """
+
+ def __init__(self,
+ db_path,
+ readonly=True,
+ lock=False,
+ readahead=False,
+ **kwargs):
+ try:
+ import lmdb
+ except ImportError:
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+ self.db_path = str(db_path)
+ self._client = lmdb.open(
+ self.db_path,
+ readonly=readonly,
+ lock=lock,
+ readahead=readahead,
+ **kwargs)
+
+ def get(self, filepath):
+ """Get values according to the filepath.
+
+ Args:
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+ """
+ filepath = str(filepath)
+ with self._client.begin(write=False) as txn:
+ value_buf = txn.get(filepath.encode('ascii'))
+ return value_buf
+
+ def get_text(self, filepath, encoding=None):
+ raise NotImplementedError
+
+
+class HardDiskBackend(BaseStorageBackend):
+ """Raw hard disks storage backend."""
+
+ _allow_symlink = True
+
+ def get(self, filepath: Union[str, Path]) -> bytes:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ bytes: Expected bytes object.
+ """
+ with open(filepath, 'rb') as f:
+ value_buf = f.read()
+ return value_buf
+
+ def get_text(self,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ with open(filepath, 'r', encoding=encoding) as f:
+ value_buf = f.read()
+ return value_buf
+
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+
+ Note:
+ ``put`` will create a directory if the directory of ``filepath``
+ does not exist.
+
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ mmcv.mkdir_or_exist(osp.dirname(filepath))
+ with open(filepath, 'wb') as f:
+ f.write(obj)
+
+ def put_text(self,
+ obj: str,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+
+ Note:
+ ``put_text`` will create a directory if the directory of
+ ``filepath`` does not exist.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+ """
+ mmcv.mkdir_or_exist(osp.dirname(filepath))
+ with open(filepath, 'w', encoding=encoding) as f:
+ f.write(obj)
+
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+
+ Args:
+ filepath (str or Path): Path to be removed.
+ """
+ os.remove(filepath)
+
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ return osp.exists(filepath)
+
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ return osp.isdir(filepath)
+
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ return osp.isfile(filepath)
+
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+
+ Join one or more filepath components intelligently. The return value
+ is the concatenation of filepath and any members of *filepaths.
+
+ Args:
+ filepath (str or Path): Path to be concatenated.
+
+ Returns:
+ str: The result of concatenation.
+ """
+ return osp.join(filepath, *filepaths)
+
+ @contextmanager
+ def get_local_path(
+ self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]:
+ """Only for unified API and do nothing."""
+ yield filepath
+
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ if list_dir and suffix is not None:
+ raise TypeError('`suffix` should be None when `list_dir` is True')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('`suffix` must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ rel_path = osp.relpath(entry.path, root)
+ if (suffix is None
+ or rel_path.endswith(suffix)) and list_file:
+ yield rel_path
+ elif osp.isdir(entry.path):
+ if list_dir:
+ rel_dir = osp.relpath(entry.path, root)
+ yield rel_dir
+ if recursive:
+ yield from _list_dir_or_file(entry.path, list_dir,
+ list_file, suffix,
+ recursive)
+
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+ recursive)
+
+
+class HTTPBackend(BaseStorageBackend):
+ """HTTP and HTTPS storage bachend."""
+
+ def get(self, filepath):
+ value_buf = urlopen(filepath).read()
+ return value_buf
+
+ def get_text(self, filepath, encoding='utf-8'):
+ value_buf = urlopen(filepath).read()
+ return value_buf.decode(encoding)
+
+ @contextmanager
+ def get_local_path(self, filepath: str) -> Iterable[str]:
+ """Download a file from ``filepath``.
+
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Args:
+ filepath (str): Download a file from ``filepath``.
+
+ Examples:
+ >>> client = HTTPBackend()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with client.get_local_path('http://path/of/your/file') as path:
+ ... # do something here
+ """
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.get(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+
+
+class FileClient:
+ """A general file client to access files in different backends.
+
+ The client loads a file or text in a specified backend from its path
+ and returns it as a binary or text file. There are two ways to choose a
+ backend, the name of backend and the prefix of path. Although both of them
+ can be used to choose a storage backend, ``backend`` has a higher priority
+ that is if they are all set, the storage backend will be chosen by the
+ backend argument. If they are all `None`, the disk backend will be chosen.
+ Note that It can also register other backend accessor with a given name,
+ prefixes, and backend class. In addition, We use the singleton pattern to
+ avoid repeated object creation. If the arguments are the same, the same
+ object will be returned.
+
+ Args:
+ backend (str, optional): The storage backend type. Options are "disk",
+ "ceph", "memcached", "lmdb", "http" and "petrel". Default: None.
+ prefix (str, optional): The prefix of the registered storage backend.
+ Options are "s3", "http", "https". Default: None.
+
+ Examples:
+ >>> # only set backend
+ >>> file_client = FileClient(backend='petrel')
+ >>> # only set prefix
+ >>> file_client = FileClient(prefix='s3')
+ >>> # set both backend and prefix but use backend to choose client
+ >>> file_client = FileClient(backend='petrel', prefix='s3')
+ >>> # if the arguments are the same, the same object is returned
+ >>> file_client1 = FileClient(backend='petrel')
+ >>> file_client1 is file_client
+ True
+
+ Attributes:
+ client (:obj:`BaseStorageBackend`): The backend object.
+ """
+
+ _backends = {
+ 'disk': HardDiskBackend,
+ 'ceph': CephBackend,
+ 'memcached': MemcachedBackend,
+ 'lmdb': LmdbBackend,
+ 'petrel': PetrelBackend,
+ 'http': HTTPBackend,
+ }
+ # This collection is used to record the overridden backends, and when a
+ # backend appears in the collection, the singleton pattern is disabled for
+ # that backend, because if the singleton pattern is used, then the object
+ # returned will be the backend before overwriting
+ _overridden_backends = set()
+ _prefix_to_backends = {
+ 's3': PetrelBackend,
+ 'http': HTTPBackend,
+ 'https': HTTPBackend,
+ }
+ _overridden_prefixes = set()
+
+ _instances = {}
+
+ def __new__(cls, backend=None, prefix=None, **kwargs):
+ if backend is None and prefix is None:
+ backend = 'disk'
+ if backend is not None and backend not in cls._backends:
+ raise ValueError(
+ f'Backend {backend} is not supported. Currently supported ones'
+ f' are {list(cls._backends.keys())}')
+ if prefix is not None and prefix not in cls._prefix_to_backends:
+ raise ValueError(
+ f'prefix {prefix} is not supported. Currently supported ones '
+ f'are {list(cls._prefix_to_backends.keys())}')
+
+ # concatenate the arguments to a unique key for determining whether
+ # objects with the same arguments were created
+ arg_key = f'{backend}:{prefix}'
+ for key, value in kwargs.items():
+ arg_key += f':{key}:{value}'
+
+ # if a backend was overridden, it will create a new object
+ if (arg_key in cls._instances
+ and backend not in cls._overridden_backends
+ and prefix not in cls._overridden_prefixes):
+ _instance = cls._instances[arg_key]
+ else:
+ # create a new object and put it to _instance
+ _instance = super().__new__(cls)
+ if backend is not None:
+ _instance.client = cls._backends[backend](**kwargs)
+ else:
+ _instance.client = cls._prefix_to_backends[prefix](**kwargs)
+
+ cls._instances[arg_key] = _instance
+
+ return _instance
+
+ @property
+ def name(self):
+ return self.client.name
+
+ @property
+ def allow_symlink(self):
+ return self.client.allow_symlink
+
+ @staticmethod
+ def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]:
+ """Parse the prefix of a uri.
+
+ Args:
+ uri (str | Path): Uri to be parsed that contains the file prefix.
+
+ Examples:
+ >>> FileClient.parse_uri_prefix('s3://path/of/your/file')
+ 's3'
+
+ Returns:
+ str | None: Return the prefix of uri if the uri contains '://'
+ else ``None``.
+ """
+ assert is_filepath(uri)
+ uri = str(uri)
+ if '://' not in uri:
+ return None
+ else:
+ prefix, _ = uri.split('://')
+ # In the case of PetrelBackend, the prefix may contains the cluster
+ # name like clusterName:s3
+ if ':' in prefix:
+ _, prefix = prefix.split(':')
+ return prefix
+
+ @classmethod
+ def infer_client(cls,
+ file_client_args: Optional[dict] = None,
+ uri: Optional[Union[str, Path]] = None) -> 'FileClient':
+ """Infer a suitable file client based on the URI and arguments.
+
+ Args:
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. Default: None.
+ uri (str | Path, optional): Uri to be parsed that contains the file
+ prefix. Default: None.
+
+ Examples:
+ >>> uri = 's3://path/of/your/file'
+ >>> file_client = FileClient.infer_client(uri=uri)
+ >>> file_client_args = {'backend': 'petrel'}
+ >>> file_client = FileClient.infer_client(file_client_args)
+
+ Returns:
+ FileClient: Instantiated FileClient object.
+ """
+ assert file_client_args is not None or uri is not None
+ if file_client_args is None:
+ file_prefix = cls.parse_uri_prefix(uri) # type: ignore
+ return cls(prefix=file_prefix)
+ else:
+ return cls(**file_client_args)
+
+ @classmethod
+ def _register_backend(cls, name, backend, force=False, prefixes=None):
+ if not isinstance(name, str):
+ raise TypeError('the backend name should be a string, '
+ f'but got {type(name)}')
+ if not inspect.isclass(backend):
+ raise TypeError(
+ f'backend should be a class but got {type(backend)}')
+ if not issubclass(backend, BaseStorageBackend):
+ raise TypeError(
+ f'backend {backend} is not a subclass of BaseStorageBackend')
+ if not force and name in cls._backends:
+ raise KeyError(
+ f'{name} is already registered as a storage backend, '
+ 'add "force=True" if you want to override it')
+
+ if name in cls._backends and force:
+ cls._overridden_backends.add(name)
+ cls._backends[name] = backend
+
+ if prefixes is not None:
+ if isinstance(prefixes, str):
+ prefixes = [prefixes]
+ else:
+ assert isinstance(prefixes, (list, tuple))
+ for prefix in prefixes:
+ if prefix not in cls._prefix_to_backends:
+ cls._prefix_to_backends[prefix] = backend
+ elif (prefix in cls._prefix_to_backends) and force:
+ cls._overridden_prefixes.add(prefix)
+ cls._prefix_to_backends[prefix] = backend
+ else:
+ raise KeyError(
+ f'{prefix} is already registered as a storage backend,'
+ ' add "force=True" if you want to override it')
+
+ @classmethod
+ def register_backend(cls, name, backend=None, force=False, prefixes=None):
+ """Register a backend to FileClient.
+
+ This method can be used as a normal class method or a decorator.
+
+ .. code-block:: python
+
+ class NewBackend(BaseStorageBackend):
+
+ def get(self, filepath):
+ return filepath
+
+ def get_text(self, filepath):
+ return filepath
+
+ FileClient.register_backend('new', NewBackend)
+
+ or
+
+ .. code-block:: python
+
+ @FileClient.register_backend('new')
+ class NewBackend(BaseStorageBackend):
+
+ def get(self, filepath):
+ return filepath
+
+ def get_text(self, filepath):
+ return filepath
+
+ Args:
+ name (str): The name of the registered backend.
+ backend (class, optional): The backend class to be registered,
+ which must be a subclass of :class:`BaseStorageBackend`.
+ When this method is used as a decorator, backend is None.
+ Defaults to None.
+ force (bool, optional): Whether to override the backend if the name
+ has already been registered. Defaults to False.
+ prefixes (str or list[str] or tuple[str], optional): The prefixes
+ of the registered storage backend. Default: None.
+ `New in version 1.3.15.`
+ """
+ if backend is not None:
+ cls._register_backend(
+ name, backend, force=force, prefixes=prefixes)
+ return
+
+ def _register(backend_cls):
+ cls._register_backend(
+ name, backend_cls, force=force, prefixes=prefixes)
+ return backend_cls
+
+ return _register
+
+ def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Note:
+ There are two types of return values for ``get``, one is ``bytes``
+ and the other is ``memoryview``. The advantage of using memoryview
+ is that you can avoid copying, and if you want to convert it to
+ ``bytes``, you can use ``.tobytes()``.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ bytes | memoryview: Expected bytes object or a memory view of the
+ bytes object.
+ """
+ return self.client.get(filepath)
+
+ def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ return self.client.get_text(filepath, encoding)
+
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+
+ Note:
+ ``put`` should create a directory if the directory of ``filepath``
+ does not exist.
+
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ self.client.put(obj, filepath)
+
+ def put_text(self, obj: str, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+
+ Note:
+ ``put_text`` should create a directory if the directory of
+ ``filepath`` does not exist.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str, optional): The encoding format used to open the
+ `filepath`. Default: 'utf-8'.
+ """
+ self.client.put_text(obj, filepath)
+
+ def remove(self, filepath: Union[str, Path]) -> None:
+ """Remove a file.
+
+ Args:
+ filepath (str, Path): Path to be removed.
+ """
+ self.client.remove(filepath)
+
+ def exists(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path exists.
+
+ Args:
+ filepath (str or Path): Path to be checked whether exists.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+ """
+ return self.client.exists(filepath)
+
+ def isdir(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a directory.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a
+ directory.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a directory,
+ ``False`` otherwise.
+ """
+ return self.client.isdir(filepath)
+
+ def isfile(self, filepath: Union[str, Path]) -> bool:
+ """Check whether a file path is a file.
+
+ Args:
+ filepath (str or Path): Path to be checked whether it is a file.
+
+ Returns:
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
+ otherwise.
+ """
+ return self.client.isfile(filepath)
+
+ def join_path(self, filepath: Union[str, Path],
+ *filepaths: Union[str, Path]) -> str:
+ """Concatenate all file paths.
+
+ Join one or more filepath components intelligently. The return value
+ is the concatenation of filepath and any members of *filepaths.
+
+ Args:
+ filepath (str or Path): Path to be concatenated.
+
+ Returns:
+ str: The result of concatenation.
+ """
+ return self.client.join_path(filepath, *filepaths)
+
+ @contextmanager
+ def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
+ """Download data from ``filepath`` and write the data to local path.
+
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Note:
+ If the ``filepath`` is a local path, just return itself.
+
+ .. warning::
+ ``get_local_path`` is an experimental interface that may change in
+ the future.
+
+ Args:
+ filepath (str or Path): Path to be read data.
+
+ Examples:
+ >>> file_client = FileClient(prefix='s3')
+ >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path:
+ ... # do something here
+
+ Yields:
+ Iterable[str]: Only yield one path.
+ """
+ with self.client.get_local_path(str(filepath)) as local_path:
+ yield local_path
+
+ def list_dir_or_file(self,
+ dir_path: Union[str, Path],
+ list_dir: bool = True,
+ list_file: bool = True,
+ suffix: Optional[Union[str, Tuple[str]]] = None,
+ recursive: bool = False) -> Iterator[str]:
+ """Scan a directory to find the interested directories or files in
+ arbitrary order.
+
+ Note:
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+
+ Args:
+ dir_path (str | Path): Path of the directory.
+ list_dir (bool): List the directories. Default: True.
+ list_file (bool): List the path of files. Default: True.
+ suffix (str or tuple[str], optional): File suffix
+ that we are interested in. Default: None.
+ recursive (bool): If set to True, recursively scan the
+ directory. Default: False.
+
+ Yields:
+ Iterable[str]: A relative path to ``dir_path``.
+ """
+ yield from self.client.list_dir_or_file(dir_path, list_dir, list_file,
+ suffix, recursive)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/handlers/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/handlers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa24d91972837b8756b225f4879bac20436eb72a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/handlers/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import BaseFileHandler
+from .json_handler import JsonHandler
+from .pickle_handler import PickleHandler
+from .yaml_handler import YamlHandler
+
+__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler']
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/handlers/base.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/handlers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..288878bc57282fbb2f12b32290152ca8e9d3cab0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/handlers/base.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+
+class BaseFileHandler(metaclass=ABCMeta):
+ # `str_like` is a flag to indicate whether the type of file object is
+ # str-like object or bytes-like object. Pickle only processes bytes-like
+ # objects but json only processes str-like object. If it is str-like
+ # object, `StringIO` will be used to process the buffer.
+ str_like = True
+
+ @abstractmethod
+ def load_from_fileobj(self, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_str(self, obj, **kwargs):
+ pass
+
+ def load_from_path(self, filepath, mode='r', **kwargs):
+ with open(filepath, mode) as f:
+ return self.load_from_fileobj(f, **kwargs)
+
+ def dump_to_path(self, obj, filepath, mode='w', **kwargs):
+ with open(filepath, mode) as f:
+ self.dump_to_fileobj(obj, f, **kwargs)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/handlers/json_handler.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/handlers/json_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..18d4f15f74139d20adff18b20be5529c592a66b6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/handlers/json_handler.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+
+import numpy as np
+
+from .base import BaseFileHandler
+
+
+def set_default(obj):
+ """Set default json values for non-serializable values.
+
+ It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
+ It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
+ etc.) into plain numbers of plain python built-in types.
+ """
+ if isinstance(obj, (set, range)):
+ return list(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, np.generic):
+ return obj.item()
+ raise TypeError(f'{type(obj)} is unsupported for json dump')
+
+
+class JsonHandler(BaseFileHandler):
+
+ def load_from_fileobj(self, file):
+ return json.load(file)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('default', set_default)
+ json.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('default', set_default)
+ return json.dumps(obj, **kwargs)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/handlers/pickle_handler.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/handlers/pickle_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b37c79bed4ef9fd8913715e62dbe3fc5cafdc3aa
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/handlers/pickle_handler.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pickle
+
+from .base import BaseFileHandler
+
+
+class PickleHandler(BaseFileHandler):
+
+ str_like = False
+
+ def load_from_fileobj(self, file, **kwargs):
+ return pickle.load(file, **kwargs)
+
+ def load_from_path(self, filepath, **kwargs):
+ return super(PickleHandler, self).load_from_path(
+ filepath, mode='rb', **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ return pickle.dumps(obj, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('protocol', 2)
+ pickle.dump(obj, file, **kwargs)
+
+ def dump_to_path(self, obj, filepath, **kwargs):
+ super(PickleHandler, self).dump_to_path(
+ obj, filepath, mode='wb', **kwargs)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/handlers/yaml_handler.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/handlers/yaml_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5aa2eea1e8c76f8baf753d1c8c959dee665e543
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/handlers/yaml_handler.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import yaml
+
+try:
+ from yaml import CLoader as Loader, CDumper as Dumper
+except ImportError:
+ from yaml import Loader, Dumper
+
+from .base import BaseFileHandler # isort:skip
+
+
+class YamlHandler(BaseFileHandler):
+
+ def load_from_fileobj(self, file, **kwargs):
+ kwargs.setdefault('Loader', Loader)
+ return yaml.load(file, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault('Dumper', Dumper)
+ yaml.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault('Dumper', Dumper)
+ return yaml.dump(obj, **kwargs)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/io.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaefde58aa3ea5b58f86249ce7e1c40c186eb8dd
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/io.py
@@ -0,0 +1,151 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from io import BytesIO, StringIO
+from pathlib import Path
+
+from ..utils import is_list_of, is_str
+from .file_client import FileClient
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+
+file_handlers = {
+ 'json': JsonHandler(),
+ 'yaml': YamlHandler(),
+ 'yml': YamlHandler(),
+ 'pickle': PickleHandler(),
+ 'pkl': PickleHandler()
+}
+
+
+def load(file, file_format=None, file_client_args=None, **kwargs):
+ """Load data from json/yaml/pickle files.
+
+ This method provides a unified api for loading data from serialized files.
+
+ Note:
+ In v1.3.16 and later, ``load`` supports loading data from serialized
+ files those can be storaged in different backends.
+
+ Args:
+ file (str or :obj:`Path` or file-like object): Filename or a file-like
+ object.
+ file_format (str, optional): If not specified, the file format will be
+ inferred from the file extension, otherwise use the specified one.
+ Currently supported formats include "json", "yaml/yml" and
+ "pickle/pkl".
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> load('/path/of/your/file') # file is storaged in disk
+ >>> load('https://path/of/your/file') # file is storaged in Internet
+ >>> load('s3://path/of/your/file') # file is storaged in petrel
+
+ Returns:
+ The content from the file.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None and is_str(file):
+ file_format = file.split('.')[-1]
+ if file_format not in file_handlers:
+ raise TypeError(f'Unsupported format: {file_format}')
+
+ handler = file_handlers[file_format]
+ if is_str(file):
+ file_client = FileClient.infer_client(file_client_args, file)
+ if handler.str_like:
+ with StringIO(file_client.get_text(file)) as f:
+ obj = handler.load_from_fileobj(f, **kwargs)
+ else:
+ with BytesIO(file_client.get(file)) as f:
+ obj = handler.load_from_fileobj(f, **kwargs)
+ elif hasattr(file, 'read'):
+ obj = handler.load_from_fileobj(file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filepath str or a file-object')
+ return obj
+
+
+def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
+ """Dump data to json/yaml/pickle strings or files.
+
+ This method provides a unified api for dumping data as strings or to files,
+ and also supports custom arguments for each file format.
+
+ Note:
+ In v1.3.16 and later, ``dump`` supports dumping data as strings or to
+ files which is saved to different backends.
+
+ Args:
+ obj (any): The python object to be dumped.
+ file (str or :obj:`Path` or file-like object, optional): If not
+ specified, then the object is dumped to a str, otherwise to a file
+ specified by the filename or file-like object.
+ file_format (str, optional): Same as :func:`load`.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> dump('hello world', '/path/of/your/file') # disk
+ >>> dump('hello world', 's3://path/of/your/file') # ceph or petrel
+
+ Returns:
+ bool: True for success, False otherwise.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None:
+ if is_str(file):
+ file_format = file.split('.')[-1]
+ elif file is None:
+ raise ValueError(
+ 'file_format must be specified since file is None')
+ if file_format not in file_handlers:
+ raise TypeError(f'Unsupported format: {file_format}')
+
+ handler = file_handlers[file_format]
+ if file is None:
+ return handler.dump_to_str(obj, **kwargs)
+ elif is_str(file):
+ file_client = FileClient.infer_client(file_client_args, file)
+ if handler.str_like:
+ with StringIO() as f:
+ handler.dump_to_fileobj(obj, f, **kwargs)
+ file_client.put_text(f.getvalue(), file)
+ else:
+ with BytesIO() as f:
+ handler.dump_to_fileobj(obj, f, **kwargs)
+ file_client.put(f.getvalue(), file)
+ elif hasattr(file, 'write'):
+ handler.dump_to_fileobj(obj, file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filename str or a file-object')
+
+
+def _register_handler(handler, file_formats):
+ """Register a handler for some file extensions.
+
+ Args:
+ handler (:obj:`BaseFileHandler`): Handler to be registered.
+ file_formats (str or list[str]): File formats to be handled by this
+ handler.
+ """
+ if not isinstance(handler, BaseFileHandler):
+ raise TypeError(
+ f'handler must be a child of BaseFileHandler, not {type(handler)}')
+ if isinstance(file_formats, str):
+ file_formats = [file_formats]
+ if not is_list_of(file_formats, str):
+ raise TypeError('file_formats must be a str or a list of str')
+ for ext in file_formats:
+ file_handlers[ext] = handler
+
+
+def register_handler(file_formats, **kwargs):
+
+ def wrap(cls):
+ _register_handler(cls(**kwargs), file_formats)
+ return cls
+
+ return wrap
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/parse.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..f60f0d611b8d75692221d0edd7dc993b0a6445c9
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/fileio/parse.py
@@ -0,0 +1,97 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from io import StringIO
+
+from .file_client import FileClient
+
+
+def list_from_file(filename,
+ prefix='',
+ offset=0,
+ max_num=0,
+ encoding='utf-8',
+ file_client_args=None):
+ """Load a text file and parse the content as a list of strings.
+
+ Note:
+ In v1.3.16 and later, ``list_from_file`` supports loading a text file
+ which can be storaged in different backends and parsing the content as
+ a list for strings.
+
+ Args:
+ filename (str): Filename.
+ prefix (str): The prefix to be inserted to the beginning of each item.
+ offset (int): The offset of lines.
+ max_num (int): The maximum number of lines to be read,
+ zeros and negatives mean no limitation.
+ encoding (str): Encoding used to open the file. Default utf-8.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> list_from_file('/path/of/your/file') # disk
+ ['hello', 'world']
+ >>> list_from_file('s3://path/of/your/file') # ceph or petrel
+ ['hello', 'world']
+
+ Returns:
+ list[str]: A list of strings.
+ """
+ cnt = 0
+ item_list = []
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with StringIO(file_client.get_text(filename, encoding)) as f:
+ for _ in range(offset):
+ f.readline()
+ for line in f:
+ if 0 < max_num <= cnt:
+ break
+ item_list.append(prefix + line.rstrip('\n\r'))
+ cnt += 1
+ return item_list
+
+
+def dict_from_file(filename,
+ key_type=str,
+ encoding='utf-8',
+ file_client_args=None):
+ """Load a text file and parse the content as a dict.
+
+ Each line of the text file will be two or more columns split by
+ whitespaces or tabs. The first column will be parsed as dict keys, and
+ the following columns will be parsed as dict values.
+
+ Note:
+ In v1.3.16 and later, ``dict_from_file`` supports loading a text file
+ which can be storaged in different backends and parsing the content as
+ a dict.
+
+ Args:
+ filename(str): Filename.
+ key_type(type): Type of the dict keys. str is user by default and
+ type conversion will be performed if specified.
+ encoding (str): Encoding used to open the file. Default utf-8.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+
+ Examples:
+ >>> dict_from_file('/path/of/your/file') # disk
+ {'key1': 'value1', 'key2': 'value2'}
+ >>> dict_from_file('s3://path/of/your/file') # ceph or petrel
+ {'key1': 'value1', 'key2': 'value2'}
+
+ Returns:
+ dict: The parsed contents.
+ """
+ mapping = {}
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with StringIO(file_client.get_text(filename, encoding)) as f:
+ for line in f:
+ items = line.rstrip('\n').split()
+ assert len(items) >= 2
+ key = key_type(items[0])
+ val = items[1:] if len(items) > 2 else items[1]
+ mapping[key] = val
+ return mapping
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/image/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/image/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0051d609d3de4e7562e3fe638335c66617c4d91
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/image/__init__.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, bgr2ycbcr,
+ gray2bgr, gray2rgb, hls2bgr, hsv2bgr, imconvert,
+ rgb2bgr, rgb2gray, rgb2ycbcr, ycbcr2bgr, ycbcr2rgb)
+from .geometric import (cutout, imcrop, imflip, imflip_, impad,
+ impad_to_multiple, imrescale, imresize, imresize_like,
+ imresize_to_multiple, imrotate, imshear, imtranslate,
+ rescale_size)
+from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
+from .misc import tensor2imgs
+from .photometric import (adjust_brightness, adjust_color, adjust_contrast,
+ adjust_lighting, adjust_sharpness, auto_contrast,
+ clahe, imdenormalize, imequalize, iminvert,
+ imnormalize, imnormalize_, lut_transform, posterize,
+ solarize)
+
+__all__ = [
+ 'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
+ 'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale',
+ 'imresize', 'imresize_like', 'imresize_to_multiple', 'rescale_size',
+ 'imcrop', 'imflip', 'imflip_', 'impad', 'impad_to_multiple', 'imrotate',
+ 'imfrombytes', 'imread', 'imwrite', 'supported_backends', 'use_backend',
+ 'imdenormalize', 'imnormalize', 'imnormalize_', 'iminvert', 'posterize',
+ 'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr',
+ 'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize',
+ 'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe',
+ 'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/image/colorspace.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/image/colorspace.py
new file mode 100644
index 0000000000000000000000000000000000000000..814533952fdfda23d67cb6a3073692d8c1156add
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/image/colorspace.py
@@ -0,0 +1,306 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+
+def imconvert(img, src, dst):
+ """Convert an image from the src colorspace to dst colorspace.
+
+ Args:
+ img (ndarray): The input image.
+ src (str): The source colorspace, e.g., 'rgb', 'hsv'.
+ dst (str): The destination colorspace, e.g., 'rgb', 'hsv'.
+
+ Returns:
+ ndarray: The converted image.
+ """
+ code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
+ out_img = cv2.cvtColor(img, code)
+ return out_img
+
+
+def bgr2gray(img, keepdim=False):
+ """Convert a BGR image to grayscale image.
+
+ Args:
+ img (ndarray): The input image.
+ keepdim (bool): If False (by default), then return the grayscale image
+ with 2 dims, otherwise 3 dims.
+
+ Returns:
+ ndarray: The converted grayscale image.
+ """
+ out_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ if keepdim:
+ out_img = out_img[..., None]
+ return out_img
+
+
+def rgb2gray(img, keepdim=False):
+ """Convert a RGB image to grayscale image.
+
+ Args:
+ img (ndarray): The input image.
+ keepdim (bool): If False (by default), then return the grayscale image
+ with 2 dims, otherwise 3 dims.
+
+ Returns:
+ ndarray: The converted grayscale image.
+ """
+ out_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+ if keepdim:
+ out_img = out_img[..., None]
+ return out_img
+
+
+def gray2bgr(img):
+ """Convert a grayscale image to BGR image.
+
+ Args:
+ img (ndarray): The input image.
+
+ Returns:
+ ndarray: The converted BGR image.
+ """
+ img = img[..., None] if img.ndim == 2 else img
+ out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ return out_img
+
+
+def gray2rgb(img):
+ """Convert a grayscale image to RGB image.
+
+ Args:
+ img (ndarray): The input image.
+
+ Returns:
+ ndarray: The converted RGB image.
+ """
+ img = img[..., None] if img.ndim == 2 else img
+ out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ return out_img
+
+
+def _convert_input_type_range(img):
+ """Convert the type and range of the input image.
+
+ It converts the input image to np.float32 type and range of [0, 1].
+ It is mainly used for pre-processing the input image in colorspace
+ conversion functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with type of np.float32 and range of
+ [0, 1].
+ """
+ img_type = img.dtype
+ img = img.astype(np.float32)
+ if img_type == np.float32:
+ pass
+ elif img_type == np.uint8:
+ img /= 255.
+ else:
+ raise TypeError('The img type should be np.float32 or np.uint8, '
+ f'but got {img_type}')
+ return img
+
+
+def _convert_output_type_range(img, dst_type):
+ """Convert the type and range of the image according to dst_type.
+
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
+ images will be converted to np.uint8 type with range [0, 255]. If
+ `dst_type` is np.float32, it converts the image to np.float32 type with
+ range [0, 1].
+ It is mainly used for post-processing images in colorspace conversion
+ functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The image to be converted with np.float32 type and
+ range [0, 255].
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+ converts the image to np.uint8 type with range [0, 255]. If
+ dst_type is np.float32, it converts the image to np.float32 type
+ with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with desired type and range.
+ """
+ if dst_type not in (np.uint8, np.float32):
+ raise TypeError('The dst_type should be np.float32 or np.uint8, '
+ f'but got {dst_type}')
+ if dst_type == np.uint8:
+ img = img.round()
+ else:
+ img /= 255.
+ return img.astype(dst_type)
+
+
+def rgb2ycbcr(img, y_only=False):
+ """Convert a RGB image to YCbCr image.
+
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def bgr2ycbcr(img, y_only=False):
+ """Convert a BGR image to YCbCr image.
+
+ The bgr version of rgb2ycbcr.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2rgb(img):
+ """Convert a YCbCr image to RGB image.
+
+ This function produces the same results as Matlab's ycbcr2rgb function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted RGB image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+ [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [
+ -222.921, 135.576, -276.836
+ ]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2bgr(img):
+ """Convert a YCbCr image to BGR image.
+
+ The bgr version of ycbcr2rgb.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted BGR image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+ [0.00791071, -0.00153632, 0],
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [
+ -276.836, 135.576, -222.921
+ ]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def convert_color_factory(src, dst):
+
+ code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
+
+ def convert_color(img):
+ out_img = cv2.cvtColor(img, code)
+ return out_img
+
+ convert_color.__doc__ = f"""Convert a {src.upper()} image to {dst.upper()}
+ image.
+
+ Args:
+ img (ndarray or str): The input image.
+
+ Returns:
+ ndarray: The converted {dst.upper()} image.
+ """
+
+ return convert_color
+
+
+bgr2rgb = convert_color_factory('bgr', 'rgb')
+
+rgb2bgr = convert_color_factory('rgb', 'bgr')
+
+bgr2hsv = convert_color_factory('bgr', 'hsv')
+
+hsv2bgr = convert_color_factory('hsv', 'bgr')
+
+bgr2hls = convert_color_factory('bgr', 'hls')
+
+hls2bgr = convert_color_factory('hls', 'bgr')
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/image/geometric.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/image/geometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf97c201cb4e43796c911919d03fb26a07ed817d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/image/geometric.py
@@ -0,0 +1,728 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+
+import cv2
+import numpy as np
+
+from ..utils import to_2tuple
+from .io import imread_backend
+
+try:
+ from PIL import Image
+except ImportError:
+ Image = None
+
+
+def _scale_size(size, scale):
+ """Rescale a size by a ratio.
+
+ Args:
+ size (tuple[int]): (w, h).
+ scale (float | tuple(float)): Scaling factor.
+
+ Returns:
+ tuple[int]: scaled size.
+ """
+ if isinstance(scale, (float, int)):
+ scale = (scale, scale)
+ w, h = size
+ return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
+
+
+cv2_interp_codes = {
+ 'nearest': cv2.INTER_NEAREST,
+ 'bilinear': cv2.INTER_LINEAR,
+ 'bicubic': cv2.INTER_CUBIC,
+ 'area': cv2.INTER_AREA,
+ 'lanczos': cv2.INTER_LANCZOS4
+}
+
+if Image is not None:
+ pillow_interp_codes = {
+ 'nearest': Image.NEAREST,
+ 'bilinear': Image.BILINEAR,
+ 'bicubic': Image.BICUBIC,
+ 'box': Image.BOX,
+ 'lanczos': Image.LANCZOS,
+ 'hamming': Image.HAMMING
+ }
+
+
+def imresize(img,
+ size,
+ return_scale=False,
+ interpolation='bilinear',
+ out=None,
+ backend=None):
+ """Resize image to a given size.
+
+ Args:
+ img (ndarray): The input image.
+ size (tuple[int]): Target size (w, h).
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ out (ndarray): The output destination.
+ backend (str | None): The image resize backend type. Options are `cv2`,
+ `pillow`, `None`. If backend is None, the global imread_backend
+ specified by ``mmcv.use_backend()`` will be used. Default: None.
+
+ Returns:
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = img.shape[:2]
+ if backend is None:
+ backend = imread_backend
+ if backend not in ['cv2', 'pillow']:
+ raise ValueError(f'backend: {backend} is not supported for resize.'
+ f"Supported backends are 'cv2', 'pillow'")
+
+ if backend == 'pillow':
+ assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
+ pil_image = Image.fromarray(img)
+ pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
+ resized_img = np.array(pil_image)
+ else:
+ resized_img = cv2.resize(
+ img, size, dst=out, interpolation=cv2_interp_codes[interpolation])
+ if not return_scale:
+ return resized_img
+ else:
+ w_scale = size[0] / w
+ h_scale = size[1] / h
+ return resized_img, w_scale, h_scale
+
+
+def imresize_to_multiple(img,
+ divisor,
+ size=None,
+ scale_factor=None,
+ keep_ratio=False,
+ return_scale=False,
+ interpolation='bilinear',
+ out=None,
+ backend=None):
+ """Resize image according to a given size or scale factor and then rounds
+ up the the resized or rescaled image size to the nearest value that can be
+ divided by the divisor.
+
+ Args:
+ img (ndarray): The input image.
+ divisor (int | tuple): Resized image size will be a multiple of
+ divisor. If divisor is a tuple, divisor should be
+ (w_divisor, h_divisor).
+ size (None | int | tuple[int]): Target size (w, h). Default: None.
+ scale_factor (None | float | tuple[float]): Multiplier for spatial
+ size. Should match input size if it is a tuple and the 2D style is
+ (w_scale_factor, h_scale_factor). Default: None.
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+ image. Default: False.
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Interpolation method, accepted values are
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+ backend, "nearest", "bilinear" for 'pillow' backend.
+ out (ndarray): The output destination.
+ backend (str | None): The image resize backend type. Options are `cv2`,
+ `pillow`, `None`. If backend is None, the global imread_backend
+ specified by ``mmcv.use_backend()`` will be used. Default: None.
+
+ Returns:
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = img.shape[:2]
+ if size is not None and scale_factor is not None:
+ raise ValueError('only one of size or scale_factor should be defined')
+ elif size is None and scale_factor is None:
+ raise ValueError('one of size or scale_factor should be defined')
+ elif size is not None:
+ size = to_2tuple(size)
+ if keep_ratio:
+ size = rescale_size((w, h), size, return_scale=False)
+ else:
+ size = _scale_size((w, h), scale_factor)
+
+ divisor = to_2tuple(divisor)
+ size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)])
+ resized_img, w_scale, h_scale = imresize(
+ img,
+ size,
+ return_scale=True,
+ interpolation=interpolation,
+ out=out,
+ backend=backend)
+ if return_scale:
+ return resized_img, w_scale, h_scale
+ else:
+ return resized_img
+
+
+def imresize_like(img,
+ dst_img,
+ return_scale=False,
+ interpolation='bilinear',
+ backend=None):
+ """Resize image to the same size of a given image.
+
+ Args:
+ img (ndarray): The input image.
+ dst_img (ndarray): The target image.
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
+ interpolation (str): Same as :func:`resize`.
+ backend (str | None): Same as :func:`resize`.
+
+ Returns:
+ tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+ `resized_img`.
+ """
+ h, w = dst_img.shape[:2]
+ return imresize(img, (w, h), return_scale, interpolation, backend=backend)
+
+
+def rescale_size(old_size, scale, return_scale=False):
+ """Calculate the new size to be rescaled to.
+
+ Args:
+ old_size (tuple[int]): The old size (w, h) of image.
+ scale (float | tuple[int]): The scaling factor or maximum size.
+ If it is a float number, then the image will be rescaled by this
+ factor, else if it is a tuple of 2 integers, then the image will
+ be rescaled as large as possible within the scale.
+ return_scale (bool): Whether to return the scaling factor besides the
+ rescaled image size.
+
+ Returns:
+ tuple[int]: The new rescaled image size.
+ """
+ w, h = old_size
+ if isinstance(scale, (float, int)):
+ if scale <= 0:
+ raise ValueError(f'Invalid scale {scale}, must be positive.')
+ scale_factor = scale
+ elif isinstance(scale, tuple):
+ max_long_edge = max(scale)
+ max_short_edge = min(scale)
+ scale_factor = min(max_long_edge / max(h, w),
+ max_short_edge / min(h, w))
+ else:
+ raise TypeError(
+ f'Scale must be a number or tuple of int, but got {type(scale)}')
+
+ new_size = _scale_size((w, h), scale_factor)
+
+ if return_scale:
+ return new_size, scale_factor
+ else:
+ return new_size
+
+
+def imrescale(img,
+ scale,
+ return_scale=False,
+ interpolation='bilinear',
+ backend=None):
+ """Resize image while keeping the aspect ratio.
+
+ Args:
+ img (ndarray): The input image.
+ scale (float | tuple[int]): The scaling factor or maximum size.
+ If it is a float number, then the image will be rescaled by this
+ factor, else if it is a tuple of 2 integers, then the image will
+ be rescaled as large as possible within the scale.
+ return_scale (bool): Whether to return the scaling factor besides the
+ rescaled image.
+ interpolation (str): Same as :func:`resize`.
+ backend (str | None): Same as :func:`resize`.
+
+ Returns:
+ ndarray: The rescaled image.
+ """
+ h, w = img.shape[:2]
+ new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
+ rescaled_img = imresize(
+ img, new_size, interpolation=interpolation, backend=backend)
+ if return_scale:
+ return rescaled_img, scale_factor
+ else:
+ return rescaled_img
+
+
+def imflip(img, direction='horizontal'):
+ """Flip an image horizontally or vertically.
+
+ Args:
+ img (ndarray): Image to be flipped.
+ direction (str): The flip direction, either "horizontal" or
+ "vertical" or "diagonal".
+
+ Returns:
+ ndarray: The flipped image.
+ """
+ assert direction in ['horizontal', 'vertical', 'diagonal']
+ if direction == 'horizontal':
+ return np.flip(img, axis=1)
+ elif direction == 'vertical':
+ return np.flip(img, axis=0)
+ else:
+ return np.flip(img, axis=(0, 1))
+
+
+def imflip_(img, direction='horizontal'):
+ """Inplace flip an image horizontally or vertically.
+
+ Args:
+ img (ndarray): Image to be flipped.
+ direction (str): The flip direction, either "horizontal" or
+ "vertical" or "diagonal".
+
+ Returns:
+ ndarray: The flipped image (inplace).
+ """
+ assert direction in ['horizontal', 'vertical', 'diagonal']
+ if direction == 'horizontal':
+ return cv2.flip(img, 1, img)
+ elif direction == 'vertical':
+ return cv2.flip(img, 0, img)
+ else:
+ return cv2.flip(img, -1, img)
+
+
+def imrotate(img,
+ angle,
+ center=None,
+ scale=1.0,
+ border_value=0,
+ interpolation='bilinear',
+ auto_bound=False):
+ """Rotate an image.
+
+ Args:
+ img (ndarray): Image to be rotated.
+ angle (float): Rotation angle in degrees, positive values mean
+ clockwise rotation.
+ center (tuple[float], optional): Center point (w, h) of the rotation in
+ the source image. If not specified, the center of the image will be
+ used.
+ scale (float): Isotropic scale factor.
+ border_value (int): Border value.
+ interpolation (str): Same as :func:`resize`.
+ auto_bound (bool): Whether to adjust the image size to cover the whole
+ rotated image.
+
+ Returns:
+ ndarray: The rotated image.
+ """
+ if center is not None and auto_bound:
+ raise ValueError('`auto_bound` conflicts with `center`')
+ h, w = img.shape[:2]
+ if center is None:
+ center = ((w - 1) * 0.5, (h - 1) * 0.5)
+ assert isinstance(center, tuple)
+
+ matrix = cv2.getRotationMatrix2D(center, -angle, scale)
+ if auto_bound:
+ cos = np.abs(matrix[0, 0])
+ sin = np.abs(matrix[0, 1])
+ new_w = h * sin + w * cos
+ new_h = h * cos + w * sin
+ matrix[0, 2] += (new_w - w) * 0.5
+ matrix[1, 2] += (new_h - h) * 0.5
+ w = int(np.round(new_w))
+ h = int(np.round(new_h))
+ rotated = cv2.warpAffine(
+ img,
+ matrix, (w, h),
+ flags=cv2_interp_codes[interpolation],
+ borderValue=border_value)
+ return rotated
+
+
+def bbox_clip(bboxes, img_shape):
+ """Clip bboxes to fit the image shape.
+
+ Args:
+ bboxes (ndarray): Shape (..., 4*k)
+ img_shape (tuple[int]): (height, width) of the image.
+
+ Returns:
+ ndarray: Clipped bboxes.
+ """
+ assert bboxes.shape[-1] % 4 == 0
+ cmin = np.empty(bboxes.shape[-1], dtype=bboxes.dtype)
+ cmin[0::2] = img_shape[1] - 1
+ cmin[1::2] = img_shape[0] - 1
+ clipped_bboxes = np.maximum(np.minimum(bboxes, cmin), 0)
+ return clipped_bboxes
+
+
+def bbox_scaling(bboxes, scale, clip_shape=None):
+ """Scaling bboxes w.r.t the box center.
+
+ Args:
+ bboxes (ndarray): Shape(..., 4).
+ scale (float): Scaling factor.
+ clip_shape (tuple[int], optional): If specified, bboxes that exceed the
+ boundary will be clipped according to the given shape (h, w).
+
+ Returns:
+ ndarray: Scaled bboxes.
+ """
+ if float(scale) == 1.0:
+ scaled_bboxes = bboxes.copy()
+ else:
+ w = bboxes[..., 2] - bboxes[..., 0] + 1
+ h = bboxes[..., 3] - bboxes[..., 1] + 1
+ dw = (w * (scale - 1)) * 0.5
+ dh = (h * (scale - 1)) * 0.5
+ scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1)
+ if clip_shape is not None:
+ return bbox_clip(scaled_bboxes, clip_shape)
+ else:
+ return scaled_bboxes
+
+
+def imcrop(img, bboxes, scale=1.0, pad_fill=None):
+ """Crop image patches.
+
+ 3 steps: scale the bboxes -> clip bboxes -> crop and pad.
+
+ Args:
+ img (ndarray): Image to be cropped.
+ bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
+ scale (float, optional): Scale ratio of bboxes, the default value
+ 1.0 means no padding.
+ pad_fill (Number | list[Number]): Value to be filled for padding.
+ Default: None, which means no padding.
+
+ Returns:
+ list[ndarray] | ndarray: The cropped image patches.
+ """
+ chn = 1 if img.ndim == 2 else img.shape[2]
+ if pad_fill is not None:
+ if isinstance(pad_fill, (int, float)):
+ pad_fill = [pad_fill for _ in range(chn)]
+ assert len(pad_fill) == chn
+
+ _bboxes = bboxes[None, ...] if bboxes.ndim == 1 else bboxes
+ scaled_bboxes = bbox_scaling(_bboxes, scale).astype(np.int32)
+ clipped_bbox = bbox_clip(scaled_bboxes, img.shape)
+
+ patches = []
+ for i in range(clipped_bbox.shape[0]):
+ x1, y1, x2, y2 = tuple(clipped_bbox[i, :])
+ if pad_fill is None:
+ patch = img[y1:y2 + 1, x1:x2 + 1, ...]
+ else:
+ _x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :])
+ if chn == 1:
+ patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1)
+ else:
+ patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1, chn)
+ patch = np.array(
+ pad_fill, dtype=img.dtype) * np.ones(
+ patch_shape, dtype=img.dtype)
+ x_start = 0 if _x1 >= 0 else -_x1
+ y_start = 0 if _y1 >= 0 else -_y1
+ w = x2 - x1 + 1
+ h = y2 - y1 + 1
+ patch[y_start:y_start + h, x_start:x_start + w,
+ ...] = img[y1:y1 + h, x1:x1 + w, ...]
+ patches.append(patch)
+
+ if bboxes.ndim == 1:
+ return patches[0]
+ else:
+ return patches
+
+
+def impad(img,
+ *,
+ shape=None,
+ padding=None,
+ pad_val=0,
+ padding_mode='constant'):
+ """Pad the given image to a certain shape or pad on all sides with
+ specified padding mode and padding value.
+
+ Args:
+ img (ndarray): Image to be padded.
+ shape (tuple[int]): Expected padding shape (h, w). Default: None.
+ padding (int or tuple[int]): Padding on each border. If a single int is
+ provided this is used to pad all borders. If tuple of length 2 is
+ provided this is the padding on left/right and top/bottom
+ respectively. If a tuple of length 4 is provided this is the
+ padding for the left, top, right and bottom borders respectively.
+ Default: None. Note that `shape` and `padding` can not be both
+ set.
+ pad_val (Number | Sequence[Number]): Values to be filled in padding
+ areas when padding_mode is 'constant'. Default: 0.
+ padding_mode (str): Type of padding. Should be: constant, edge,
+ reflect or symmetric. Default: constant.
+
+ - constant: pads with a constant value, this value is specified
+ with pad_val.
+ - edge: pads with the last value at the edge of the image.
+ - reflect: pads with reflection of image without repeating the
+ last value on the edge. For example, padding [1, 2, 3, 4]
+ with 2 elements on both sides in reflect mode will result
+ in [3, 2, 1, 2, 3, 4, 3, 2].
+ - symmetric: pads with reflection of image repeating the last
+ value on the edge. For example, padding [1, 2, 3, 4] with
+ 2 elements on both sides in symmetric mode will result in
+ [2, 1, 1, 2, 3, 4, 4, 3]
+
+ Returns:
+ ndarray: The padded image.
+ """
+
+ assert (shape is not None) ^ (padding is not None)
+ if shape is not None:
+ padding = (0, 0, shape[1] - img.shape[1], shape[0] - img.shape[0])
+
+ # check pad_val
+ if isinstance(pad_val, tuple):
+ assert len(pad_val) == img.shape[-1]
+ elif not isinstance(pad_val, numbers.Number):
+ raise TypeError('pad_val must be a int or a tuple. '
+ f'But received {type(pad_val)}')
+
+ # check padding
+ if isinstance(padding, tuple) and len(padding) in [2, 4]:
+ if len(padding) == 2:
+ padding = (padding[0], padding[1], padding[0], padding[1])
+ elif isinstance(padding, numbers.Number):
+ padding = (padding, padding, padding, padding)
+ else:
+ raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
+ f'But received {padding}')
+
+ # check padding mode
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
+
+ border_type = {
+ 'constant': cv2.BORDER_CONSTANT,
+ 'edge': cv2.BORDER_REPLICATE,
+ 'reflect': cv2.BORDER_REFLECT_101,
+ 'symmetric': cv2.BORDER_REFLECT
+ }
+ img = cv2.copyMakeBorder(
+ img,
+ padding[1],
+ padding[3],
+ padding[0],
+ padding[2],
+ border_type[padding_mode],
+ value=pad_val)
+
+ return img
+
+
+def impad_to_multiple(img, divisor, pad_val=0):
+ """Pad an image to ensure each edge to be multiple to some number.
+
+ Args:
+ img (ndarray): Image to be padded.
+ divisor (int): Padded image edges will be multiple to divisor.
+ pad_val (Number | Sequence[Number]): Same as :func:`impad`.
+
+ Returns:
+ ndarray: The padded image.
+ """
+ pad_h = int(np.ceil(img.shape[0] / divisor)) * divisor
+ pad_w = int(np.ceil(img.shape[1] / divisor)) * divisor
+ return impad(img, shape=(pad_h, pad_w), pad_val=pad_val)
+
+
+def cutout(img, shape, pad_val=0):
+ """Randomly cut out a rectangle from the original img.
+
+ Args:
+ img (ndarray): Image to be cutout.
+ shape (int | tuple[int]): Expected cutout shape (h, w). If given as a
+ int, the value will be used for both h and w.
+ pad_val (int | float | tuple[int | float]): Values to be filled in the
+ cut area. Defaults to 0.
+
+ Returns:
+ ndarray: The cutout image.
+ """
+
+ channels = 1 if img.ndim == 2 else img.shape[2]
+ if isinstance(shape, int):
+ cut_h, cut_w = shape, shape
+ else:
+ assert isinstance(shape, tuple) and len(shape) == 2, \
+ f'shape must be a int or a tuple with length 2, but got type ' \
+ f'{type(shape)} instead.'
+ cut_h, cut_w = shape
+ if isinstance(pad_val, (int, float)):
+ pad_val = tuple([pad_val] * channels)
+ elif isinstance(pad_val, tuple):
+ assert len(pad_val) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(pad_val), channels)
+ else:
+ raise TypeError(f'Invalid type {type(pad_val)} for `pad_val`')
+
+ img_h, img_w = img.shape[:2]
+ y0 = np.random.uniform(img_h)
+ x0 = np.random.uniform(img_w)
+
+ y1 = int(max(0, y0 - cut_h / 2.))
+ x1 = int(max(0, x0 - cut_w / 2.))
+ y2 = min(img_h, y1 + cut_h)
+ x2 = min(img_w, x1 + cut_w)
+
+ if img.ndim == 2:
+ patch_shape = (y2 - y1, x2 - x1)
+ else:
+ patch_shape = (y2 - y1, x2 - x1, channels)
+
+ img_cutout = img.copy()
+ patch = np.array(
+ pad_val, dtype=img.dtype) * np.ones(
+ patch_shape, dtype=img.dtype)
+ img_cutout[y1:y2, x1:x2, ...] = patch
+
+ return img_cutout
+
+
+def _get_shear_matrix(magnitude, direction='horizontal'):
+ """Generate the shear matrix for transformation.
+
+ Args:
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The flip direction, either "horizontal"
+ or "vertical".
+
+ Returns:
+ ndarray: The shear matrix with dtype float32.
+ """
+ if direction == 'horizontal':
+ shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0]])
+ elif direction == 'vertical':
+ shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0]])
+ return shear_matrix
+
+
+def imshear(img,
+ magnitude,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Shear an image.
+
+ Args:
+ img (ndarray): Image to be sheared with format (h, w)
+ or (h, w, c).
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The flip direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border.
+ interpolation (str): Same as :func:`resize`.
+
+ Returns:
+ ndarray: The sheared image.
+ """
+ assert direction in ['horizontal',
+ 'vertical'], f'Invalid direction: {direction}'
+ height, width = img.shape[:2]
+ if img.ndim == 2:
+ channels = 1
+ elif img.ndim == 3:
+ channels = img.shape[-1]
+ if isinstance(border_value, int):
+ border_value = tuple([border_value] * channels)
+ elif isinstance(border_value, tuple):
+ assert len(border_value) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(border_value), channels)
+ else:
+ raise ValueError(
+ f'Invalid type {type(border_value)} for `border_value`')
+ shear_matrix = _get_shear_matrix(magnitude, direction)
+ sheared = cv2.warpAffine(
+ img,
+ shear_matrix,
+ (width, height),
+ # Note case when the number elements in `border_value`
+ # greater than 3 (e.g. shearing masks whose channels large
+ # than 3) will raise TypeError in `cv2.warpAffine`.
+ # Here simply slice the first 3 values in `border_value`.
+ borderValue=border_value[:3],
+ flags=cv2_interp_codes[interpolation])
+ return sheared
+
+
+def _get_translate_matrix(offset, direction='horizontal'):
+ """Generate the translate matrix.
+
+ Args:
+ offset (int | float): The offset used for translate.
+ direction (str): The translate direction, either
+ "horizontal" or "vertical".
+
+ Returns:
+ ndarray: The translate matrix with dtype float32.
+ """
+ if direction == 'horizontal':
+ translate_matrix = np.float32([[1, 0, offset], [0, 1, 0]])
+ elif direction == 'vertical':
+ translate_matrix = np.float32([[1, 0, 0], [0, 1, offset]])
+ return translate_matrix
+
+
+def imtranslate(img,
+ offset,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Translate an image.
+
+ Args:
+ img (ndarray): Image to be translated with format
+ (h, w) or (h, w, c).
+ offset (int | float): The offset used for translate.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border.
+ interpolation (str): Same as :func:`resize`.
+
+ Returns:
+ ndarray: The translated image.
+ """
+ assert direction in ['horizontal',
+ 'vertical'], f'Invalid direction: {direction}'
+ height, width = img.shape[:2]
+ if img.ndim == 2:
+ channels = 1
+ elif img.ndim == 3:
+ channels = img.shape[-1]
+ if isinstance(border_value, int):
+ border_value = tuple([border_value] * channels)
+ elif isinstance(border_value, tuple):
+ assert len(border_value) == channels, \
+ 'Expected the num of elements in tuple equals the channels' \
+ 'of input image. Found {} vs {}'.format(
+ len(border_value), channels)
+ else:
+ raise ValueError(
+ f'Invalid type {type(border_value)} for `border_value`.')
+ translate_matrix = _get_translate_matrix(offset, direction)
+ translated = cv2.warpAffine(
+ img,
+ translate_matrix,
+ (width, height),
+ # Note case when the number elements in `border_value`
+ # greater than 3 (e.g. translating masks whose channels
+ # large than 3) will raise TypeError in `cv2.warpAffine`.
+ # Here simply slice the first 3 values in `border_value`.
+ borderValue=border_value[:3],
+ flags=cv2_interp_codes[interpolation])
+ return translated
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/image/io.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/image/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e8f1877978840aede93774d86643b129751db13
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/image/io.py
@@ -0,0 +1,258 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import io
+import os.path as osp
+from pathlib import Path
+
+import cv2
+import numpy as np
+from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION,
+ IMREAD_UNCHANGED)
+
+from annotator.mmpkg.mmcv.utils import check_file_exist, is_str, mkdir_or_exist
+
+try:
+ from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG
+except ImportError:
+ TJCS_RGB = TJPF_GRAY = TJPF_BGR = TurboJPEG = None
+
+try:
+ from PIL import Image, ImageOps
+except ImportError:
+ Image = None
+
+try:
+ import tifffile
+except ImportError:
+ tifffile = None
+
+jpeg = None
+supported_backends = ['cv2', 'turbojpeg', 'pillow', 'tifffile']
+
+imread_flags = {
+ 'color': IMREAD_COLOR,
+ 'grayscale': IMREAD_GRAYSCALE,
+ 'unchanged': IMREAD_UNCHANGED,
+ 'color_ignore_orientation': IMREAD_IGNORE_ORIENTATION | IMREAD_COLOR,
+ 'grayscale_ignore_orientation':
+ IMREAD_IGNORE_ORIENTATION | IMREAD_GRAYSCALE
+}
+
+imread_backend = 'cv2'
+
+
+def use_backend(backend):
+ """Select a backend for image decoding.
+
+ Args:
+ backend (str): The image decoding backend type. Options are `cv2`,
+ `pillow`, `turbojpeg` (see https://github.com/lilohuang/PyTurboJPEG)
+ and `tifffile`. `turbojpeg` is faster but it only supports `.jpeg`
+ file format.
+ """
+ assert backend in supported_backends
+ global imread_backend
+ imread_backend = backend
+ if imread_backend == 'turbojpeg':
+ if TurboJPEG is None:
+ raise ImportError('`PyTurboJPEG` is not installed')
+ global jpeg
+ if jpeg is None:
+ jpeg = TurboJPEG()
+ elif imread_backend == 'pillow':
+ if Image is None:
+ raise ImportError('`Pillow` is not installed')
+ elif imread_backend == 'tifffile':
+ if tifffile is None:
+ raise ImportError('`tifffile` is not installed')
+
+
+def _jpegflag(flag='color', channel_order='bgr'):
+ channel_order = channel_order.lower()
+ if channel_order not in ['rgb', 'bgr']:
+ raise ValueError('channel order must be either "rgb" or "bgr"')
+
+ if flag == 'color':
+ if channel_order == 'bgr':
+ return TJPF_BGR
+ elif channel_order == 'rgb':
+ return TJCS_RGB
+ elif flag == 'grayscale':
+ return TJPF_GRAY
+ else:
+ raise ValueError('flag must be "color" or "grayscale"')
+
+
+def _pillow2array(img, flag='color', channel_order='bgr'):
+ """Convert a pillow image to numpy array.
+
+ Args:
+ img (:obj:`PIL.Image.Image`): The image loaded using PIL
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are 'color', 'grayscale' and 'unchanged'.
+ Default to 'color'.
+ channel_order (str): The channel order of the output image array,
+ candidates are 'bgr' and 'rgb'. Default to 'bgr'.
+
+ Returns:
+ np.ndarray: The converted numpy array
+ """
+ channel_order = channel_order.lower()
+ if channel_order not in ['rgb', 'bgr']:
+ raise ValueError('channel order must be either "rgb" or "bgr"')
+
+ if flag == 'unchanged':
+ array = np.array(img)
+ if array.ndim >= 3 and array.shape[2] >= 3: # color image
+ array[:, :, :3] = array[:, :, (2, 1, 0)] # RGB to BGR
+ else:
+ # Handle exif orientation tag
+ if flag in ['color', 'grayscale']:
+ img = ImageOps.exif_transpose(img)
+ # If the image mode is not 'RGB', convert it to 'RGB' first.
+ if img.mode != 'RGB':
+ if img.mode != 'LA':
+ # Most formats except 'LA' can be directly converted to RGB
+ img = img.convert('RGB')
+ else:
+ # When the mode is 'LA', the default conversion will fill in
+ # the canvas with black, which sometimes shadows black objects
+ # in the foreground.
+ #
+ # Therefore, a random color (124, 117, 104) is used for canvas
+ img_rgba = img.convert('RGBA')
+ img = Image.new('RGB', img_rgba.size, (124, 117, 104))
+ img.paste(img_rgba, mask=img_rgba.split()[3]) # 3 is alpha
+ if flag in ['color', 'color_ignore_orientation']:
+ array = np.array(img)
+ if channel_order != 'rgb':
+ array = array[:, :, ::-1] # RGB to BGR
+ elif flag in ['grayscale', 'grayscale_ignore_orientation']:
+ img = img.convert('L')
+ array = np.array(img)
+ else:
+ raise ValueError(
+ 'flag must be "color", "grayscale", "unchanged", '
+ f'"color_ignore_orientation" or "grayscale_ignore_orientation"'
+ f' but got {flag}')
+ return array
+
+
+def imread(img_or_path, flag='color', channel_order='bgr', backend=None):
+ """Read an image.
+
+ Args:
+ img_or_path (ndarray or str or Path): Either a numpy array or str or
+ pathlib.Path. If it is a numpy array (loaded image), then
+ it will be returned as is.
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are `color`, `grayscale`, `unchanged`,
+ `color_ignore_orientation` and `grayscale_ignore_orientation`.
+ By default, `cv2` and `pillow` backend would rotate the image
+ according to its EXIF info unless called with `unchanged` or
+ `*_ignore_orientation` flags. `turbojpeg` and `tifffile` backend
+ always ignore image's EXIF info regardless of the flag.
+ The `turbojpeg` backend only supports `color` and `grayscale`.
+ channel_order (str): Order of channel, candidates are `bgr` and `rgb`.
+ backend (str | None): The image decoding backend type. Options are
+ `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`.
+ If backend is None, the global imread_backend specified by
+ ``mmcv.use_backend()`` will be used. Default: None.
+
+ Returns:
+ ndarray: Loaded image array.
+ """
+
+ if backend is None:
+ backend = imread_backend
+ if backend not in supported_backends:
+ raise ValueError(f'backend: {backend} is not supported. Supported '
+ "backends are 'cv2', 'turbojpeg', 'pillow'")
+ if isinstance(img_or_path, Path):
+ img_or_path = str(img_or_path)
+
+ if isinstance(img_or_path, np.ndarray):
+ return img_or_path
+ elif is_str(img_or_path):
+ check_file_exist(img_or_path,
+ f'img file does not exist: {img_or_path}')
+ if backend == 'turbojpeg':
+ with open(img_or_path, 'rb') as in_file:
+ img = jpeg.decode(in_file.read(),
+ _jpegflag(flag, channel_order))
+ if img.shape[-1] == 1:
+ img = img[:, :, 0]
+ return img
+ elif backend == 'pillow':
+ img = Image.open(img_or_path)
+ img = _pillow2array(img, flag, channel_order)
+ return img
+ elif backend == 'tifffile':
+ img = tifffile.imread(img_or_path)
+ return img
+ else:
+ flag = imread_flags[flag] if is_str(flag) else flag
+ img = cv2.imread(img_or_path, flag)
+ if flag == IMREAD_COLOR and channel_order == 'rgb':
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ return img
+ else:
+ raise TypeError('"img" must be a numpy array or a str or '
+ 'a pathlib.Path object')
+
+
+def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
+ """Read an image from bytes.
+
+ Args:
+ content (bytes): Image bytes got from files or other streams.
+ flag (str): Same as :func:`imread`.
+ backend (str | None): The image decoding backend type. Options are
+ `cv2`, `pillow`, `turbojpeg`, `None`. If backend is None, the
+ global imread_backend specified by ``mmcv.use_backend()`` will be
+ used. Default: None.
+
+ Returns:
+ ndarray: Loaded image array.
+ """
+
+ if backend is None:
+ backend = imread_backend
+ if backend not in supported_backends:
+ raise ValueError(f'backend: {backend} is not supported. Supported '
+ "backends are 'cv2', 'turbojpeg', 'pillow'")
+ if backend == 'turbojpeg':
+ img = jpeg.decode(content, _jpegflag(flag, channel_order))
+ if img.shape[-1] == 1:
+ img = img[:, :, 0]
+ return img
+ elif backend == 'pillow':
+ buff = io.BytesIO(content)
+ img = Image.open(buff)
+ img = _pillow2array(img, flag, channel_order)
+ return img
+ else:
+ img_np = np.frombuffer(content, np.uint8)
+ flag = imread_flags[flag] if is_str(flag) else flag
+ img = cv2.imdecode(img_np, flag)
+ if flag == IMREAD_COLOR and channel_order == 'rgb':
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+ return img
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = osp.abspath(osp.dirname(file_path))
+ mkdir_or_exist(dir_name)
+ return cv2.imwrite(file_path, img, params)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/image/misc.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/image/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd60e66131719ca0627569598809366b9c1ac64d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/image/misc.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+import annotator.mmpkg.mmcv as mmcv
+
+try:
+ import torch
+except ImportError:
+ torch = None
+
+
+def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
+ """Convert tensor to 3-channel images.
+
+ Args:
+ tensor (torch.Tensor): Tensor that contains multiple images, shape (
+ N, C, H, W).
+ mean (tuple[float], optional): Mean of images. Defaults to (0, 0, 0).
+ std (tuple[float], optional): Standard deviation of images.
+ Defaults to (1, 1, 1).
+ to_rgb (bool, optional): Whether the tensor was converted to RGB
+ format in the first place. If so, convert it back to BGR.
+ Defaults to True.
+
+ Returns:
+ list[np.ndarray]: A list that contains multiple images.
+ """
+
+ if torch is None:
+ raise RuntimeError('pytorch is not installed')
+ assert torch.is_tensor(tensor) and tensor.ndim == 4
+ assert len(mean) == 3
+ assert len(std) == 3
+
+ num_imgs = tensor.size(0)
+ mean = np.array(mean, dtype=np.float32)
+ std = np.array(std, dtype=np.float32)
+ imgs = []
+ for img_id in range(num_imgs):
+ img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0)
+ img = mmcv.imdenormalize(
+ img, mean, std, to_bgr=to_rgb).astype(np.uint8)
+ imgs.append(np.ascontiguousarray(img))
+ return imgs
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/image/photometric.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/image/photometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..5085d012019c0cbf56f66f421a378278c1a058ae
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/image/photometric.py
@@ -0,0 +1,428 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+from ..utils import is_tuple_of
+from .colorspace import bgr2gray, gray2bgr
+
+
+def imnormalize(img, mean, std, to_rgb=True):
+ """Normalize an image with mean and std.
+
+ Args:
+ img (ndarray): Image to be normalized.
+ mean (ndarray): The mean to be used for normalize.
+ std (ndarray): The std to be used for normalize.
+ to_rgb (bool): Whether to convert to rgb.
+
+ Returns:
+ ndarray: The normalized image.
+ """
+ img = img.copy().astype(np.float32)
+ return imnormalize_(img, mean, std, to_rgb)
+
+
+def imnormalize_(img, mean, std, to_rgb=True):
+ """Inplace normalize an image with mean and std.
+
+ Args:
+ img (ndarray): Image to be normalized.
+ mean (ndarray): The mean to be used for normalize.
+ std (ndarray): The std to be used for normalize.
+ to_rgb (bool): Whether to convert to rgb.
+
+ Returns:
+ ndarray: The normalized image.
+ """
+ # cv2 inplace normalization does not accept uint8
+ assert img.dtype != np.uint8
+ mean = np.float64(mean.reshape(1, -1))
+ stdinv = 1 / np.float64(std.reshape(1, -1))
+ if to_rgb:
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
+ cv2.subtract(img, mean, img) # inplace
+ cv2.multiply(img, stdinv, img) # inplace
+ return img
+
+
+def imdenormalize(img, mean, std, to_bgr=True):
+ assert img.dtype != np.uint8
+ mean = mean.reshape(1, -1).astype(np.float64)
+ std = std.reshape(1, -1).astype(np.float64)
+ img = cv2.multiply(img, std) # make a copy
+ cv2.add(img, mean, img) # inplace
+ if to_bgr:
+ cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace
+ return img
+
+
+def iminvert(img):
+ """Invert (negate) an image.
+
+ Args:
+ img (ndarray): Image to be inverted.
+
+ Returns:
+ ndarray: The inverted image.
+ """
+ return np.full_like(img, 255) - img
+
+
+def solarize(img, thr=128):
+ """Solarize an image (invert all pixel values above a threshold)
+
+ Args:
+ img (ndarray): Image to be solarized.
+ thr (int): Threshold for solarizing (0 - 255).
+
+ Returns:
+ ndarray: The solarized image.
+ """
+ img = np.where(img < thr, img, 255 - img)
+ return img
+
+
+def posterize(img, bits):
+ """Posterize an image (reduce the number of bits for each color channel)
+
+ Args:
+ img (ndarray): Image to be posterized.
+ bits (int): Number of bits (1 to 8) to use for posterizing.
+
+ Returns:
+ ndarray: The posterized image.
+ """
+ shift = 8 - bits
+ img = np.left_shift(np.right_shift(img, shift), shift)
+ return img
+
+
+def adjust_color(img, alpha=1, beta=None, gamma=0):
+ r"""It blends the source image and its gray image:
+
+ .. math::
+ output = img * alpha + gray\_img * beta + gamma
+
+ Args:
+ img (ndarray): The input source image.
+ alpha (int | float): Weight for the source image. Default 1.
+ beta (int | float): Weight for the converted gray image.
+ If None, it's assigned the value (1 - `alpha`).
+ gamma (int | float): Scalar added to each sum.
+ Same as :func:`cv2.addWeighted`. Default 0.
+
+ Returns:
+ ndarray: Colored image which has the same size and dtype as input.
+ """
+ gray_img = bgr2gray(img)
+ gray_img = np.tile(gray_img[..., None], [1, 1, 3])
+ if beta is None:
+ beta = 1 - alpha
+ colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma)
+ if not colored_img.dtype == np.uint8:
+ # Note when the dtype of `img` is not the default `np.uint8`
+ # (e.g. np.float32), the value in `colored_img` got from cv2
+ # is not guaranteed to be in range [0, 255], so here clip
+ # is needed.
+ colored_img = np.clip(colored_img, 0, 255)
+ return colored_img
+
+
+def imequalize(img):
+ """Equalize the image histogram.
+
+ This function applies a non-linear mapping to the input image,
+ in order to create a uniform distribution of grayscale values
+ in the output image.
+
+ Args:
+ img (ndarray): Image to be equalized.
+
+ Returns:
+ ndarray: The equalized image.
+ """
+
+ def _scale_channel(im, c):
+ """Scale the data in the corresponding channel."""
+ im = im[:, :, c]
+ # Compute the histogram of the image channel.
+ histo = np.histogram(im, 256, (0, 255))[0]
+ # For computing the step, filter out the nonzeros.
+ nonzero_histo = histo[histo > 0]
+ step = (np.sum(nonzero_histo) - nonzero_histo[-1]) // 255
+ if not step:
+ lut = np.array(range(256))
+ else:
+ # Compute the cumulative sum, shifted by step // 2
+ # and then normalized by step.
+ lut = (np.cumsum(histo) + (step // 2)) // step
+ # Shift lut, prepending with 0.
+ lut = np.concatenate([[0], lut[:-1]], 0)
+ # handle potential integer overflow
+ lut[lut > 255] = 255
+ # If step is zero, return the original image.
+ # Otherwise, index from lut.
+ return np.where(np.equal(step, 0), im, lut[im])
+
+ # Scales each channel independently and then stacks
+ # the result.
+ s1 = _scale_channel(img, 0)
+ s2 = _scale_channel(img, 1)
+ s3 = _scale_channel(img, 2)
+ equalized_img = np.stack([s1, s2, s3], axis=-1)
+ return equalized_img.astype(img.dtype)
+
+
+def adjust_brightness(img, factor=1.):
+ """Adjust image brightness.
+
+ This function controls the brightness of an image. An
+ enhancement factor of 0.0 gives a black image.
+ A factor of 1.0 gives the original image. This function
+ blends the source image and the degenerated black image:
+
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+
+ Args:
+ img (ndarray): Image to be brightened.
+ factor (float): A value controls the enhancement.
+ Factor 1.0 returns the original image, lower
+ factors mean less color (brightness, contrast,
+ etc), and higher values more. Default 1.
+
+ Returns:
+ ndarray: The brightened image.
+ """
+ degenerated = np.zeros_like(img)
+ # Note manually convert the dtype to np.float32, to
+ # achieve as close results as PIL.ImageEnhance.Brightness.
+ # Set beta=1-factor, and gamma=0
+ brightened_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ brightened_img = np.clip(brightened_img, 0, 255)
+ return brightened_img.astype(img.dtype)
+
+
+def adjust_contrast(img, factor=1.):
+ """Adjust image contrast.
+
+ This function controls the contrast of an image. An
+ enhancement factor of 0.0 gives a solid grey
+ image. A factor of 1.0 gives the original image. It
+ blends the source image and the degenerated mean image:
+
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+
+ Args:
+ img (ndarray): Image to be contrasted. BGR order.
+ factor (float): Same as :func:`mmcv.adjust_brightness`.
+
+ Returns:
+ ndarray: The contrasted image.
+ """
+ gray_img = bgr2gray(img)
+ hist = np.histogram(gray_img, 256, (0, 255))[0]
+ mean = round(np.sum(gray_img) / np.sum(hist))
+ degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype)
+ degenerated = gray2bgr(degenerated)
+ contrasted_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ contrasted_img = np.clip(contrasted_img, 0, 255)
+ return contrasted_img.astype(img.dtype)
+
+
+def auto_contrast(img, cutoff=0):
+ """Auto adjust image contrast.
+
+ This function maximize (normalize) image contrast by first removing cutoff
+ percent of the lightest and darkest pixels from the histogram and remapping
+ the image so that the darkest pixel becomes black (0), and the lightest
+ becomes white (255).
+
+ Args:
+ img (ndarray): Image to be contrasted. BGR order.
+ cutoff (int | float | tuple): The cutoff percent of the lightest and
+ darkest pixels to be removed. If given as tuple, it shall be
+ (low, high). Otherwise, the single value will be used for both.
+ Defaults to 0.
+
+ Returns:
+ ndarray: The contrasted image.
+ """
+
+ def _auto_contrast_channel(im, c, cutoff):
+ im = im[:, :, c]
+ # Compute the histogram of the image channel.
+ histo = np.histogram(im, 256, (0, 255))[0]
+ # Remove cut-off percent pixels from histo
+ histo_sum = np.cumsum(histo)
+ cut_low = histo_sum[-1] * cutoff[0] // 100
+ cut_high = histo_sum[-1] - histo_sum[-1] * cutoff[1] // 100
+ histo_sum = np.clip(histo_sum, cut_low, cut_high) - cut_low
+ histo = np.concatenate([[histo_sum[0]], np.diff(histo_sum)], 0)
+
+ # Compute mapping
+ low, high = np.nonzero(histo)[0][0], np.nonzero(histo)[0][-1]
+ # If all the values have been cut off, return the origin img
+ if low >= high:
+ return im
+ scale = 255.0 / (high - low)
+ offset = -low * scale
+ lut = np.array(range(256))
+ lut = lut * scale + offset
+ lut = np.clip(lut, 0, 255)
+ return lut[im]
+
+ if isinstance(cutoff, (int, float)):
+ cutoff = (cutoff, cutoff)
+ else:
+ assert isinstance(cutoff, tuple), 'cutoff must be of type int, ' \
+ f'float or tuple, but got {type(cutoff)} instead.'
+ # Auto adjusts contrast for each channel independently and then stacks
+ # the result.
+ s1 = _auto_contrast_channel(img, 0, cutoff)
+ s2 = _auto_contrast_channel(img, 1, cutoff)
+ s3 = _auto_contrast_channel(img, 2, cutoff)
+ contrasted_img = np.stack([s1, s2, s3], axis=-1)
+ return contrasted_img.astype(img.dtype)
+
+
+def adjust_sharpness(img, factor=1., kernel=None):
+ """Adjust image sharpness.
+
+ This function controls the sharpness of an image. An
+ enhancement factor of 0.0 gives a blurred image. A
+ factor of 1.0 gives the original image. And a factor
+ of 2.0 gives a sharpened image. It blends the source
+ image and the degenerated mean image:
+
+ .. math::
+ output = img * factor + degenerated * (1 - factor)
+
+ Args:
+ img (ndarray): Image to be sharpened. BGR order.
+ factor (float): Same as :func:`mmcv.adjust_brightness`.
+ kernel (np.ndarray, optional): Filter kernel to be applied on the img
+ to obtain the degenerated img. Defaults to None.
+
+ Note:
+ No value sanity check is enforced on the kernel set by users. So with
+ an inappropriate kernel, the ``adjust_sharpness`` may fail to perform
+ the function its name indicates but end up performing whatever
+ transform determined by the kernel.
+
+ Returns:
+ ndarray: The sharpened image.
+ """
+
+ if kernel is None:
+ # adopted from PIL.ImageFilter.SMOOTH
+ kernel = np.array([[1., 1., 1.], [1., 5., 1.], [1., 1., 1.]]) / 13
+ assert isinstance(kernel, np.ndarray), \
+ f'kernel must be of type np.ndarray, but got {type(kernel)} instead.'
+ assert kernel.ndim == 2, \
+ f'kernel must have a dimension of 2, but got {kernel.ndim} instead.'
+
+ degenerated = cv2.filter2D(img, -1, kernel)
+ sharpened_img = cv2.addWeighted(
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
+ 1 - factor, 0)
+ sharpened_img = np.clip(sharpened_img, 0, 255)
+ return sharpened_img.astype(img.dtype)
+
+
+def adjust_lighting(img, eigval, eigvec, alphastd=0.1, to_rgb=True):
+ """AlexNet-style PCA jitter.
+
+ This data augmentation is proposed in `ImageNet Classification with Deep
+ Convolutional Neural Networks
+ `_.
+
+ Args:
+ img (ndarray): Image to be adjusted lighting. BGR order.
+ eigval (ndarray): the eigenvalue of the convariance matrix of pixel
+ values, respectively.
+ eigvec (ndarray): the eigenvector of the convariance matrix of pixel
+ values, respectively.
+ alphastd (float): The standard deviation for distribution of alpha.
+ Defaults to 0.1
+ to_rgb (bool): Whether to convert img to rgb.
+
+ Returns:
+ ndarray: The adjusted image.
+ """
+ assert isinstance(eigval, np.ndarray) and isinstance(eigvec, np.ndarray), \
+ f'eigval and eigvec should both be of type np.ndarray, got ' \
+ f'{type(eigval)} and {type(eigvec)} instead.'
+
+ assert eigval.ndim == 1 and eigvec.ndim == 2
+ assert eigvec.shape == (3, eigval.shape[0])
+ n_eigval = eigval.shape[0]
+ assert isinstance(alphastd, float), 'alphastd should be of type float, ' \
+ f'got {type(alphastd)} instead.'
+
+ img = img.copy().astype(np.float32)
+ if to_rgb:
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
+
+ alpha = np.random.normal(0, alphastd, n_eigval)
+ alter = eigvec \
+ * np.broadcast_to(alpha.reshape(1, n_eigval), (3, n_eigval)) \
+ * np.broadcast_to(eigval.reshape(1, n_eigval), (3, n_eigval))
+ alter = np.broadcast_to(alter.sum(axis=1).reshape(1, 1, 3), img.shape)
+ img_adjusted = img + alter
+ return img_adjusted
+
+
+def lut_transform(img, lut_table):
+ """Transform array by look-up table.
+
+ The function lut_transform fills the output array with values from the
+ look-up table. Indices of the entries are taken from the input array.
+
+ Args:
+ img (ndarray): Image to be transformed.
+ lut_table (ndarray): look-up table of 256 elements; in case of
+ multi-channel input array, the table should either have a single
+ channel (in this case the same table is used for all channels) or
+ the same number of channels as in the input array.
+
+ Returns:
+ ndarray: The transformed image.
+ """
+ assert isinstance(img, np.ndarray)
+ assert 0 <= np.min(img) and np.max(img) <= 255
+ assert isinstance(lut_table, np.ndarray)
+ assert lut_table.shape == (256, )
+
+ return cv2.LUT(np.array(img, dtype=np.uint8), lut_table)
+
+
+def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
+ """Use CLAHE method to process the image.
+
+ See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
+ Graphics Gems, 1994:474-485.` for more information.
+
+ Args:
+ img (ndarray): Image to be processed.
+ clip_limit (float): Threshold for contrast limiting. Default: 40.0.
+ tile_grid_size (tuple[int]): Size of grid for histogram equalization.
+ Input image will be divided into equally sized rectangular tiles.
+ It defines the number of tiles in row and column. Default: (8, 8).
+
+ Returns:
+ ndarray: The processed image.
+ """
+ assert isinstance(img, np.ndarray)
+ assert img.ndim == 2
+ assert isinstance(clip_limit, (float, int))
+ assert is_tuple_of(tile_grid_size, int)
+ assert len(tile_grid_size) == 2
+
+ clahe = cv2.createCLAHE(clip_limit, tile_grid_size)
+ return clahe.apply(np.array(img, dtype=np.uint8))
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/model_zoo/deprecated.json b/sd-webui-controlnet/annotator/mmpkg/mmcv/model_zoo/deprecated.json
new file mode 100644
index 0000000000000000000000000000000000000000..25cf6f28caecc22a77e3136fefa6b8dfc0e6cb5b
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/model_zoo/deprecated.json
@@ -0,0 +1,6 @@
+{
+ "resnet50_caffe": "detectron/resnet50_caffe",
+ "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr",
+ "resnet101_caffe": "detectron/resnet101_caffe",
+ "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr"
+}
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/model_zoo/mmcls.json b/sd-webui-controlnet/annotator/mmpkg/mmcv/model_zoo/mmcls.json
new file mode 100644
index 0000000000000000000000000000000000000000..bdb311d9fe6d9f317290feedc9e37236c6cf6e8f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/model_zoo/mmcls.json
@@ -0,0 +1,31 @@
+{
+ "vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth",
+ "vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth",
+ "vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth",
+ "vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth",
+ "vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth",
+ "vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth",
+ "vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth",
+ "vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth",
+ "resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_batch256_imagenet_20200708-34ab8f90.pth",
+ "resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_batch256_imagenet_20200708-32ffb4f7.pth",
+ "resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth",
+ "resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_batch256_imagenet_20200708-753f3608.pth",
+ "resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_batch256_imagenet_20200708-ec25b1f9.pth",
+ "resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_batch256_imagenet_20200708-1ad0ce94.pth",
+ "resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_batch256_imagenet_20200708-9cb302ef.pth",
+ "resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_batch256_imagenet_20200708-e79cb6a2.pth",
+ "resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth",
+ "resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth",
+ "resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth",
+ "resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth",
+ "se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth",
+ "se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth",
+ "resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth",
+ "resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth",
+ "resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth",
+ "resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth",
+ "shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth",
+ "shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth",
+ "mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth"
+}
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/model_zoo/open_mmlab.json b/sd-webui-controlnet/annotator/mmpkg/mmcv/model_zoo/open_mmlab.json
new file mode 100644
index 0000000000000000000000000000000000000000..8311db4feef92faa0841c697d75efbee8430c3a0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/model_zoo/open_mmlab.json
@@ -0,0 +1,50 @@
+{
+ "vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
+ "detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
+ "detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth",
+ "detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
+ "detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
+ "detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth",
+ "resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
+ "resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
+ "resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
+ "contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
+ "detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
+ "detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
+ "jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
+ "jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
+ "jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
+ "jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
+ "jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
+ "jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
+ "msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth",
+ "msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
+ "msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
+ "msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
+ "msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth",
+ "bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
+ "kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
+ "kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
+ "res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth",
+ "regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth",
+ "regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
+ "regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
+ "regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
+ "regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
+ "regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
+ "regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
+ "regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth",
+ "resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth",
+ "resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth",
+ "resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth",
+ "mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth",
+ "mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth",
+ "mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth",
+ "contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth",
+ "contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth",
+ "resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth",
+ "resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth",
+ "resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth",
+ "darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth",
+ "mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth"
+}
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..999e090a458ee148ceca0649f1e3806a40e909bd
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/__init__.py
@@ -0,0 +1,81 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .assign_score_withk import assign_score_withk
+from .ball_query import ball_query
+from .bbox import bbox_overlaps
+from .border_align import BorderAlign, border_align
+from .box_iou_rotated import box_iou_rotated
+from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
+from .cc_attention import CrissCrossAttention
+from .contour_expand import contour_expand
+from .corner_pool import CornerPool
+from .correlation import Correlation
+from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
+from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack,
+ ModulatedDeformRoIPoolPack, deform_roi_pool)
+from .deprecated_wrappers import Conv2d_deprecated as Conv2d
+from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d
+from .deprecated_wrappers import Linear_deprecated as Linear
+from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
+from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
+ sigmoid_focal_loss, softmax_focal_loss)
+from .furthest_point_sample import (furthest_point_sample,
+ furthest_point_sample_with_dist)
+from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
+from .gather_points import gather_points
+from .group_points import GroupAll, QueryAndGroup, grouping_operation
+from .info import (get_compiler_version, get_compiling_cuda_version,
+ get_onnxruntime_op_path)
+from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev
+from .knn import knn
+from .masked_conv import MaskedConv2d, masked_conv2d
+from .modulated_deform_conv import (ModulatedDeformConv2d,
+ ModulatedDeformConv2dPack,
+ modulated_deform_conv2d)
+from .multi_scale_deform_attn import MultiScaleDeformableAttention
+from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
+from .pixel_group import pixel_group
+from .point_sample import (SimpleRoIAlign, point_sample,
+ rel_roi_point_to_rel_img_point)
+from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
+ points_in_boxes_part)
+from .points_sampler import PointsSampler
+from .psa_mask import PSAMask
+from .roi_align import RoIAlign, roi_align
+from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
+from .roi_pool import RoIPool, roi_pool
+from .roiaware_pool3d import RoIAwarePool3d
+from .roipoint_pool3d import RoIPointPool3d
+from .saconv import SAConv2d
+from .scatter_points import DynamicScatter, dynamic_scatter
+from .sync_bn import SyncBatchNorm
+from .three_interpolate import three_interpolate
+from .three_nn import three_nn
+from .tin_shift import TINShift, tin_shift
+from .upfirdn2d import upfirdn2d
+from .voxelize import Voxelization, voxelization
+
+__all__ = [
+ 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
+ 'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack',
+ 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
+ 'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss',
+ 'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss',
+ 'get_compiler_version', 'get_compiling_cuda_version',
+ 'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d',
+ 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
+ 'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
+ 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
+ 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
+ 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
+ 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
+ 'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
+ 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
+ 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup',
+ 'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn',
+ 'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
+ 'border_align', 'gather_points', 'furthest_point_sample',
+ 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
+ 'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
+ 'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
+ 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/assign_score_withk.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/assign_score_withk.py
new file mode 100644
index 0000000000000000000000000000000000000000..4906adaa2cffd1b46912fbe7d4f87ef2f9fa0012
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/assign_score_withk.py
@@ -0,0 +1,123 @@
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['assign_score_withk_forward', 'assign_score_withk_backward'])
+
+
+class AssignScoreWithK(Function):
+ r"""Perform weighted sum to generate output features according to scores.
+ Modified from `PAConv `_.
+
+ This is a memory-efficient CUDA implementation of assign_scores operation,
+ which first transform all point features with weight bank, then assemble
+ neighbor features with ``knn_idx`` and perform weighted sum of ``scores``.
+
+ See the `paper `_ appendix Sec. D for
+ more detailed descriptions.
+
+ Note:
+ This implementation assumes using ``neighbor`` kernel input, which is
+ (point_features - center_features, point_features).
+ See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/
+ pointnet2/paconv.py#L128 for more details.
+ """
+
+ @staticmethod
+ def forward(ctx,
+ scores,
+ point_features,
+ center_features,
+ knn_idx,
+ aggregate='sum'):
+ """
+ Args:
+ scores (torch.Tensor): (B, npoint, K, M), predicted scores to
+ aggregate weight matrices in the weight bank.
+ ``npoint`` is the number of sampled centers.
+ ``K`` is the number of queried neighbors.
+ ``M`` is the number of weight matrices in the weight bank.
+ point_features (torch.Tensor): (B, N, M, out_dim)
+ Pre-computed point features to be aggregated.
+ center_features (torch.Tensor): (B, N, M, out_dim)
+ Pre-computed center features to be aggregated.
+ knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN.
+ We assume the first idx in each row is the idx of the center.
+ aggregate (str, optional): Aggregation method.
+ Can be 'sum', 'avg' or 'max'. Defaults: 'sum'.
+
+ Returns:
+ torch.Tensor: (B, out_dim, npoint, K), the aggregated features.
+ """
+ agg = {'sum': 0, 'avg': 1, 'max': 2}
+
+ B, N, M, out_dim = point_features.size()
+ _, npoint, K, _ = scores.size()
+
+ output = point_features.new_zeros((B, out_dim, npoint, K))
+ ext_module.assign_score_withk_forward(
+ point_features.contiguous(),
+ center_features.contiguous(),
+ scores.contiguous(),
+ knn_idx.contiguous(),
+ output,
+ B=B,
+ N0=N,
+ N1=npoint,
+ M=M,
+ K=K,
+ O=out_dim,
+ aggregate=agg[aggregate])
+
+ ctx.save_for_backward(output, point_features, center_features, scores,
+ knn_idx)
+ ctx.agg = agg[aggregate]
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ """
+ Args:
+ grad_out (torch.Tensor): (B, out_dim, npoint, K)
+
+ Returns:
+ grad_scores (torch.Tensor): (B, npoint, K, M)
+ grad_point_features (torch.Tensor): (B, N, M, out_dim)
+ grad_center_features (torch.Tensor): (B, N, M, out_dim)
+ """
+ _, point_features, center_features, scores, knn_idx = ctx.saved_tensors
+
+ agg = ctx.agg
+
+ B, N, M, out_dim = point_features.size()
+ _, npoint, K, _ = scores.size()
+
+ grad_point_features = point_features.new_zeros(point_features.shape)
+ grad_center_features = center_features.new_zeros(center_features.shape)
+ grad_scores = scores.new_zeros(scores.shape)
+
+ ext_module.assign_score_withk_backward(
+ grad_out.contiguous(),
+ point_features.contiguous(),
+ center_features.contiguous(),
+ scores.contiguous(),
+ knn_idx.contiguous(),
+ grad_point_features,
+ grad_center_features,
+ grad_scores,
+ B=B,
+ N0=N,
+ N1=npoint,
+ M=M,
+ K=K,
+ O=out_dim,
+ aggregate=agg)
+
+ return grad_scores, grad_point_features, \
+ grad_center_features, None, None
+
+
+assign_score_withk = AssignScoreWithK.apply
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/ball_query.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/ball_query.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0466847c6e5c1239e359a0397568413ebc1504a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/ball_query.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['ball_query_forward'])
+
+
+class BallQuery(Function):
+ """Find nearby points in spherical space."""
+
+ @staticmethod
+ def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
+ xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ min_radius (float): minimum radius of the balls.
+ max_radius (float): maximum radius of the balls.
+ sample_num (int): maximum number of features in the balls.
+ xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ center_xyz (Tensor): (B, npoint, 3) centers of the ball query.
+
+ Returns:
+ Tensor: (B, npoint, nsample) tensor with the indices of
+ the features that form the query balls.
+ """
+ assert center_xyz.is_contiguous()
+ assert xyz.is_contiguous()
+ assert min_radius < max_radius
+
+ B, N, _ = xyz.size()
+ npoint = center_xyz.size(1)
+ idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int)
+
+ ext_module.ball_query_forward(
+ center_xyz,
+ xyz,
+ idx,
+ b=B,
+ n=N,
+ m=npoint,
+ min_radius=min_radius,
+ max_radius=max_radius,
+ nsample=sample_num)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(idx)
+ return idx
+
+ @staticmethod
+ def backward(ctx, a=None):
+ return None, None, None, None
+
+
+ball_query = BallQuery.apply
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/bbox.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c4d58b6c91f652933974f519acd3403a833e906
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/bbox.py
@@ -0,0 +1,72 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['bbox_overlaps'])
+
+
+def bbox_overlaps(bboxes1, bboxes2, mode='iou', aligned=False, offset=0):
+ """Calculate overlap between two set of bboxes.
+
+ If ``aligned`` is ``False``, then calculate the ious between each bbox
+ of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
+ bboxes1 and bboxes2.
+
+ Args:
+ bboxes1 (Tensor): shape (m, 4) in format or empty.
+ bboxes2 (Tensor): shape (n, 4) in format or empty.
+ If aligned is ``True``, then m and n must be equal.
+ mode (str): "iou" (intersection over union) or iof (intersection over
+ foreground).
+
+ Returns:
+ ious(Tensor): shape (m, n) if aligned == False else shape (m, 1)
+
+ Example:
+ >>> bboxes1 = torch.FloatTensor([
+ >>> [0, 0, 10, 10],
+ >>> [10, 10, 20, 20],
+ >>> [32, 32, 38, 42],
+ >>> ])
+ >>> bboxes2 = torch.FloatTensor([
+ >>> [0, 0, 10, 20],
+ >>> [0, 10, 10, 19],
+ >>> [10, 10, 20, 20],
+ >>> ])
+ >>> bbox_overlaps(bboxes1, bboxes2)
+ tensor([[0.5000, 0.0000, 0.0000],
+ [0.0000, 0.0000, 1.0000],
+ [0.0000, 0.0000, 0.0000]])
+
+ Example:
+ >>> empty = torch.FloatTensor([])
+ >>> nonempty = torch.FloatTensor([
+ >>> [0, 0, 10, 9],
+ >>> ])
+ >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
+ >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
+ >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
+ """
+
+ mode_dict = {'iou': 0, 'iof': 1}
+ assert mode in mode_dict.keys()
+ mode_flag = mode_dict[mode]
+ # Either the boxes are empty or the length of boxes' last dimension is 4
+ assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)
+ assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)
+ assert offset == 1 or offset == 0
+
+ rows = bboxes1.size(0)
+ cols = bboxes2.size(0)
+ if aligned:
+ assert rows == cols
+
+ if rows * cols == 0:
+ return bboxes1.new(rows, 1) if aligned else bboxes1.new(rows, cols)
+
+ if aligned:
+ ious = bboxes1.new_zeros(rows)
+ else:
+ ious = bboxes1.new_zeros((rows, cols))
+ ext_module.bbox_overlaps(
+ bboxes1, bboxes2, ious, mode=mode_flag, aligned=aligned, offset=offset)
+ return ious
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/border_align.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/border_align.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff305be328e9b0a15e1bbb5e6b41beb940f55c81
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/border_align.py
@@ -0,0 +1,109 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# modified from
+# https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/border_align.py
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['border_align_forward', 'border_align_backward'])
+
+
+class BorderAlignFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, boxes, pool_size):
+ return g.op(
+ 'mmcv::MMCVBorderAlign', input, boxes, pool_size_i=pool_size)
+
+ @staticmethod
+ def forward(ctx, input, boxes, pool_size):
+ ctx.pool_size = pool_size
+ ctx.input_shape = input.size()
+
+ assert boxes.ndim == 3, 'boxes must be with shape [B, H*W, 4]'
+ assert boxes.size(2) == 4, \
+ 'the last dimension of boxes must be (x1, y1, x2, y2)'
+ assert input.size(1) % 4 == 0, \
+ 'the channel for input feature must be divisible by factor 4'
+
+ # [B, C//4, H*W, 4]
+ output_shape = (input.size(0), input.size(1) // 4, boxes.size(1), 4)
+ output = input.new_zeros(output_shape)
+ # `argmax_idx` only used for backward
+ argmax_idx = input.new_zeros(output_shape).to(torch.int)
+
+ ext_module.border_align_forward(
+ input, boxes, output, argmax_idx, pool_size=ctx.pool_size)
+
+ ctx.save_for_backward(boxes, argmax_idx)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ boxes, argmax_idx = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(ctx.input_shape)
+ # complex head architecture may cause grad_output uncontiguous
+ grad_output = grad_output.contiguous()
+ ext_module.border_align_backward(
+ grad_output,
+ boxes,
+ argmax_idx,
+ grad_input,
+ pool_size=ctx.pool_size)
+ return grad_input, None, None
+
+
+border_align = BorderAlignFunction.apply
+
+
+class BorderAlign(nn.Module):
+ r"""Border align pooling layer.
+
+ Applies border_align over the input feature based on predicted bboxes.
+ The details were described in the paper
+ `BorderDet: Border Feature for Dense Object Detection
+ `_.
+
+ For each border line (e.g. top, left, bottom or right) of each box,
+ border_align does the following:
+ 1. uniformly samples `pool_size`+1 positions on this line, involving \
+ the start and end points.
+ 2. the corresponding features on these points are computed by \
+ bilinear interpolation.
+ 3. max pooling over all the `pool_size`+1 positions are used for \
+ computing pooled feature.
+
+ Args:
+ pool_size (int): number of positions sampled over the boxes' borders
+ (e.g. top, bottom, left, right).
+
+ """
+
+ def __init__(self, pool_size):
+ super(BorderAlign, self).__init__()
+ self.pool_size = pool_size
+
+ def forward(self, input, boxes):
+ """
+ Args:
+ input: Features with shape [N,4C,H,W]. Channels ranged in [0,C),
+ [C,2C), [2C,3C), [3C,4C) represent the top, left, bottom,
+ right features respectively.
+ boxes: Boxes with shape [N,H*W,4]. Coordinate format (x1,y1,x2,y2).
+
+ Returns:
+ Tensor: Pooled features with shape [N,C,H*W,4]. The order is
+ (top,left,bottom,right) for the last dimension.
+ """
+ return border_align(input, boxes, self.pool_size)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(pool_size={self.pool_size})'
+ return s
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/box_iou_rotated.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/box_iou_rotated.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d78015e9c2a9e7a52859b4e18f84a9aa63481a0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/box_iou_rotated.py
@@ -0,0 +1,45 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated'])
+
+
+def box_iou_rotated(bboxes1, bboxes2, mode='iou', aligned=False):
+ """Return intersection-over-union (Jaccard index) of boxes.
+
+ Both sets of boxes are expected to be in
+ (x_center, y_center, width, height, angle) format.
+
+ If ``aligned`` is ``False``, then calculate the ious between each bbox
+ of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
+ bboxes1 and bboxes2.
+
+ Arguments:
+ boxes1 (Tensor): rotated bboxes 1. \
+ It has shape (N, 5), indicating (x, y, w, h, theta) for each row.
+ Note that theta is in radian.
+ boxes2 (Tensor): rotated bboxes 2. \
+ It has shape (M, 5), indicating (x, y, w, h, theta) for each row.
+ Note that theta is in radian.
+ mode (str): "iou" (intersection over union) or iof (intersection over
+ foreground).
+
+ Returns:
+ ious(Tensor): shape (N, M) if aligned == False else shape (N,)
+ """
+ assert mode in ['iou', 'iof']
+ mode_dict = {'iou': 0, 'iof': 1}
+ mode_flag = mode_dict[mode]
+ rows = bboxes1.size(0)
+ cols = bboxes2.size(0)
+ if aligned:
+ ious = bboxes1.new_zeros(rows)
+ else:
+ ious = bboxes1.new_zeros((rows * cols))
+ bboxes1 = bboxes1.contiguous()
+ bboxes2 = bboxes2.contiguous()
+ ext_module.box_iou_rotated(
+ bboxes1, bboxes2, ious, mode_flag=mode_flag, aligned=aligned)
+ if not aligned:
+ ious = ious.view(rows, cols)
+ return ious
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/carafe.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/carafe.py
new file mode 100644
index 0000000000000000000000000000000000000000..5154cb3abfccfbbe0a1b2daa67018dbf80aaf6d2
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/carafe.py
@@ -0,0 +1,287 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.nn.modules.module import Module
+
+from ..cnn import UPSAMPLE_LAYERS, normal_init, xavier_init
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'carafe_naive_forward', 'carafe_naive_backward', 'carafe_forward',
+ 'carafe_backward'
+])
+
+
+class CARAFENaiveFunction(Function):
+
+ @staticmethod
+ def symbolic(g, features, masks, kernel_size, group_size, scale_factor):
+ return g.op(
+ 'mmcv::MMCVCARAFENaive',
+ features,
+ masks,
+ kernel_size_i=kernel_size,
+ group_size_i=group_size,
+ scale_factor_f=scale_factor)
+
+ @staticmethod
+ def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
+ assert scale_factor >= 1
+ assert masks.size(1) == kernel_size * kernel_size * group_size
+ assert masks.size(-1) == features.size(-1) * scale_factor
+ assert masks.size(-2) == features.size(-2) * scale_factor
+ assert features.size(1) % group_size == 0
+ assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
+ ctx.kernel_size = kernel_size
+ ctx.group_size = group_size
+ ctx.scale_factor = scale_factor
+ ctx.feature_size = features.size()
+ ctx.mask_size = masks.size()
+
+ n, c, h, w = features.size()
+ output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
+ ext_module.carafe_naive_forward(
+ features,
+ masks,
+ output,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+
+ if features.requires_grad or masks.requires_grad:
+ ctx.save_for_backward(features, masks)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ assert grad_output.is_cuda
+
+ features, masks = ctx.saved_tensors
+ kernel_size = ctx.kernel_size
+ group_size = ctx.group_size
+ scale_factor = ctx.scale_factor
+
+ grad_input = torch.zeros_like(features)
+ grad_masks = torch.zeros_like(masks)
+ ext_module.carafe_naive_backward(
+ grad_output.contiguous(),
+ features,
+ masks,
+ grad_input,
+ grad_masks,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+
+ return grad_input, grad_masks, None, None, None
+
+
+carafe_naive = CARAFENaiveFunction.apply
+
+
+class CARAFENaive(Module):
+
+ def __init__(self, kernel_size, group_size, scale_factor):
+ super(CARAFENaive, self).__init__()
+
+ assert isinstance(kernel_size, int) and isinstance(
+ group_size, int) and isinstance(scale_factor, int)
+ self.kernel_size = kernel_size
+ self.group_size = group_size
+ self.scale_factor = scale_factor
+
+ def forward(self, features, masks):
+ return carafe_naive(features, masks, self.kernel_size, self.group_size,
+ self.scale_factor)
+
+
+class CARAFEFunction(Function):
+
+ @staticmethod
+ def symbolic(g, features, masks, kernel_size, group_size, scale_factor):
+ return g.op(
+ 'mmcv::MMCVCARAFE',
+ features,
+ masks,
+ kernel_size_i=kernel_size,
+ group_size_i=group_size,
+ scale_factor_f=scale_factor)
+
+ @staticmethod
+ def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
+ assert scale_factor >= 1
+ assert masks.size(1) == kernel_size * kernel_size * group_size
+ assert masks.size(-1) == features.size(-1) * scale_factor
+ assert masks.size(-2) == features.size(-2) * scale_factor
+ assert features.size(1) % group_size == 0
+ assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
+ ctx.kernel_size = kernel_size
+ ctx.group_size = group_size
+ ctx.scale_factor = scale_factor
+ ctx.feature_size = features.size()
+ ctx.mask_size = masks.size()
+
+ n, c, h, w = features.size()
+ output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
+ routput = features.new_zeros(output.size(), requires_grad=False)
+ rfeatures = features.new_zeros(features.size(), requires_grad=False)
+ rmasks = masks.new_zeros(masks.size(), requires_grad=False)
+ ext_module.carafe_forward(
+ features,
+ masks,
+ rfeatures,
+ routput,
+ rmasks,
+ output,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+
+ if features.requires_grad or masks.requires_grad:
+ ctx.save_for_backward(features, masks, rfeatures)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ assert grad_output.is_cuda
+
+ features, masks, rfeatures = ctx.saved_tensors
+ kernel_size = ctx.kernel_size
+ group_size = ctx.group_size
+ scale_factor = ctx.scale_factor
+
+ rgrad_output = torch.zeros_like(grad_output, requires_grad=False)
+ rgrad_input_hs = torch.zeros_like(grad_output, requires_grad=False)
+ rgrad_input = torch.zeros_like(features, requires_grad=False)
+ rgrad_masks = torch.zeros_like(masks, requires_grad=False)
+ grad_input = torch.zeros_like(features, requires_grad=False)
+ grad_masks = torch.zeros_like(masks, requires_grad=False)
+ ext_module.carafe_backward(
+ grad_output.contiguous(),
+ rfeatures,
+ masks,
+ rgrad_output,
+ rgrad_input_hs,
+ rgrad_input,
+ rgrad_masks,
+ grad_input,
+ grad_masks,
+ kernel_size=kernel_size,
+ group_size=group_size,
+ scale_factor=scale_factor)
+ return grad_input, grad_masks, None, None, None
+
+
+carafe = CARAFEFunction.apply
+
+
+class CARAFE(Module):
+ """ CARAFE: Content-Aware ReAssembly of FEatures
+
+ Please refer to https://arxiv.org/abs/1905.02188 for more details.
+
+ Args:
+ kernel_size (int): reassemble kernel size
+ group_size (int): reassemble group size
+ scale_factor (int): upsample ratio
+
+ Returns:
+ upsampled feature map
+ """
+
+ def __init__(self, kernel_size, group_size, scale_factor):
+ super(CARAFE, self).__init__()
+
+ assert isinstance(kernel_size, int) and isinstance(
+ group_size, int) and isinstance(scale_factor, int)
+ self.kernel_size = kernel_size
+ self.group_size = group_size
+ self.scale_factor = scale_factor
+
+ def forward(self, features, masks):
+ return carafe(features, masks, self.kernel_size, self.group_size,
+ self.scale_factor)
+
+
+@UPSAMPLE_LAYERS.register_module(name='carafe')
+class CARAFEPack(nn.Module):
+ """A unified package of CARAFE upsampler that contains: 1) channel
+ compressor 2) content encoder 3) CARAFE op.
+
+ Official implementation of ICCV 2019 paper
+ CARAFE: Content-Aware ReAssembly of FEatures
+ Please refer to https://arxiv.org/abs/1905.02188 for more details.
+
+ Args:
+ channels (int): input feature channels
+ scale_factor (int): upsample ratio
+ up_kernel (int): kernel size of CARAFE op
+ up_group (int): group size of CARAFE op
+ encoder_kernel (int): kernel size of content encoder
+ encoder_dilation (int): dilation of content encoder
+ compressed_channels (int): output channels of channels compressor
+
+ Returns:
+ upsampled feature map
+ """
+
+ def __init__(self,
+ channels,
+ scale_factor,
+ up_kernel=5,
+ up_group=1,
+ encoder_kernel=3,
+ encoder_dilation=1,
+ compressed_channels=64):
+ super(CARAFEPack, self).__init__()
+ self.channels = channels
+ self.scale_factor = scale_factor
+ self.up_kernel = up_kernel
+ self.up_group = up_group
+ self.encoder_kernel = encoder_kernel
+ self.encoder_dilation = encoder_dilation
+ self.compressed_channels = compressed_channels
+ self.channel_compressor = nn.Conv2d(channels, self.compressed_channels,
+ 1)
+ self.content_encoder = nn.Conv2d(
+ self.compressed_channels,
+ self.up_kernel * self.up_kernel * self.up_group *
+ self.scale_factor * self.scale_factor,
+ self.encoder_kernel,
+ padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
+ dilation=self.encoder_dilation,
+ groups=1)
+ self.init_weights()
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+ normal_init(self.content_encoder, std=0.001)
+
+ def kernel_normalizer(self, mask):
+ mask = F.pixel_shuffle(mask, self.scale_factor)
+ n, mask_c, h, w = mask.size()
+ # use float division explicitly,
+ # to void inconsistency while exporting to onnx
+ mask_channel = int(mask_c / float(self.up_kernel**2))
+ mask = mask.view(n, mask_channel, -1, h, w)
+
+ mask = F.softmax(mask, dim=2, dtype=mask.dtype)
+ mask = mask.view(n, mask_c, h, w).contiguous()
+
+ return mask
+
+ def feature_reassemble(self, x, mask):
+ x = carafe(x, mask, self.up_kernel, self.up_group, self.scale_factor)
+ return x
+
+ def forward(self, x):
+ compressed_x = self.channel_compressor(x)
+ mask = self.content_encoder(compressed_x)
+ mask = self.kernel_normalizer(mask)
+
+ x = self.feature_reassemble(x, mask)
+ return x
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/cc_attention.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/cc_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..8982f467185b5d839832baa2e51722613a8b87a2
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/cc_attention.py
@@ -0,0 +1,83 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from annotator.mmpkg.mmcv.cnn import PLUGIN_LAYERS, Scale
+
+
+def NEG_INF_DIAG(n, device):
+ """Returns a diagonal matrix of size [n, n].
+
+ The diagonal are all "-inf". This is for avoiding calculating the
+ overlapped element in the Criss-Cross twice.
+ """
+ return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0)
+
+
+@PLUGIN_LAYERS.register_module()
+class CrissCrossAttention(nn.Module):
+ """Criss-Cross Attention Module.
+
+ .. note::
+ Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch
+ to a pure PyTorch and equivalent implementation. For more
+ details, please refer to https://github.com/open-mmlab/mmcv/pull/1201.
+
+ Speed comparison for one forward pass
+
+ - Input size: [2,512,97,97]
+ - Device: 1 NVIDIA GeForce RTX 2080 Ti
+
+ +-----------------------+---------------+------------+---------------+
+ | |PyTorch version|CUDA version|Relative speed |
+ +=======================+===============+============+===============+
+ |with torch.no_grad() |0.00554402 s |0.0299619 s |5.4x |
+ +-----------------------+---------------+------------+---------------+
+ |no with torch.no_grad()|0.00562803 s |0.0301349 s |5.4x |
+ +-----------------------+---------------+------------+---------------+
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ """
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
+ self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
+ self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
+ self.gamma = Scale(0.)
+ self.in_channels = in_channels
+
+ def forward(self, x):
+ """forward function of Criss-Cross Attention.
+
+ Args:
+ x (Tensor): Input feature. \
+ shape (batch_size, in_channels, height, width)
+ Returns:
+ Tensor: Output of the layer, with shape of \
+ (batch_size, in_channels, height, width)
+ """
+ B, C, H, W = x.size()
+ query = self.query_conv(x)
+ key = self.key_conv(x)
+ value = self.value_conv(x)
+ energy_H = torch.einsum('bchw,bciw->bwhi', query, key) + NEG_INF_DIAG(
+ H, query.device)
+ energy_H = energy_H.transpose(1, 2)
+ energy_W = torch.einsum('bchw,bchj->bhwj', query, key)
+ attn = F.softmax(
+ torch.cat([energy_H, energy_W], dim=-1), dim=-1) # [B,H,W,(H+W)]
+ out = torch.einsum('bciw,bhwi->bchw', value, attn[..., :H])
+ out += torch.einsum('bchj,bhwj->bchw', value, attn[..., H:])
+
+ out = self.gamma(out) + x
+ out = out.contiguous()
+
+ return out
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(in_channels={self.in_channels})'
+ return s
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/contour_expand.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/contour_expand.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea1111e1768b5f27e118bf7dbc0d9c70a7afd6d7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/contour_expand.py
@@ -0,0 +1,49 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['contour_expand'])
+
+
+def contour_expand(kernel_mask, internal_kernel_label, min_kernel_area,
+ kernel_num):
+ """Expand kernel contours so that foreground pixels are assigned into
+ instances.
+
+ Arguments:
+ kernel_mask (np.array or Tensor): The instance kernel mask with
+ size hxw.
+ internal_kernel_label (np.array or Tensor): The instance internal
+ kernel label with size hxw.
+ min_kernel_area (int): The minimum kernel area.
+ kernel_num (int): The instance kernel number.
+
+ Returns:
+ label (list): The instance index map with size hxw.
+ """
+ assert isinstance(kernel_mask, (torch.Tensor, np.ndarray))
+ assert isinstance(internal_kernel_label, (torch.Tensor, np.ndarray))
+ assert isinstance(min_kernel_area, int)
+ assert isinstance(kernel_num, int)
+
+ if isinstance(kernel_mask, np.ndarray):
+ kernel_mask = torch.from_numpy(kernel_mask)
+ if isinstance(internal_kernel_label, np.ndarray):
+ internal_kernel_label = torch.from_numpy(internal_kernel_label)
+
+ if torch.__version__ == 'parrots':
+ if kernel_mask.shape[0] == 0 or internal_kernel_label.shape[0] == 0:
+ label = []
+ else:
+ label = ext_module.contour_expand(
+ kernel_mask,
+ internal_kernel_label,
+ min_kernel_area=min_kernel_area,
+ kernel_num=kernel_num)
+ label = label.tolist()
+ else:
+ label = ext_module.contour_expand(kernel_mask, internal_kernel_label,
+ min_kernel_area, kernel_num)
+ return label
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/corner_pool.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/corner_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..a33d798b43d405e4c86bee4cd6389be21ca9c637
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/corner_pool.py
@@ -0,0 +1,161 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'top_pool_forward', 'top_pool_backward', 'bottom_pool_forward',
+ 'bottom_pool_backward', 'left_pool_forward', 'left_pool_backward',
+ 'right_pool_forward', 'right_pool_backward'
+])
+
+_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}
+
+
+class TopPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.top_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.top_pool_backward(input, grad_output)
+ return output
+
+
+class BottomPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.bottom_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.bottom_pool_backward(input, grad_output)
+ return output
+
+
+class LeftPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.left_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.left_pool_backward(input, grad_output)
+ return output
+
+
+class RightPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right']))
+ return output
+
+ @staticmethod
+ def forward(ctx, input):
+ output = ext_module.right_pool_forward(input)
+ ctx.save_for_backward(input)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_tensors
+ output = ext_module.right_pool_backward(input, grad_output)
+ return output
+
+
+class CornerPool(nn.Module):
+ """Corner Pooling.
+
+ Corner Pooling is a new type of pooling layer that helps a
+ convolutional network better localize corners of bounding boxes.
+
+ Please refer to https://arxiv.org/abs/1808.01244 for more details.
+ Code is modified from https://github.com/princeton-vl/CornerNet-Lite.
+
+ Args:
+ mode(str): Pooling orientation for the pooling layer
+
+ - 'bottom': Bottom Pooling
+ - 'left': Left Pooling
+ - 'right': Right Pooling
+ - 'top': Top Pooling
+
+ Returns:
+ Feature map after pooling.
+ """
+
+ pool_functions = {
+ 'bottom': BottomPoolFunction,
+ 'left': LeftPoolFunction,
+ 'right': RightPoolFunction,
+ 'top': TopPoolFunction,
+ }
+
+ cummax_dim_flip = {
+ 'bottom': (2, False),
+ 'left': (3, True),
+ 'right': (3, False),
+ 'top': (2, True),
+ }
+
+ def __init__(self, mode):
+ super(CornerPool, self).__init__()
+ assert mode in self.pool_functions
+ self.mode = mode
+ self.corner_pool = self.pool_functions[mode]
+
+ def forward(self, x):
+ if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0':
+ if torch.onnx.is_in_onnx_export():
+ assert torch.__version__ >= '1.7.0', \
+ 'When `cummax` serves as an intermediate component whose '\
+ 'outputs is used as inputs for another modules, it\'s '\
+ 'expected that pytorch version must be >= 1.7.0, '\
+ 'otherwise Error appears like: `RuntimeError: tuple '\
+ 'appears in op that does not forward tuples, unsupported '\
+ 'kind: prim::PythonOp`.'
+
+ dim, flip = self.cummax_dim_flip[self.mode]
+ if flip:
+ x = x.flip(dim)
+ pool_tensor, _ = torch.cummax(x, dim=dim)
+ if flip:
+ pool_tensor = pool_tensor.flip(dim)
+ return pool_tensor
+ else:
+ return self.corner_pool.apply(x)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/correlation.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/correlation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d0b79c301b29915dfaf4d2b1846c59be73127d3
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/correlation.py
@@ -0,0 +1,196 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import Tensor, nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['correlation_forward', 'correlation_backward'])
+
+
+class CorrelationFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input1,
+ input2,
+ kernel_size=1,
+ max_displacement=1,
+ stride=1,
+ padding=1,
+ dilation=1,
+ dilation_patch=1):
+
+ ctx.save_for_backward(input1, input2)
+
+ kH, kW = ctx.kernel_size = _pair(kernel_size)
+ patch_size = max_displacement * 2 + 1
+ ctx.patch_size = patch_size
+ dH, dW = ctx.stride = _pair(stride)
+ padH, padW = ctx.padding = _pair(padding)
+ dilationH, dilationW = ctx.dilation = _pair(dilation)
+ dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair(
+ dilation_patch)
+
+ output_size = CorrelationFunction._output_size(ctx, input1)
+
+ output = input1.new_zeros(output_size)
+
+ ext_module.correlation_forward(
+ input1,
+ input2,
+ output,
+ kH=kH,
+ kW=kW,
+ patchH=patch_size,
+ patchW=patch_size,
+ padH=padH,
+ padW=padW,
+ dilationH=dilationH,
+ dilationW=dilationW,
+ dilation_patchH=dilation_patchH,
+ dilation_patchW=dilation_patchW,
+ dH=dH,
+ dW=dW)
+
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input1, input2 = ctx.saved_tensors
+
+ kH, kW = ctx.kernel_size
+ patch_size = ctx.patch_size
+ padH, padW = ctx.padding
+ dilationH, dilationW = ctx.dilation
+ dilation_patchH, dilation_patchW = ctx.dilation_patch
+ dH, dW = ctx.stride
+ grad_input1 = torch.zeros_like(input1)
+ grad_input2 = torch.zeros_like(input2)
+
+ ext_module.correlation_backward(
+ grad_output,
+ input1,
+ input2,
+ grad_input1,
+ grad_input2,
+ kH=kH,
+ kW=kW,
+ patchH=patch_size,
+ patchW=patch_size,
+ padH=padH,
+ padW=padW,
+ dilationH=dilationH,
+ dilationW=dilationW,
+ dilation_patchH=dilation_patchH,
+ dilation_patchW=dilation_patchW,
+ dH=dH,
+ dW=dW)
+ return grad_input1, grad_input2, None, None, None, None, None, None
+
+ @staticmethod
+ def _output_size(ctx, input1):
+ iH, iW = input1.size(2), input1.size(3)
+ batch_size = input1.size(0)
+ kH, kW = ctx.kernel_size
+ patch_size = ctx.patch_size
+ dH, dW = ctx.stride
+ padH, padW = ctx.padding
+ dilationH, dilationW = ctx.dilation
+ dilatedKH = (kH - 1) * dilationH + 1
+ dilatedKW = (kW - 1) * dilationW + 1
+
+ oH = int((iH + 2 * padH - dilatedKH) / dH + 1)
+ oW = int((iW + 2 * padW - dilatedKW) / dW + 1)
+
+ output_size = (batch_size, patch_size, patch_size, oH, oW)
+ return output_size
+
+
+class Correlation(nn.Module):
+ r"""Correlation operator
+
+ This correlation operator works for optical flow correlation computation.
+
+ There are two batched tensors with shape :math:`(N, C, H, W)`,
+ and the correlation output's shape is :math:`(N, max\_displacement \times
+ 2 + 1, max\_displacement * 2 + 1, H_{out}, W_{out})`
+
+ where
+
+ .. math::
+ H_{out} = \left\lfloor\frac{H_{in} + 2 \times padding -
+ dilation \times (kernel\_size - 1) - 1}
+ {stride} + 1\right\rfloor
+
+ .. math::
+ W_{out} = \left\lfloor\frac{W_{in} + 2 \times padding - dilation
+ \times (kernel\_size - 1) - 1}
+ {stride} + 1\right\rfloor
+
+ the correlation item :math:`(N_i, dy, dx)` is formed by taking the sliding
+ window convolution between input1 and shifted input2,
+
+ .. math::
+ Corr(N_i, dx, dy) =
+ \sum_{c=0}^{C-1}
+ input1(N_i, c) \star
+ \mathcal{S}(input2(N_i, c), dy, dx)
+
+ where :math:`\star` is the valid 2d sliding window convolution operator,
+ and :math:`\mathcal{S}` means shifting the input features (auto-complete
+ zero marginal), and :math:`dx, dy` are shifting distance, :math:`dx, dy \in
+ [-max\_displacement \times dilation\_patch, max\_displacement \times
+ dilation\_patch]`.
+
+ Args:
+ kernel_size (int): The size of sliding window i.e. local neighborhood
+ representing the center points and involved in correlation
+ computation. Defaults to 1.
+ max_displacement (int): The radius for computing correlation volume,
+ but the actual working space can be dilated by dilation_patch.
+ Defaults to 1.
+ stride (int): The stride of the sliding blocks in the input spatial
+ dimensions. Defaults to 1.
+ padding (int): Zero padding added to all four sides of the input1.
+ Defaults to 0.
+ dilation (int): The spacing of local neighborhood that will involved
+ in correlation. Defaults to 1.
+ dilation_patch (int): The spacing between position need to compute
+ correlation. Defaults to 1.
+ """
+
+ def __init__(self,
+ kernel_size: int = 1,
+ max_displacement: int = 1,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ dilation_patch: int = 1) -> None:
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.max_displacement = max_displacement
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.dilation_patch = dilation_patch
+
+ def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
+ return CorrelationFunction.apply(input1, input2, self.kernel_size,
+ self.max_displacement, self.stride,
+ self.padding, self.dilation,
+ self.dilation_patch)
+
+ def __repr__(self) -> str:
+ s = self.__class__.__name__
+ s += f'(kernel_size={self.kernel_size}, '
+ s += f'max_displacement={self.max_displacement}, '
+ s += f'stride={self.stride}, '
+ s += f'padding={self.padding}, '
+ s += f'dilation={self.dilation}, '
+ s += f'dilation_patch={self.dilation_patch})'
+ return s
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/deform_conv.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..3de3aae1e7b2258360aef3ad9eb3a351f080f10f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/deform_conv.py
@@ -0,0 +1,405 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair, _single
+
+from annotator.mmpkg.mmcv.utils import deprecated_api_warning
+from ..cnn import CONV_LAYERS
+from ..utils import ext_loader, print_log
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'deform_conv_forward', 'deform_conv_backward_input',
+ 'deform_conv_backward_parameters'
+])
+
+
+class DeformConv2dFunction(Function):
+
+ @staticmethod
+ def symbolic(g,
+ input,
+ offset,
+ weight,
+ stride,
+ padding,
+ dilation,
+ groups,
+ deform_groups,
+ bias=False,
+ im2col_step=32):
+ return g.op(
+ 'mmcv::MMCVDeformConv2d',
+ input,
+ offset,
+ weight,
+ stride_i=stride,
+ padding_i=padding,
+ dilation_i=dilation,
+ groups_i=groups,
+ deform_groups_i=deform_groups,
+ bias_i=bias,
+ im2col_step_i=im2col_step)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ weight,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deform_groups=1,
+ bias=False,
+ im2col_step=32):
+ if input is not None and input.dim() != 4:
+ raise ValueError(
+ f'Expected 4D tensor as input, got {input.dim()}D tensor \
+ instead.')
+ assert bias is False, 'Only support bias is False.'
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deform_groups = deform_groups
+ ctx.im2col_step = im2col_step
+
+ # When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
+ # amp won't cast the type of model (float32), but "offset" is cast
+ # to float16 by nn.Conv2d automatically, leading to the type
+ # mismatch with input (when it is float32) or weight.
+ # The flag for whether to use fp16 or amp is the type of "offset",
+ # we cast weight and input to temporarily support fp16 and amp
+ # whatever the pytorch version is.
+ input = input.type_as(offset)
+ weight = weight.type_as(input)
+ ctx.save_for_backward(input, offset, weight)
+
+ output = input.new_empty(
+ DeformConv2dFunction._output_size(ctx, input, weight))
+
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
+
+ cur_im2col_step = min(ctx.im2col_step, input.size(0))
+ assert (input.size(0) %
+ cur_im2col_step) == 0, 'im2col step must divide batchsize'
+ ext_module.deform_conv_forward(
+ input,
+ weight,
+ offset,
+ output,
+ ctx.bufs_[0],
+ ctx.bufs_[1],
+ kW=weight.size(3),
+ kH=weight.size(2),
+ dW=ctx.stride[1],
+ dH=ctx.stride[0],
+ padW=ctx.padding[1],
+ padH=ctx.padding[0],
+ dilationW=ctx.dilation[1],
+ dilationH=ctx.dilation[0],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ im2col_step=cur_im2col_step)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, weight = ctx.saved_tensors
+
+ grad_input = grad_offset = grad_weight = None
+
+ cur_im2col_step = min(ctx.im2col_step, input.size(0))
+ assert (input.size(0) % cur_im2col_step
+ ) == 0, 'batch size must be divisible by im2col_step'
+
+ grad_output = grad_output.contiguous()
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ ext_module.deform_conv_backward_input(
+ input,
+ offset,
+ grad_output,
+ grad_input,
+ grad_offset,
+ weight,
+ ctx.bufs_[0],
+ kW=weight.size(3),
+ kH=weight.size(2),
+ dW=ctx.stride[1],
+ dH=ctx.stride[0],
+ padW=ctx.padding[1],
+ padH=ctx.padding[0],
+ dilationW=ctx.dilation[1],
+ dilationH=ctx.dilation[0],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ im2col_step=cur_im2col_step)
+
+ if ctx.needs_input_grad[2]:
+ grad_weight = torch.zeros_like(weight)
+ ext_module.deform_conv_backward_parameters(
+ input,
+ offset,
+ grad_output,
+ grad_weight,
+ ctx.bufs_[0],
+ ctx.bufs_[1],
+ kW=weight.size(3),
+ kH=weight.size(2),
+ dW=ctx.stride[1],
+ dH=ctx.stride[0],
+ padW=ctx.padding[1],
+ padH=ctx.padding[0],
+ dilationW=ctx.dilation[1],
+ dilationH=ctx.dilation[0],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ scale=1,
+ im2col_step=cur_im2col_step)
+
+ return grad_input, grad_offset, grad_weight, \
+ None, None, None, None, None, None, None
+
+ @staticmethod
+ def _output_size(ctx, input, weight):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = ctx.padding[d]
+ kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = ctx.stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError(
+ 'convolution input is too small (output would be ' +
+ 'x'.join(map(str, output_size)) + ')')
+ return output_size
+
+
+deform_conv2d = DeformConv2dFunction.apply
+
+
+class DeformConv2d(nn.Module):
+ r"""Deformable 2D convolution.
+
+ Applies a deformable 2D convolution over an input signal composed of
+ several input planes. DeformConv2d was described in the paper
+ `Deformable Convolutional Networks
+ `_
+
+ Note:
+ The argument ``im2col_step`` was added in version 1.3.17, which means
+ number of samples processed by the ``im2col_cuda_kernel`` per call.
+ It enables users to define ``batch_size`` and ``im2col_step`` more
+ flexibly and solved `issue mmcv#1440
+ `_.
+
+ Args:
+ in_channels (int): Number of channels in the input image.
+ out_channels (int): Number of channels produced by the convolution.
+ kernel_size(int, tuple): Size of the convolving kernel.
+ stride(int, tuple): Stride of the convolution. Default: 1.
+ padding (int or tuple): Zero-padding added to both sides of the input.
+ Default: 0.
+ dilation (int or tuple): Spacing between kernel elements. Default: 1.
+ groups (int): Number of blocked connections from input.
+ channels to output channels. Default: 1.
+ deform_groups (int): Number of deformable group partitions.
+ bias (bool): If True, adds a learnable bias to the output.
+ Default: False.
+ im2col_step (int): Number of samples processed by im2col_cuda_kernel
+ per call. It will work when ``batch_size`` > ``im2col_step``, but
+ ``batch_size`` must be divisible by ``im2col_step``. Default: 32.
+ `New in version 1.3.17.`
+ """
+
+ @deprecated_api_warning({'deformable_groups': 'deform_groups'},
+ cls_name='DeformConv2d')
+ def __init__(self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, ...]],
+ stride: Union[int, Tuple[int, ...]] = 1,
+ padding: Union[int, Tuple[int, ...]] = 0,
+ dilation: Union[int, Tuple[int, ...]] = 1,
+ groups: int = 1,
+ deform_groups: int = 1,
+ bias: bool = False,
+ im2col_step: int = 32) -> None:
+ super(DeformConv2d, self).__init__()
+
+ assert not bias, \
+ f'bias={bias} is not supported in DeformConv2d.'
+ assert in_channels % groups == 0, \
+ f'in_channels {in_channels} cannot be divisible by groups {groups}'
+ assert out_channels % groups == 0, \
+ f'out_channels {out_channels} cannot be divisible by groups \
+ {groups}'
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deform_groups = deform_groups
+ self.im2col_step = im2col_step
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ # only weight, no bias
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // self.groups,
+ *self.kernel_size))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ # switch the initialization of `self.weight` to the standard kaiming
+ # method described in `Delving deep into rectifiers: Surpassing
+ # human-level performance on ImageNet classification` - He, K. et al.
+ # (2015), using a uniform distribution
+ nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
+
+ def forward(self, x: Tensor, offset: Tensor) -> Tensor:
+ """Deformable Convolutional forward function.
+
+ Args:
+ x (Tensor): Input feature, shape (B, C_in, H_in, W_in)
+ offset (Tensor): Offset for deformable convolution, shape
+ (B, deform_groups*kernel_size[0]*kernel_size[1]*2,
+ H_out, W_out), H_out, W_out are equal to the output's.
+
+ An offset is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`.
+ The spatial arrangement is like:
+
+ .. code:: text
+
+ (x0, y0) (x1, y1) (x2, y2)
+ (x3, y3) (x4, y4) (x5, y5)
+ (x6, y6) (x7, y7) (x8, y8)
+
+ Returns:
+ Tensor: Output of the layer.
+ """
+ # To fix an assert error in deform_conv_cuda.cpp:128
+ # input image is smaller than kernel
+ input_pad = (x.size(2) < self.kernel_size[0]) or (x.size(3) <
+ self.kernel_size[1])
+ if input_pad:
+ pad_h = max(self.kernel_size[0] - x.size(2), 0)
+ pad_w = max(self.kernel_size[1] - x.size(3), 0)
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+ offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0)
+ offset = offset.contiguous()
+ out = deform_conv2d(x, offset, self.weight, self.stride, self.padding,
+ self.dilation, self.groups, self.deform_groups,
+ False, self.im2col_step)
+ if input_pad:
+ out = out[:, :, :out.size(2) - pad_h, :out.size(3) -
+ pad_w].contiguous()
+ return out
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(in_channels={self.in_channels},\n'
+ s += f'out_channels={self.out_channels},\n'
+ s += f'kernel_size={self.kernel_size},\n'
+ s += f'stride={self.stride},\n'
+ s += f'padding={self.padding},\n'
+ s += f'dilation={self.dilation},\n'
+ s += f'groups={self.groups},\n'
+ s += f'deform_groups={self.deform_groups},\n'
+ # bias is not supported in DeformConv2d.
+ s += 'bias=False)'
+ return s
+
+
+@CONV_LAYERS.register_module('DCN')
+class DeformConv2dPack(DeformConv2d):
+ """A Deformable Conv Encapsulation that acts as normal Conv layers.
+
+ The offset tensor is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`.
+ The spatial arrangement is like:
+
+ .. code:: text
+
+ (x0, y0) (x1, y1) (x2, y2)
+ (x3, y3) (x4, y4) (x5, y5)
+ (x6, y6) (x7, y7) (x8, y8)
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(DeformConv2dPack, self).__init__(*args, **kwargs)
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ offset = self.conv_offset(x)
+ return deform_conv2d(x, offset, self.weight, self.stride, self.padding,
+ self.dilation, self.groups, self.deform_groups,
+ False, self.im2col_step)
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ version = local_metadata.get('version', None)
+
+ if version is None or version < 2:
+ # the key is different in early versions
+ # In version < 2, DeformConvPack loads previous benchmark models.
+ if (prefix + 'conv_offset.weight' not in state_dict
+ and prefix[:-1] + '_offset.weight' in state_dict):
+ state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
+ prefix[:-1] + '_offset.weight')
+ if (prefix + 'conv_offset.bias' not in state_dict
+ and prefix[:-1] + '_offset.bias' in state_dict):
+ state_dict[prefix +
+ 'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
+ '_offset.bias')
+
+ if version is not None and version > 1:
+ print_log(
+ f'DeformConv2dPack {prefix.rstrip(".")} is upgraded to '
+ 'version 2.',
+ logger='root')
+
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys,
+ error_msgs)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/deform_roi_pool.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/deform_roi_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc245ba91fee252226ba22e76bb94a35db9a629b
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/deform_roi_pool.py
@@ -0,0 +1,204 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['deform_roi_pool_forward', 'deform_roi_pool_backward'])
+
+
+class DeformRoIPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, rois, offset, output_size, spatial_scale,
+ sampling_ratio, gamma):
+ return g.op(
+ 'mmcv::MMCVDeformRoIPool',
+ input,
+ rois,
+ offset,
+ pooled_height_i=output_size[0],
+ pooled_width_i=output_size[1],
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_f=sampling_ratio,
+ gamma_f=gamma)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ rois,
+ offset,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ if offset is None:
+ offset = input.new_zeros(0)
+ ctx.output_size = _pair(output_size)
+ ctx.spatial_scale = float(spatial_scale)
+ ctx.sampling_ratio = int(sampling_ratio)
+ ctx.gamma = float(gamma)
+
+ assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
+
+ output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
+ ctx.output_size[1])
+ output = input.new_zeros(output_shape)
+
+ ext_module.deform_roi_pool_forward(
+ input,
+ rois,
+ offset,
+ output,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ gamma=ctx.gamma)
+
+ ctx.save_for_backward(input, rois, offset)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, rois, offset = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(input.shape)
+ grad_offset = grad_output.new_zeros(offset.shape)
+
+ ext_module.deform_roi_pool_backward(
+ grad_output,
+ input,
+ rois,
+ offset,
+ grad_input,
+ grad_offset,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ gamma=ctx.gamma)
+ if grad_offset.numel() == 0:
+ grad_offset = None
+ return grad_input, None, grad_offset, None, None, None, None
+
+
+deform_roi_pool = DeformRoIPoolFunction.apply
+
+
+class DeformRoIPool(nn.Module):
+
+ def __init__(self,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ super(DeformRoIPool, self).__init__()
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+ self.sampling_ratio = int(sampling_ratio)
+ self.gamma = float(gamma)
+
+ def forward(self, input, rois, offset=None):
+ return deform_roi_pool(input, rois, offset, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+
+
+class DeformRoIPoolPack(DeformRoIPool):
+
+ def __init__(self,
+ output_size,
+ output_channels,
+ deform_fc_channels=1024,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ super(DeformRoIPoolPack, self).__init__(output_size, spatial_scale,
+ sampling_ratio, gamma)
+
+ self.output_channels = output_channels
+ self.deform_fc_channels = deform_fc_channels
+
+ self.offset_fc = nn.Sequential(
+ nn.Linear(
+ self.output_size[0] * self.output_size[1] *
+ self.output_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels,
+ self.output_size[0] * self.output_size[1] * 2))
+ self.offset_fc[-1].weight.data.zero_()
+ self.offset_fc[-1].bias.data.zero_()
+
+ def forward(self, input, rois):
+ assert input.size(1) == self.output_channels
+ x = deform_roi_pool(input, rois, None, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+ rois_num = rois.size(0)
+ offset = self.offset_fc(x.view(rois_num, -1))
+ offset = offset.view(rois_num, 2, self.output_size[0],
+ self.output_size[1])
+ return deform_roi_pool(input, rois, offset, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+
+
+class ModulatedDeformRoIPoolPack(DeformRoIPool):
+
+ def __init__(self,
+ output_size,
+ output_channels,
+ deform_fc_channels=1024,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ gamma=0.1):
+ super(ModulatedDeformRoIPoolPack,
+ self).__init__(output_size, spatial_scale, sampling_ratio, gamma)
+
+ self.output_channels = output_channels
+ self.deform_fc_channels = deform_fc_channels
+
+ self.offset_fc = nn.Sequential(
+ nn.Linear(
+ self.output_size[0] * self.output_size[1] *
+ self.output_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels,
+ self.output_size[0] * self.output_size[1] * 2))
+ self.offset_fc[-1].weight.data.zero_()
+ self.offset_fc[-1].bias.data.zero_()
+
+ self.mask_fc = nn.Sequential(
+ nn.Linear(
+ self.output_size[0] * self.output_size[1] *
+ self.output_channels, self.deform_fc_channels),
+ nn.ReLU(inplace=True),
+ nn.Linear(self.deform_fc_channels,
+ self.output_size[0] * self.output_size[1] * 1),
+ nn.Sigmoid())
+ self.mask_fc[2].weight.data.zero_()
+ self.mask_fc[2].bias.data.zero_()
+
+ def forward(self, input, rois):
+ assert input.size(1) == self.output_channels
+ x = deform_roi_pool(input, rois, None, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+ rois_num = rois.size(0)
+ offset = self.offset_fc(x.view(rois_num, -1))
+ offset = offset.view(rois_num, 2, self.output_size[0],
+ self.output_size[1])
+ mask = self.mask_fc(x.view(rois_num, -1))
+ mask = mask.view(rois_num, 1, self.output_size[0], self.output_size[1])
+ d = deform_roi_pool(input, rois, offset, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.gamma)
+ return d * mask
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/deprecated_wrappers.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/deprecated_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2e593df9ee57637038683d7a1efaa347b2b69e7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/deprecated_wrappers.py
@@ -0,0 +1,43 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This file is for backward compatibility.
+# Module wrappers for empty tensor have been moved to mmcv.cnn.bricks.
+import warnings
+
+from ..cnn.bricks.wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
+
+
+class Conv2d_deprecated(Conv2d):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing Conv2d wrapper from "mmcv.ops" will be deprecated in'
+ ' the future. Please import them from "mmcv.cnn" instead')
+
+
+class ConvTranspose2d_deprecated(ConvTranspose2d):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing ConvTranspose2d wrapper from "mmcv.ops" will be '
+ 'deprecated in the future. Please import them from "mmcv.cnn" '
+ 'instead')
+
+
+class MaxPool2d_deprecated(MaxPool2d):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing MaxPool2d wrapper from "mmcv.ops" will be deprecated in'
+ ' the future. Please import them from "mmcv.cnn" instead')
+
+
+class Linear_deprecated(Linear):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ warnings.warn(
+ 'Importing Linear wrapper from "mmcv.ops" will be deprecated in'
+ ' the future. Please import them from "mmcv.cnn" instead')
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/focal_loss.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/focal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..763bc93bd2575c49ca8ccf20996bbd92d1e0d1a4
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/focal_loss.py
@@ -0,0 +1,212 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'sigmoid_focal_loss_forward', 'sigmoid_focal_loss_backward',
+ 'softmax_focal_loss_forward', 'softmax_focal_loss_backward'
+])
+
+
+class SigmoidFocalLossFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, target, gamma, alpha, weight, reduction):
+ return g.op(
+ 'mmcv::MMCVSigmoidFocalLoss',
+ input,
+ target,
+ gamma_f=gamma,
+ alpha_f=alpha,
+ weight_f=weight,
+ reduction_s=reduction)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ target,
+ gamma=2.0,
+ alpha=0.25,
+ weight=None,
+ reduction='mean'):
+
+ assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
+ assert input.dim() == 2
+ assert target.dim() == 1
+ assert input.size(0) == target.size(0)
+ if weight is None:
+ weight = input.new_empty(0)
+ else:
+ assert weight.dim() == 1
+ assert input.size(1) == weight.size(0)
+ ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
+ assert reduction in ctx.reduction_dict.keys()
+
+ ctx.gamma = float(gamma)
+ ctx.alpha = float(alpha)
+ ctx.reduction = ctx.reduction_dict[reduction]
+
+ output = input.new_zeros(input.size())
+
+ ext_module.sigmoid_focal_loss_forward(
+ input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha)
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ output = output.sum() / input.size(0)
+ elif ctx.reduction == ctx.reduction_dict['sum']:
+ output = output.sum()
+ ctx.save_for_backward(input, target, weight)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, target, weight = ctx.saved_tensors
+
+ grad_input = input.new_zeros(input.size())
+
+ ext_module.sigmoid_focal_loss_backward(
+ input,
+ target,
+ weight,
+ grad_input,
+ gamma=ctx.gamma,
+ alpha=ctx.alpha)
+
+ grad_input *= grad_output
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ grad_input /= input.size(0)
+ return grad_input, None, None, None, None, None
+
+
+sigmoid_focal_loss = SigmoidFocalLossFunction.apply
+
+
+class SigmoidFocalLoss(nn.Module):
+
+ def __init__(self, gamma, alpha, weight=None, reduction='mean'):
+ super(SigmoidFocalLoss, self).__init__()
+ self.gamma = gamma
+ self.alpha = alpha
+ self.register_buffer('weight', weight)
+ self.reduction = reduction
+
+ def forward(self, input, target):
+ return sigmoid_focal_loss(input, target, self.gamma, self.alpha,
+ self.weight, self.reduction)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(gamma={self.gamma}, '
+ s += f'alpha={self.alpha}, '
+ s += f'reduction={self.reduction})'
+ return s
+
+
+class SoftmaxFocalLossFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, target, gamma, alpha, weight, reduction):
+ return g.op(
+ 'mmcv::MMCVSoftmaxFocalLoss',
+ input,
+ target,
+ gamma_f=gamma,
+ alpha_f=alpha,
+ weight_f=weight,
+ reduction_s=reduction)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ target,
+ gamma=2.0,
+ alpha=0.25,
+ weight=None,
+ reduction='mean'):
+
+ assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
+ assert input.dim() == 2
+ assert target.dim() == 1
+ assert input.size(0) == target.size(0)
+ if weight is None:
+ weight = input.new_empty(0)
+ else:
+ assert weight.dim() == 1
+ assert input.size(1) == weight.size(0)
+ ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
+ assert reduction in ctx.reduction_dict.keys()
+
+ ctx.gamma = float(gamma)
+ ctx.alpha = float(alpha)
+ ctx.reduction = ctx.reduction_dict[reduction]
+
+ channel_stats, _ = torch.max(input, dim=1)
+ input_softmax = input - channel_stats.unsqueeze(1).expand_as(input)
+ input_softmax.exp_()
+
+ channel_stats = input_softmax.sum(dim=1)
+ input_softmax /= channel_stats.unsqueeze(1).expand_as(input)
+
+ output = input.new_zeros(input.size(0))
+ ext_module.softmax_focal_loss_forward(
+ input_softmax,
+ target,
+ weight,
+ output,
+ gamma=ctx.gamma,
+ alpha=ctx.alpha)
+
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ output = output.sum() / input.size(0)
+ elif ctx.reduction == ctx.reduction_dict['sum']:
+ output = output.sum()
+ ctx.save_for_backward(input_softmax, target, weight)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input_softmax, target, weight = ctx.saved_tensors
+ buff = input_softmax.new_zeros(input_softmax.size(0))
+ grad_input = input_softmax.new_zeros(input_softmax.size())
+
+ ext_module.softmax_focal_loss_backward(
+ input_softmax,
+ target,
+ weight,
+ buff,
+ grad_input,
+ gamma=ctx.gamma,
+ alpha=ctx.alpha)
+
+ grad_input *= grad_output
+ if ctx.reduction == ctx.reduction_dict['mean']:
+ grad_input /= input_softmax.size(0)
+ return grad_input, None, None, None, None, None
+
+
+softmax_focal_loss = SoftmaxFocalLossFunction.apply
+
+
+class SoftmaxFocalLoss(nn.Module):
+
+ def __init__(self, gamma, alpha, weight=None, reduction='mean'):
+ super(SoftmaxFocalLoss, self).__init__()
+ self.gamma = gamma
+ self.alpha = alpha
+ self.register_buffer('weight', weight)
+ self.reduction = reduction
+
+ def forward(self, input, target):
+ return softmax_focal_loss(input, target, self.gamma, self.alpha,
+ self.weight, self.reduction)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(gamma={self.gamma}, '
+ s += f'alpha={self.alpha}, '
+ s += f'reduction={self.reduction})'
+ return s
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/furthest_point_sample.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/furthest_point_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..374b7a878f1972c183941af28ba1df216ac1a60f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/furthest_point_sample.py
@@ -0,0 +1,83 @@
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'furthest_point_sampling_forward',
+ 'furthest_point_sampling_with_dist_forward'
+])
+
+
+class FurthestPointSampling(Function):
+ """Uses iterative furthest point sampling to select a set of features whose
+ corresponding points have the furthest distance."""
+
+ @staticmethod
+ def forward(ctx, points_xyz: torch.Tensor,
+ num_points: int) -> torch.Tensor:
+ """
+ Args:
+ points_xyz (Tensor): (B, N, 3) where N > num_points.
+ num_points (int): Number of points in the sampled set.
+
+ Returns:
+ Tensor: (B, num_points) indices of the sampled points.
+ """
+ assert points_xyz.is_contiguous()
+
+ B, N = points_xyz.size()[:2]
+ output = torch.cuda.IntTensor(B, num_points)
+ temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
+
+ ext_module.furthest_point_sampling_forward(
+ points_xyz,
+ temp,
+ output,
+ b=B,
+ n=N,
+ m=num_points,
+ )
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(output)
+ return output
+
+ @staticmethod
+ def backward(xyz, a=None):
+ return None, None
+
+
+class FurthestPointSamplingWithDist(Function):
+ """Uses iterative furthest point sampling to select a set of features whose
+ corresponding points have the furthest distance."""
+
+ @staticmethod
+ def forward(ctx, points_dist: torch.Tensor,
+ num_points: int) -> torch.Tensor:
+ """
+ Args:
+ points_dist (Tensor): (B, N, N) Distance between each point pair.
+ num_points (int): Number of points in the sampled set.
+
+ Returns:
+ Tensor: (B, num_points) indices of the sampled points.
+ """
+ assert points_dist.is_contiguous()
+
+ B, N, _ = points_dist.size()
+ output = points_dist.new_zeros([B, num_points], dtype=torch.int32)
+ temp = points_dist.new_zeros([B, N]).fill_(1e10)
+
+ ext_module.furthest_point_sampling_with_dist_forward(
+ points_dist, temp, output, b=B, n=N, m=num_points)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(output)
+ return output
+
+ @staticmethod
+ def backward(xyz, a=None):
+ return None, None
+
+
+furthest_point_sample = FurthestPointSampling.apply
+furthest_point_sample_with_dist = FurthestPointSamplingWithDist.apply
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/fused_bias_leakyrelu.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/fused_bias_leakyrelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d12508469c6c8fa1884debece44c58d158cb6fa
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/fused_bias_leakyrelu.py
@@ -0,0 +1,268 @@
+# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
+
+# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
+# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator
+# Augmentation (ADA)
+# =======================================================================
+
+# 1. Definitions
+
+# "Licensor" means any person or entity that distributes its Work.
+
+# "Software" means the original work of authorship made available under
+# this License.
+
+# "Work" means the Software and any additions to or derivative works of
+# the Software that are made available under this License.
+
+# The terms "reproduce," "reproduction," "derivative works," and
+# "distribution" have the meaning as provided under U.S. copyright law;
+# provided, however, that for the purposes of this License, derivative
+# works shall not include works that remain separable from, or merely
+# link (or bind by name) to the interfaces of, the Work.
+
+# Works, including the Software, are "made available" under this License
+# by including in or with the Work either (a) a copyright notice
+# referencing the applicability of this License to the Work, or (b) a
+# copy of this License.
+
+# 2. License Grants
+
+# 2.1 Copyright Grant. Subject to the terms and conditions of this
+# License, each Licensor grants to you a perpetual, worldwide,
+# non-exclusive, royalty-free, copyright license to reproduce,
+# prepare derivative works of, publicly display, publicly perform,
+# sublicense and distribute its Work and any resulting derivative
+# works in any form.
+
+# 3. Limitations
+
+# 3.1 Redistribution. You may reproduce or distribute the Work only
+# if (a) you do so under this License, (b) you include a complete
+# copy of this License with your distribution, and (c) you retain
+# without modification any copyright, patent, trademark, or
+# attribution notices that are present in the Work.
+
+# 3.2 Derivative Works. You may specify that additional or different
+# terms apply to the use, reproduction, and distribution of your
+# derivative works of the Work ("Your Terms") only if (a) Your Terms
+# provide that the use limitation in Section 3.3 applies to your
+# derivative works, and (b) you identify the specific derivative
+# works that are subject to Your Terms. Notwithstanding Your Terms,
+# this License (including the redistribution requirements in Section
+# 3.1) will continue to apply to the Work itself.
+
+# 3.3 Use Limitation. The Work and any derivative works thereof only
+# may be used or intended for use non-commercially. Notwithstanding
+# the foregoing, NVIDIA and its affiliates may use the Work and any
+# derivative works commercially. As used herein, "non-commercially"
+# means for research or evaluation purposes only.
+
+# 3.4 Patent Claims. If you bring or threaten to bring a patent claim
+# against any Licensor (including any claim, cross-claim or
+# counterclaim in a lawsuit) to enforce any patents that you allege
+# are infringed by any Work, then your rights under this License from
+# such Licensor (including the grant in Section 2.1) will terminate
+# immediately.
+
+# 3.5 Trademarks. This License does not grant any rights to use any
+# Licensor’s or its affiliates’ names, logos, or trademarks, except
+# as necessary to reproduce the notices described in this License.
+
+# 3.6 Termination. If you violate any term of this License, then your
+# rights under this License (including the grant in Section 2.1) will
+# terminate immediately.
+
+# 4. Disclaimer of Warranty.
+
+# THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
+# NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
+# THIS LICENSE.
+
+# 5. Limitation of Liability.
+
+# EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
+# THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
+# SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
+# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
+# OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
+# (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
+# LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
+# COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGES.
+
+# =======================================================================
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['fused_bias_leakyrelu'])
+
+
+class FusedBiasLeakyReLUFunctionBackward(Function):
+ """Calculate second order deviation.
+
+ This function is to compute the second order deviation for the fused leaky
+ relu operation.
+ """
+
+ @staticmethod
+ def forward(ctx, grad_output, out, negative_slope, scale):
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ empty = grad_output.new_empty(0)
+
+ grad_input = ext_module.fused_bias_leakyrelu(
+ grad_output,
+ empty,
+ out,
+ act=3,
+ grad=1,
+ alpha=negative_slope,
+ scale=scale)
+
+ dim = [0]
+
+ if grad_input.ndim > 2:
+ dim += list(range(2, grad_input.ndim))
+
+ grad_bias = grad_input.sum(dim).detach()
+
+ return grad_input, grad_bias
+
+ @staticmethod
+ def backward(ctx, gradgrad_input, gradgrad_bias):
+ out, = ctx.saved_tensors
+
+ # The second order deviation, in fact, contains two parts, while the
+ # the first part is zero. Thus, we direct consider the second part
+ # which is similar with the first order deviation in implementation.
+ gradgrad_out = ext_module.fused_bias_leakyrelu(
+ gradgrad_input,
+ gradgrad_bias.to(out.dtype),
+ out,
+ act=3,
+ grad=1,
+ alpha=ctx.negative_slope,
+ scale=ctx.scale)
+
+ return gradgrad_out, None, None, None
+
+
+class FusedBiasLeakyReLUFunction(Function):
+
+ @staticmethod
+ def forward(ctx, input, bias, negative_slope, scale):
+ empty = input.new_empty(0)
+
+ out = ext_module.fused_bias_leakyrelu(
+ input,
+ bias,
+ empty,
+ act=3,
+ grad=0,
+ alpha=negative_slope,
+ scale=scale)
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ out, = ctx.saved_tensors
+
+ grad_input, grad_bias = FusedBiasLeakyReLUFunctionBackward.apply(
+ grad_output, out, ctx.negative_slope, ctx.scale)
+
+ return grad_input, grad_bias, None, None
+
+
+class FusedBiasLeakyReLU(nn.Module):
+ """Fused bias leaky ReLU.
+
+ This function is introduced in the StyleGAN2:
+ http://arxiv.org/abs/1912.04958
+
+ The bias term comes from the convolution operation. In addition, to keep
+ the variance of the feature map or gradients unchanged, they also adopt a
+ scale similarly with Kaiming initialization. However, since the
+ :math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the
+ final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
+ your own scale.
+
+ TODO: Implement the CPU version.
+
+ Args:
+ channel (int): The channel number of the feature map.
+ negative_slope (float, optional): Same as nn.LeakyRelu.
+ Defaults to 0.2.
+ scale (float, optional): A scalar to adjust the variance of the feature
+ map. Defaults to 2**0.5.
+ """
+
+ def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5):
+ super(FusedBiasLeakyReLU, self).__init__()
+
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_bias_leakyrelu(input, self.bias, self.negative_slope,
+ self.scale)
+
+
+def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
+ """Fused bias leaky ReLU function.
+
+ This function is introduced in the StyleGAN2:
+ http://arxiv.org/abs/1912.04958
+
+ The bias term comes from the convolution operation. In addition, to keep
+ the variance of the feature map or gradients unchanged, they also adopt a
+ scale similarly with Kaiming initialization. However, since the
+ :math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the
+ final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
+ your own scale.
+
+ Args:
+ input (torch.Tensor): Input feature map.
+ bias (nn.Parameter): The bias from convolution operation.
+ negative_slope (float, optional): Same as nn.LeakyRelu.
+ Defaults to 0.2.
+ scale (float, optional): A scalar to adjust the variance of the feature
+ map. Defaults to 2**0.5.
+
+ Returns:
+ torch.Tensor: Feature map after non-linear activation.
+ """
+
+ if not input.is_cuda:
+ return bias_leakyrelu_ref(input, bias, negative_slope, scale)
+
+ return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype),
+ negative_slope, scale)
+
+
+def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5):
+
+ if bias is not None:
+ assert bias.ndim == 1
+ assert bias.shape[0] == x.shape[1]
+ x = x + bias.reshape([-1 if i == 1 else 1 for i in range(x.ndim)])
+
+ x = F.leaky_relu(x, negative_slope)
+ if scale != 1:
+ x = x * scale
+
+ return x
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/gather_points.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/gather_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..f52f1677d8ea0facafc56a3672d37adb44677ff3
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/gather_points.py
@@ -0,0 +1,57 @@
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['gather_points_forward', 'gather_points_backward'])
+
+
+class GatherPoints(Function):
+ """Gather points with given index."""
+
+ @staticmethod
+ def forward(ctx, features: torch.Tensor,
+ indices: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features (Tensor): (B, C, N) features to gather.
+ indices (Tensor): (B, M) where M is the number of points.
+
+ Returns:
+ Tensor: (B, C, M) where M is the number of points.
+ """
+ assert features.is_contiguous()
+ assert indices.is_contiguous()
+
+ B, npoint = indices.size()
+ _, C, N = features.size()
+ output = torch.cuda.FloatTensor(B, C, npoint)
+
+ ext_module.gather_points_forward(
+ features, indices, output, b=B, c=C, n=N, npoints=npoint)
+
+ ctx.for_backwards = (indices, C, N)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(indices)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ idx, C, N = ctx.for_backwards
+ B, npoint = idx.size()
+
+ grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
+ grad_out_data = grad_out.data.contiguous()
+ ext_module.gather_points_backward(
+ grad_out_data,
+ idx,
+ grad_features.data,
+ b=B,
+ c=C,
+ n=N,
+ npoints=npoint)
+ return grad_features, None
+
+
+gather_points = GatherPoints.apply
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/group_points.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/group_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c3ec9d758ebe4e1c2205882af4be154008253a5
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/group_points.py
@@ -0,0 +1,224 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple
+
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+from .ball_query import ball_query
+from .knn import knn
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['group_points_forward', 'group_points_backward'])
+
+
+class QueryAndGroup(nn.Module):
+ """Groups points with a ball query of radius.
+
+ Args:
+ max_radius (float): The maximum radius of the balls.
+ If None is given, we will use kNN sampling instead of ball query.
+ sample_num (int): Maximum number of features to gather in the ball.
+ min_radius (float, optional): The minimum radius of the balls.
+ Default: 0.
+ use_xyz (bool, optional): Whether to use xyz.
+ Default: True.
+ return_grouped_xyz (bool, optional): Whether to return grouped xyz.
+ Default: False.
+ normalize_xyz (bool, optional): Whether to normalize xyz.
+ Default: False.
+ uniform_sample (bool, optional): Whether to sample uniformly.
+ Default: False
+ return_unique_cnt (bool, optional): Whether to return the count of
+ unique samples. Default: False.
+ return_grouped_idx (bool, optional): Whether to return grouped idx.
+ Default: False.
+ """
+
+ def __init__(self,
+ max_radius,
+ sample_num,
+ min_radius=0,
+ use_xyz=True,
+ return_grouped_xyz=False,
+ normalize_xyz=False,
+ uniform_sample=False,
+ return_unique_cnt=False,
+ return_grouped_idx=False):
+ super().__init__()
+ self.max_radius = max_radius
+ self.min_radius = min_radius
+ self.sample_num = sample_num
+ self.use_xyz = use_xyz
+ self.return_grouped_xyz = return_grouped_xyz
+ self.normalize_xyz = normalize_xyz
+ self.uniform_sample = uniform_sample
+ self.return_unique_cnt = return_unique_cnt
+ self.return_grouped_idx = return_grouped_idx
+ if self.return_unique_cnt:
+ assert self.uniform_sample, \
+ 'uniform_sample should be True when ' \
+ 'returning the count of unique samples'
+ if self.max_radius is None:
+ assert not self.normalize_xyz, \
+ 'can not normalize grouped xyz when max_radius is None'
+
+ def forward(self, points_xyz, center_xyz, features=None):
+ """
+ Args:
+ points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ center_xyz (Tensor): (B, npoint, 3) coordinates of the centriods.
+ features (Tensor): (B, C, N) Descriptors of the features.
+
+ Returns:
+ Tensor: (B, 3 + C, npoint, sample_num) Grouped feature.
+ """
+ # if self.max_radius is None, we will perform kNN instead of ball query
+ # idx is of shape [B, npoint, sample_num]
+ if self.max_radius is None:
+ idx = knn(self.sample_num, points_xyz, center_xyz, False)
+ idx = idx.transpose(1, 2).contiguous()
+ else:
+ idx = ball_query(self.min_radius, self.max_radius, self.sample_num,
+ points_xyz, center_xyz)
+
+ if self.uniform_sample:
+ unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
+ for i_batch in range(idx.shape[0]):
+ for i_region in range(idx.shape[1]):
+ unique_ind = torch.unique(idx[i_batch, i_region, :])
+ num_unique = unique_ind.shape[0]
+ unique_cnt[i_batch, i_region] = num_unique
+ sample_ind = torch.randint(
+ 0,
+ num_unique, (self.sample_num - num_unique, ),
+ dtype=torch.long)
+ all_ind = torch.cat((unique_ind, unique_ind[sample_ind]))
+ idx[i_batch, i_region, :] = all_ind
+
+ xyz_trans = points_xyz.transpose(1, 2).contiguous()
+ # (B, 3, npoint, sample_num)
+ grouped_xyz = grouping_operation(xyz_trans, idx)
+ grouped_xyz_diff = grouped_xyz - \
+ center_xyz.transpose(1, 2).unsqueeze(-1) # relative offsets
+ if self.normalize_xyz:
+ grouped_xyz_diff /= self.max_radius
+
+ if features is not None:
+ grouped_features = grouping_operation(features, idx)
+ if self.use_xyz:
+ # (B, C + 3, npoint, sample_num)
+ new_features = torch.cat([grouped_xyz_diff, grouped_features],
+ dim=1)
+ else:
+ new_features = grouped_features
+ else:
+ assert (self.use_xyz
+ ), 'Cannot have not features and not use xyz as a feature!'
+ new_features = grouped_xyz_diff
+
+ ret = [new_features]
+ if self.return_grouped_xyz:
+ ret.append(grouped_xyz)
+ if self.return_unique_cnt:
+ ret.append(unique_cnt)
+ if self.return_grouped_idx:
+ ret.append(idx)
+ if len(ret) == 1:
+ return ret[0]
+ else:
+ return tuple(ret)
+
+
+class GroupAll(nn.Module):
+ """Group xyz with feature.
+
+ Args:
+ use_xyz (bool): Whether to use xyz.
+ """
+
+ def __init__(self, use_xyz: bool = True):
+ super().__init__()
+ self.use_xyz = use_xyz
+
+ def forward(self,
+ xyz: torch.Tensor,
+ new_xyz: torch.Tensor,
+ features: torch.Tensor = None):
+ """
+ Args:
+ xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ new_xyz (Tensor): new xyz coordinates of the features.
+ features (Tensor): (B, C, N) features to group.
+
+ Returns:
+ Tensor: (B, C + 3, 1, N) Grouped feature.
+ """
+ grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
+ if features is not None:
+ grouped_features = features.unsqueeze(2)
+ if self.use_xyz:
+ # (B, 3 + C, 1, N)
+ new_features = torch.cat([grouped_xyz, grouped_features],
+ dim=1)
+ else:
+ new_features = grouped_features
+ else:
+ new_features = grouped_xyz
+
+ return new_features
+
+
+class GroupingOperation(Function):
+ """Group feature with given index."""
+
+ @staticmethod
+ def forward(ctx, features: torch.Tensor,
+ indices: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features (Tensor): (B, C, N) tensor of features to group.
+ indices (Tensor): (B, npoint, nsample) the indices of
+ features to group with.
+
+ Returns:
+ Tensor: (B, C, npoint, nsample) Grouped features.
+ """
+ features = features.contiguous()
+ indices = indices.contiguous()
+
+ B, nfeatures, nsample = indices.size()
+ _, C, N = features.size()
+ output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
+
+ ext_module.group_points_forward(B, C, N, nfeatures, nsample, features,
+ indices, output)
+
+ ctx.for_backwards = (indices, N)
+ return output
+
+ @staticmethod
+ def backward(ctx,
+ grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients
+ of the output from forward.
+
+ Returns:
+ Tensor: (B, C, N) gradient of the features.
+ """
+ idx, N = ctx.for_backwards
+
+ B, C, npoint, nsample = grad_out.size()
+ grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
+
+ grad_out_data = grad_out.data.contiguous()
+ ext_module.group_points_backward(B, C, N, npoint, nsample,
+ grad_out_data, idx,
+ grad_features.data)
+ return grad_features, None
+
+
+grouping_operation = GroupingOperation.apply
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/info.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/info.py
new file mode 100644
index 0000000000000000000000000000000000000000..29f2e5598ae2bb5866ccd15a7d3b4de33c0cd14d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/info.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import glob
+import os
+
+import torch
+
+if torch.__version__ == 'parrots':
+ import parrots
+
+ def get_compiler_version():
+ return 'GCC ' + parrots.version.compiler
+
+ def get_compiling_cuda_version():
+ return parrots.version.cuda
+else:
+ from ..utils import ext_loader
+ ext_module = ext_loader.load_ext(
+ '_ext', ['get_compiler_version', 'get_compiling_cuda_version'])
+
+ def get_compiler_version():
+ return ext_module.get_compiler_version()
+
+ def get_compiling_cuda_version():
+ return ext_module.get_compiling_cuda_version()
+
+
+def get_onnxruntime_op_path():
+ wildcard = os.path.join(
+ os.path.abspath(os.path.dirname(os.path.dirname(__file__))),
+ '_ext_ort.*.so')
+
+ paths = glob.glob(wildcard)
+ if len(paths) > 0:
+ return paths[0]
+ else:
+ return ''
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/iou3d.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/iou3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fc71979190323f44c09f8b7e1761cf49cd2d76b
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/iou3d.py
@@ -0,0 +1,85 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'iou3d_boxes_iou_bev_forward', 'iou3d_nms_forward',
+ 'iou3d_nms_normal_forward'
+])
+
+
+def boxes_iou_bev(boxes_a, boxes_b):
+ """Calculate boxes IoU in the Bird's Eye View.
+
+ Args:
+ boxes_a (torch.Tensor): Input boxes a with shape (M, 5).
+ boxes_b (torch.Tensor): Input boxes b with shape (N, 5).
+
+ Returns:
+ ans_iou (torch.Tensor): IoU result with shape (M, N).
+ """
+ ans_iou = boxes_a.new_zeros(
+ torch.Size((boxes_a.shape[0], boxes_b.shape[0])))
+
+ ext_module.iou3d_boxes_iou_bev_forward(boxes_a.contiguous(),
+ boxes_b.contiguous(), ans_iou)
+
+ return ans_iou
+
+
+def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None):
+ """NMS function GPU implementation (for BEV boxes). The overlap of two
+ boxes for IoU calculation is defined as the exact overlapping area of the
+ two boxes. In this function, one can also set ``pre_max_size`` and
+ ``post_max_size``.
+
+ Args:
+ boxes (torch.Tensor): Input boxes with the shape of [N, 5]
+ ([x1, y1, x2, y2, ry]).
+ scores (torch.Tensor): Scores of boxes with the shape of [N].
+ thresh (float): Overlap threshold of NMS.
+ pre_max_size (int, optional): Max size of boxes before NMS.
+ Default: None.
+ post_max_size (int, optional): Max size of boxes after NMS.
+ Default: None.
+
+ Returns:
+ torch.Tensor: Indexes after NMS.
+ """
+ assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]'
+ order = scores.sort(0, descending=True)[1]
+
+ if pre_max_size is not None:
+ order = order[:pre_max_size]
+ boxes = boxes[order].contiguous()
+
+ keep = torch.zeros(boxes.size(0), dtype=torch.long)
+ num_out = ext_module.iou3d_nms_forward(boxes, keep, thresh)
+ keep = order[keep[:num_out].cuda(boxes.device)].contiguous()
+ if post_max_size is not None:
+ keep = keep[:post_max_size]
+ return keep
+
+
+def nms_normal_bev(boxes, scores, thresh):
+ """Normal NMS function GPU implementation (for BEV boxes). The overlap of
+ two boxes for IoU calculation is defined as the exact overlapping area of
+ the two boxes WITH their yaw angle set to 0.
+
+ Args:
+ boxes (torch.Tensor): Input boxes with shape (N, 5).
+ scores (torch.Tensor): Scores of predicted boxes with shape (N).
+ thresh (float): Overlap threshold of NMS.
+
+ Returns:
+ torch.Tensor: Remaining indices with scores in descending order.
+ """
+ assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]'
+ order = scores.sort(0, descending=True)[1]
+
+ boxes = boxes[order].contiguous()
+
+ keep = torch.zeros(boxes.size(0), dtype=torch.long)
+ num_out = ext_module.iou3d_nms_normal_forward(boxes, keep, thresh)
+ return order[keep[:num_out].cuda(boxes.device)].contiguous()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/knn.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/knn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f335785036669fc19239825b0aae6dde3f73bf92
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/knn.py
@@ -0,0 +1,77 @@
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['knn_forward'])
+
+
+class KNN(Function):
+ r"""KNN (CUDA) based on heap data structure.
+ Modified from `PAConv `_.
+
+ Find k-nearest points.
+ """
+
+ @staticmethod
+ def forward(ctx,
+ k: int,
+ xyz: torch.Tensor,
+ center_xyz: torch.Tensor = None,
+ transposed: bool = False) -> torch.Tensor:
+ """
+ Args:
+ k (int): number of nearest neighbors.
+ xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N).
+ xyz coordinates of the features.
+ center_xyz (Tensor, optional): (B, npoint, 3) if transposed ==
+ False, else (B, 3, npoint). centers of the knn query.
+ Default: None.
+ transposed (bool, optional): whether the input tensors are
+ transposed. Should not explicitly use this keyword when
+ calling knn (=KNN.apply), just add the fourth param.
+ Default: False.
+
+ Returns:
+ Tensor: (B, k, npoint) tensor with the indices of
+ the features that form k-nearest neighbours.
+ """
+ assert (k > 0) & (k < 100), 'k should be in range(0, 100)'
+
+ if center_xyz is None:
+ center_xyz = xyz
+
+ if transposed:
+ xyz = xyz.transpose(2, 1).contiguous()
+ center_xyz = center_xyz.transpose(2, 1).contiguous()
+
+ assert xyz.is_contiguous() # [B, N, 3]
+ assert center_xyz.is_contiguous() # [B, npoint, 3]
+
+ center_xyz_device = center_xyz.get_device()
+ assert center_xyz_device == xyz.get_device(), \
+ 'center_xyz and xyz should be put on the same device'
+ if torch.cuda.current_device() != center_xyz_device:
+ torch.cuda.set_device(center_xyz_device)
+
+ B, npoint, _ = center_xyz.shape
+ N = xyz.shape[1]
+
+ idx = center_xyz.new_zeros((B, npoint, k)).int()
+ dist2 = center_xyz.new_zeros((B, npoint, k)).float()
+
+ ext_module.knn_forward(
+ xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k)
+ # idx shape to [B, k, npoint]
+ idx = idx.transpose(2, 1).contiguous()
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(idx)
+ return idx
+
+ @staticmethod
+ def backward(ctx, a=None):
+ return None, None, None
+
+
+knn = KNN.apply
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/masked_conv.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/masked_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd514cc204c1d571ea5dc7e74b038c0f477a008b
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/masked_conv.py
@@ -0,0 +1,111 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['masked_im2col_forward', 'masked_col2im_forward'])
+
+
+class MaskedConv2dFunction(Function):
+
+ @staticmethod
+ def symbolic(g, features, mask, weight, bias, padding, stride):
+ return g.op(
+ 'mmcv::MMCVMaskedConv2d',
+ features,
+ mask,
+ weight,
+ bias,
+ padding_i=padding,
+ stride_i=stride)
+
+ @staticmethod
+ def forward(ctx, features, mask, weight, bias, padding=0, stride=1):
+ assert mask.dim() == 3 and mask.size(0) == 1
+ assert features.dim() == 4 and features.size(0) == 1
+ assert features.size()[2:] == mask.size()[1:]
+ pad_h, pad_w = _pair(padding)
+ stride_h, stride_w = _pair(stride)
+ if stride_h != 1 or stride_w != 1:
+ raise ValueError(
+ 'Stride could not only be 1 in masked_conv2d currently.')
+ out_channel, in_channel, kernel_h, kernel_w = weight.size()
+
+ batch_size = features.size(0)
+ out_h = int(
+ math.floor((features.size(2) + 2 * pad_h -
+ (kernel_h - 1) - 1) / stride_h + 1))
+ out_w = int(
+ math.floor((features.size(3) + 2 * pad_w -
+ (kernel_h - 1) - 1) / stride_w + 1))
+ mask_inds = torch.nonzero(mask[0] > 0, as_tuple=False)
+ output = features.new_zeros(batch_size, out_channel, out_h, out_w)
+ if mask_inds.numel() > 0:
+ mask_h_idx = mask_inds[:, 0].contiguous()
+ mask_w_idx = mask_inds[:, 1].contiguous()
+ data_col = features.new_zeros(in_channel * kernel_h * kernel_w,
+ mask_inds.size(0))
+ ext_module.masked_im2col_forward(
+ features,
+ mask_h_idx,
+ mask_w_idx,
+ data_col,
+ kernel_h=kernel_h,
+ kernel_w=kernel_w,
+ pad_h=pad_h,
+ pad_w=pad_w)
+
+ masked_output = torch.addmm(1, bias[:, None], 1,
+ weight.view(out_channel, -1), data_col)
+ ext_module.masked_col2im_forward(
+ masked_output,
+ mask_h_idx,
+ mask_w_idx,
+ output,
+ height=out_h,
+ width=out_w,
+ channels=out_channel)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ return (None, ) * 5
+
+
+masked_conv2d = MaskedConv2dFunction.apply
+
+
+class MaskedConv2d(nn.Conv2d):
+ """A MaskedConv2d which inherits the official Conv2d.
+
+ The masked forward doesn't implement the backward function and only
+ supports the stride parameter to be 1 currently.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True):
+ super(MaskedConv2d,
+ self).__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias)
+
+ def forward(self, input, mask=None):
+ if mask is None: # fallback to the normal Conv2d
+ return super(MaskedConv2d, self).forward(input)
+ else:
+ return masked_conv2d(input, mask, self.weight, self.bias,
+ self.padding)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/merge_cells.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/merge_cells.py
new file mode 100644
index 0000000000000000000000000000000000000000..48ca8cc0a8aca8432835bd760c0403a3c35b34cf
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/merge_cells.py
@@ -0,0 +1,149 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import abstractmethod
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..cnn import ConvModule
+
+
+class BaseMergeCell(nn.Module):
+ """The basic class for cells used in NAS-FPN and NAS-FCOS.
+
+ BaseMergeCell takes 2 inputs. After applying convolution
+ on them, they are resized to the target size. Then,
+ they go through binary_op, which depends on the type of cell.
+ If with_out_conv is True, the result of output will go through
+ another convolution layer.
+
+ Args:
+ in_channels (int): number of input channels in out_conv layer.
+ out_channels (int): number of output channels in out_conv layer.
+ with_out_conv (bool): Whether to use out_conv layer
+ out_conv_cfg (dict): Config dict for convolution layer, which should
+ contain "groups", "kernel_size", "padding", "bias" to build
+ out_conv layer.
+ out_norm_cfg (dict): Config dict for normalization layer in out_conv.
+ out_conv_order (tuple): The order of conv/norm/activation layers in
+ out_conv.
+ with_input1_conv (bool): Whether to use convolution on input1.
+ with_input2_conv (bool): Whether to use convolution on input2.
+ input_conv_cfg (dict): Config dict for building input1_conv layer and
+ input2_conv layer, which is expected to contain the type of
+ convolution.
+ Default: None, which means using conv2d.
+ input_norm_cfg (dict): Config dict for normalization layer in
+ input1_conv and input2_conv layer. Default: None.
+ upsample_mode (str): Interpolation method used to resize the output
+ of input1_conv and input2_conv to target size. Currently, we
+ support ['nearest', 'bilinear']. Default: 'nearest'.
+ """
+
+ def __init__(self,
+ fused_channels=256,
+ out_channels=256,
+ with_out_conv=True,
+ out_conv_cfg=dict(
+ groups=1, kernel_size=3, padding=1, bias=True),
+ out_norm_cfg=None,
+ out_conv_order=('act', 'conv', 'norm'),
+ with_input1_conv=False,
+ with_input2_conv=False,
+ input_conv_cfg=None,
+ input_norm_cfg=None,
+ upsample_mode='nearest'):
+ super(BaseMergeCell, self).__init__()
+ assert upsample_mode in ['nearest', 'bilinear']
+ self.with_out_conv = with_out_conv
+ self.with_input1_conv = with_input1_conv
+ self.with_input2_conv = with_input2_conv
+ self.upsample_mode = upsample_mode
+
+ if self.with_out_conv:
+ self.out_conv = ConvModule(
+ fused_channels,
+ out_channels,
+ **out_conv_cfg,
+ norm_cfg=out_norm_cfg,
+ order=out_conv_order)
+
+ self.input1_conv = self._build_input_conv(
+ out_channels, input_conv_cfg,
+ input_norm_cfg) if with_input1_conv else nn.Sequential()
+ self.input2_conv = self._build_input_conv(
+ out_channels, input_conv_cfg,
+ input_norm_cfg) if with_input2_conv else nn.Sequential()
+
+ def _build_input_conv(self, channel, conv_cfg, norm_cfg):
+ return ConvModule(
+ channel,
+ channel,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ bias=True)
+
+ @abstractmethod
+ def _binary_op(self, x1, x2):
+ pass
+
+ def _resize(self, x, size):
+ if x.shape[-2:] == size:
+ return x
+ elif x.shape[-2:] < size:
+ return F.interpolate(x, size=size, mode=self.upsample_mode)
+ else:
+ assert x.shape[-2] % size[-2] == 0 and x.shape[-1] % size[-1] == 0
+ kernel_size = x.shape[-1] // size[-1]
+ x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
+ return x
+
+ def forward(self, x1, x2, out_size=None):
+ assert x1.shape[:2] == x2.shape[:2]
+ assert out_size is None or len(out_size) == 2
+ if out_size is None: # resize to larger one
+ out_size = max(x1.size()[2:], x2.size()[2:])
+
+ x1 = self.input1_conv(x1)
+ x2 = self.input2_conv(x2)
+
+ x1 = self._resize(x1, out_size)
+ x2 = self._resize(x2, out_size)
+
+ x = self._binary_op(x1, x2)
+ if self.with_out_conv:
+ x = self.out_conv(x)
+ return x
+
+
+class SumCell(BaseMergeCell):
+
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(SumCell, self).__init__(in_channels, out_channels, **kwargs)
+
+ def _binary_op(self, x1, x2):
+ return x1 + x2
+
+
+class ConcatCell(BaseMergeCell):
+
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(ConcatCell, self).__init__(in_channels * 2, out_channels,
+ **kwargs)
+
+ def _binary_op(self, x1, x2):
+ ret = torch.cat([x1, x2], dim=1)
+ return ret
+
+
+class GlobalPoolingCell(BaseMergeCell):
+
+ def __init__(self, in_channels=None, out_channels=None, **kwargs):
+ super().__init__(in_channels, out_channels, **kwargs)
+ self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
+
+ def _binary_op(self, x1, x2):
+ x2_att = self.global_pool(x2).sigmoid()
+ return x2 + x2_att * x1
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/modulated_deform_conv.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/modulated_deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..f97278361d5262b1a87432dc5e3eb842b39ceb10
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/modulated_deform_conv.py
@@ -0,0 +1,282 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair, _single
+
+from annotator.mmpkg.mmcv.utils import deprecated_api_warning
+from ..cnn import CONV_LAYERS
+from ..utils import ext_loader, print_log
+
+ext_module = ext_loader.load_ext(
+ '_ext',
+ ['modulated_deform_conv_forward', 'modulated_deform_conv_backward'])
+
+
+class ModulatedDeformConv2dFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, offset, mask, weight, bias, stride, padding,
+ dilation, groups, deform_groups):
+ input_tensors = [input, offset, mask, weight]
+ if bias is not None:
+ input_tensors.append(bias)
+ return g.op(
+ 'mmcv::MMCVModulatedDeformConv2d',
+ *input_tensors,
+ stride_i=stride,
+ padding_i=padding,
+ dilation_i=dilation,
+ groups_i=groups,
+ deform_groups_i=deform_groups)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ mask,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deform_groups=1):
+ if input is not None and input.dim() != 4:
+ raise ValueError(
+ f'Expected 4D tensor as input, got {input.dim()}D tensor \
+ instead.')
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deform_groups = deform_groups
+ ctx.with_bias = bias is not None
+ if not ctx.with_bias:
+ bias = input.new_empty(0) # fake tensor
+ # When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
+ # amp won't cast the type of model (float32), but "offset" is cast
+ # to float16 by nn.Conv2d automatically, leading to the type
+ # mismatch with input (when it is float32) or weight.
+ # The flag for whether to use fp16 or amp is the type of "offset",
+ # we cast weight and input to temporarily support fp16 and amp
+ # whatever the pytorch version is.
+ input = input.type_as(offset)
+ weight = weight.type_as(input)
+ ctx.save_for_backward(input, offset, mask, weight, bias)
+ output = input.new_empty(
+ ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+ ext_module.modulated_deform_conv_forward(
+ input,
+ weight,
+ bias,
+ ctx._bufs[0],
+ offset,
+ mask,
+ output,
+ ctx._bufs[1],
+ kernel_h=weight.size(2),
+ kernel_w=weight.size(3),
+ stride_h=ctx.stride[0],
+ stride_w=ctx.stride[1],
+ pad_h=ctx.padding[0],
+ pad_w=ctx.padding[1],
+ dilation_h=ctx.dilation[0],
+ dilation_w=ctx.dilation[1],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ with_bias=ctx.with_bias)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, mask, weight, bias = ctx.saved_tensors
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ grad_mask = torch.zeros_like(mask)
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(bias)
+ grad_output = grad_output.contiguous()
+ ext_module.modulated_deform_conv_backward(
+ input,
+ weight,
+ bias,
+ ctx._bufs[0],
+ offset,
+ mask,
+ ctx._bufs[1],
+ grad_input,
+ grad_weight,
+ grad_bias,
+ grad_offset,
+ grad_mask,
+ grad_output,
+ kernel_h=weight.size(2),
+ kernel_w=weight.size(3),
+ stride_h=ctx.stride[0],
+ stride_w=ctx.stride[1],
+ pad_h=ctx.padding[0],
+ pad_w=ctx.padding[1],
+ dilation_h=ctx.dilation[0],
+ dilation_w=ctx.dilation[1],
+ group=ctx.groups,
+ deformable_group=ctx.deform_groups,
+ with_bias=ctx.with_bias)
+ if not ctx.with_bias:
+ grad_bias = None
+
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
+ None, None, None, None, None)
+
+ @staticmethod
+ def _output_size(ctx, input, weight):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = ctx.padding[d]
+ kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = ctx.stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError(
+ 'convolution input is too small (output would be ' +
+ 'x'.join(map(str, output_size)) + ')')
+ return output_size
+
+
+modulated_deform_conv2d = ModulatedDeformConv2dFunction.apply
+
+
+class ModulatedDeformConv2d(nn.Module):
+
+ @deprecated_api_warning({'deformable_groups': 'deform_groups'},
+ cls_name='ModulatedDeformConv2d')
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deform_groups=1,
+ bias=True):
+ super(ModulatedDeformConv2d, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deform_groups = deform_groups
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // groups,
+ *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.init_weights()
+
+ def init_weights(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.zero_()
+
+ def forward(self, x, offset, mask):
+ return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
+ self.stride, self.padding,
+ self.dilation, self.groups,
+ self.deform_groups)
+
+
+@CONV_LAYERS.register_module('DCNv2')
+class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
+ """A ModulatedDeformable Conv Encapsulation that acts as normal Conv
+ layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int): Same as nn.Conv2d, while tuple is not supported.
+ padding (int): Same as nn.Conv2d, while tuple is not supported.
+ dilation (int): Same as nn.Conv2d, while tuple is not supported.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(ModulatedDeformConv2dPack, self).__init__(*args, **kwargs)
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ bias=True)
+ self.init_weights()
+
+ def init_weights(self):
+ super(ModulatedDeformConv2dPack, self).init_weights()
+ if hasattr(self, 'conv_offset'):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ out = self.conv_offset(x)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
+ self.stride, self.padding,
+ self.dilation, self.groups,
+ self.deform_groups)
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ version = local_metadata.get('version', None)
+
+ if version is None or version < 2:
+ # the key is different in early versions
+ # In version < 2, ModulatedDeformConvPack
+ # loads previous benchmark models.
+ if (prefix + 'conv_offset.weight' not in state_dict
+ and prefix[:-1] + '_offset.weight' in state_dict):
+ state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
+ prefix[:-1] + '_offset.weight')
+ if (prefix + 'conv_offset.bias' not in state_dict
+ and prefix[:-1] + '_offset.bias' in state_dict):
+ state_dict[prefix +
+ 'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
+ '_offset.bias')
+
+ if version is not None and version > 1:
+ print_log(
+ f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to '
+ 'version 2.',
+ logger='root')
+
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys,
+ error_msgs)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/multi_scale_deform_attn.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/multi_scale_deform_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe755eaa931565aab77ecc387990328c01447343
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/multi_scale_deform_attn.py
@@ -0,0 +1,358 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd.function import Function, once_differentiable
+
+from annotator.mmpkg.mmcv import deprecated_api_warning
+from annotator.mmpkg.mmcv.cnn import constant_init, xavier_init
+from annotator.mmpkg.mmcv.cnn.bricks.registry import ATTENTION
+from annotator.mmpkg.mmcv.runner import BaseModule
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
+
+
+class MultiScaleDeformableAttnFunction(Function):
+
+ @staticmethod
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index,
+ sampling_locations, attention_weights, im2col_step):
+ """GPU version of multi-scale deformable attention.
+
+ Args:
+ value (Tensor): The value has shape
+ (bs, num_keys, mum_heads, embed_dims//num_heads)
+ value_spatial_shapes (Tensor): Spatial shape of
+ each feature map, has shape (num_levels, 2),
+ last dimension 2 represent (h, w)
+ sampling_locations (Tensor): The location of sampling points,
+ has shape
+ (bs ,num_queries, num_heads, num_levels, num_points, 2),
+ the last dimension 2 represent (x, y).
+ attention_weights (Tensor): The weight of sampling points used
+ when calculate the attention, has shape
+ (bs ,num_queries, num_heads, num_levels, num_points),
+ im2col_step (Tensor): The step used in image to column.
+
+ Returns:
+ Tensor: has shape (bs, num_queries, embed_dims)
+ """
+
+ ctx.im2col_step = im2col_step
+ output = ext_module.ms_deform_attn_forward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ im2col_step=ctx.im2col_step)
+ ctx.save_for_backward(value, value_spatial_shapes,
+ value_level_start_index, sampling_locations,
+ attention_weights)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ """GPU version of backward function.
+
+ Args:
+ grad_output (Tensor): Gradient
+ of output tensor of forward.
+
+ Returns:
+ Tuple[Tensor]: Gradient
+ of input tensors in forward.
+ """
+ value, value_spatial_shapes, value_level_start_index,\
+ sampling_locations, attention_weights = ctx.saved_tensors
+ grad_value = torch.zeros_like(value)
+ grad_sampling_loc = torch.zeros_like(sampling_locations)
+ grad_attn_weight = torch.zeros_like(attention_weights)
+
+ ext_module.ms_deform_attn_backward(
+ value,
+ value_spatial_shapes,
+ value_level_start_index,
+ sampling_locations,
+ attention_weights,
+ grad_output.contiguous(),
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight,
+ im2col_step=ctx.im2col_step)
+
+ return grad_value, None, None, \
+ grad_sampling_loc, grad_attn_weight, None
+
+
+def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
+ sampling_locations, attention_weights):
+ """CPU version of multi-scale deformable attention.
+
+ Args:
+ value (Tensor): The value has shape
+ (bs, num_keys, mum_heads, embed_dims//num_heads)
+ value_spatial_shapes (Tensor): Spatial shape of
+ each feature map, has shape (num_levels, 2),
+ last dimension 2 represent (h, w)
+ sampling_locations (Tensor): The location of sampling points,
+ has shape
+ (bs ,num_queries, num_heads, num_levels, num_points, 2),
+ the last dimension 2 represent (x, y).
+ attention_weights (Tensor): The weight of sampling points used
+ when calculate the attention, has shape
+ (bs ,num_queries, num_heads, num_levels, num_points),
+
+ Returns:
+ Tensor: has shape (bs, num_queries, embed_dims)
+ """
+
+ bs, _, num_heads, embed_dims = value.shape
+ _, num_queries, num_heads, num_levels, num_points, _ =\
+ sampling_locations.shape
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
+ dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for level, (H_, W_) in enumerate(value_spatial_shapes):
+ # bs, H_*W_, num_heads, embed_dims ->
+ # bs, H_*W_, num_heads*embed_dims ->
+ # bs, num_heads*embed_dims, H_*W_ ->
+ # bs*num_heads, embed_dims, H_, W_
+ value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
+ bs * num_heads, embed_dims, H_, W_)
+ # bs, num_queries, num_heads, num_points, 2 ->
+ # bs, num_heads, num_queries, num_points, 2 ->
+ # bs*num_heads, num_queries, num_points, 2
+ sampling_grid_l_ = sampling_grids[:, :, :,
+ level].transpose(1, 2).flatten(0, 1)
+ # bs*num_heads, embed_dims, num_queries, num_points
+ sampling_value_l_ = F.grid_sample(
+ value_l_,
+ sampling_grid_l_,
+ mode='bilinear',
+ padding_mode='zeros',
+ align_corners=False)
+ sampling_value_list.append(sampling_value_l_)
+ # (bs, num_queries, num_heads, num_levels, num_points) ->
+ # (bs, num_heads, num_queries, num_levels, num_points) ->
+ # (bs, num_heads, 1, num_queries, num_levels*num_points)
+ attention_weights = attention_weights.transpose(1, 2).reshape(
+ bs * num_heads, 1, num_queries, num_levels * num_points)
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
+ attention_weights).sum(-1).view(bs, num_heads * embed_dims,
+ num_queries)
+ return output.transpose(1, 2).contiguous()
+
+
+@ATTENTION.register_module()
+class MultiScaleDeformableAttention(BaseModule):
+ """An attention module used in Deformable-Detr.
+
+ `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
+ `_.
+
+ Args:
+ embed_dims (int): The embedding dimension of Attention.
+ Default: 256.
+ num_heads (int): Parallel attention heads. Default: 64.
+ num_levels (int): The number of feature map used in
+ Attention. Default: 4.
+ num_points (int): The number of sampling points for
+ each query in each head. Default: 4.
+ im2col_step (int): The step used in image_to_column.
+ Default: 64.
+ dropout (float): A Dropout layer on `inp_identity`.
+ Default: 0.1.
+ batch_first (bool): Key, Query and Value are shape of
+ (batch, n, embed_dim)
+ or (n, batch, embed_dim). Default to False.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: None.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims=256,
+ num_heads=8,
+ num_levels=4,
+ num_points=4,
+ im2col_step=64,
+ dropout=0.1,
+ batch_first=False,
+ norm_cfg=None,
+ init_cfg=None):
+ super().__init__(init_cfg)
+ if embed_dims % num_heads != 0:
+ raise ValueError(f'embed_dims must be divisible by num_heads, '
+ f'but got {embed_dims} and {num_heads}')
+ dim_per_head = embed_dims // num_heads
+ self.norm_cfg = norm_cfg
+ self.dropout = nn.Dropout(dropout)
+ self.batch_first = batch_first
+
+ # you'd better set dim_per_head to a power of 2
+ # which is more efficient in the CUDA implementation
+ def _is_power_of_2(n):
+ if (not isinstance(n, int)) or (n < 0):
+ raise ValueError(
+ 'invalid input for _is_power_of_2: {} (type: {})'.format(
+ n, type(n)))
+ return (n & (n - 1) == 0) and n != 0
+
+ if not _is_power_of_2(dim_per_head):
+ warnings.warn(
+ "You'd better set embed_dims in "
+ 'MultiScaleDeformAttention to make '
+ 'the dimension of each attention head a power of 2 '
+ 'which is more efficient in our CUDA implementation.')
+
+ self.im2col_step = im2col_step
+ self.embed_dims = embed_dims
+ self.num_levels = num_levels
+ self.num_heads = num_heads
+ self.num_points = num_points
+ self.sampling_offsets = nn.Linear(
+ embed_dims, num_heads * num_levels * num_points * 2)
+ self.attention_weights = nn.Linear(embed_dims,
+ num_heads * num_levels * num_points)
+ self.value_proj = nn.Linear(embed_dims, embed_dims)
+ self.output_proj = nn.Linear(embed_dims, embed_dims)
+ self.init_weights()
+
+ def init_weights(self):
+ """Default initialization for Parameters of Module."""
+ constant_init(self.sampling_offsets, 0.)
+ thetas = torch.arange(
+ self.num_heads,
+ dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (grid_init /
+ grid_init.abs().max(-1, keepdim=True)[0]).view(
+ self.num_heads, 1, 1,
+ 2).repeat(1, self.num_levels, self.num_points, 1)
+ for i in range(self.num_points):
+ grid_init[:, :, i, :] *= i + 1
+
+ self.sampling_offsets.bias.data = grid_init.view(-1)
+ constant_init(self.attention_weights, val=0., bias=0.)
+ xavier_init(self.value_proj, distribution='uniform', bias=0.)
+ xavier_init(self.output_proj, distribution='uniform', bias=0.)
+ self._is_init = True
+
+ @deprecated_api_warning({'residual': 'identity'},
+ cls_name='MultiScaleDeformableAttention')
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ identity=None,
+ query_pos=None,
+ key_padding_mask=None,
+ reference_points=None,
+ spatial_shapes=None,
+ level_start_index=None,
+ **kwargs):
+ """Forward Function of MultiScaleDeformAttention.
+
+ Args:
+ query (Tensor): Query of Transformer with shape
+ (num_query, bs, embed_dims).
+ key (Tensor): The key tensor with shape
+ `(num_key, bs, embed_dims)`.
+ value (Tensor): The value tensor with shape
+ `(num_key, bs, embed_dims)`.
+ identity (Tensor): The tensor used for addition, with the
+ same shape as `query`. Default None. If None,
+ `query` will be used.
+ query_pos (Tensor): The positional encoding for `query`.
+ Default: None.
+ key_pos (Tensor): The positional encoding for `key`. Default
+ None.
+ reference_points (Tensor): The normalized reference
+ points with shape (bs, num_query, num_levels, 2),
+ all elements is range in [0, 1], top-left (0,0),
+ bottom-right (1, 1), including padding area.
+ or (N, Length_{query}, num_levels, 4), add
+ additional two dimensions is (w, h) to
+ form reference boxes.
+ key_padding_mask (Tensor): ByteTensor for `query`, with
+ shape [bs, num_key].
+ spatial_shapes (Tensor): Spatial shape of features in
+ different levels. With shape (num_levels, 2),
+ last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape ``(num_levels, )`` and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+
+ Returns:
+ Tensor: forwarded results with shape [num_query, bs, embed_dims].
+ """
+
+ if value is None:
+ value = query
+
+ if identity is None:
+ identity = query
+ if query_pos is not None:
+ query = query + query_pos
+ if not self.batch_first:
+ # change to (bs, num_query ,embed_dims)
+ query = query.permute(1, 0, 2)
+ value = value.permute(1, 0, 2)
+
+ bs, num_query, _ = query.shape
+ bs, num_value, _ = value.shape
+ assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
+
+ value = self.value_proj(value)
+ if key_padding_mask is not None:
+ value = value.masked_fill(key_padding_mask[..., None], 0.0)
+ value = value.view(bs, num_value, self.num_heads, -1)
+ sampling_offsets = self.sampling_offsets(query).view(
+ bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
+ attention_weights = self.attention_weights(query).view(
+ bs, num_query, self.num_heads, self.num_levels * self.num_points)
+ attention_weights = attention_weights.softmax(-1)
+
+ attention_weights = attention_weights.view(bs, num_query,
+ self.num_heads,
+ self.num_levels,
+ self.num_points)
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack(
+ [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
+ sampling_locations = reference_points[:, :, None, :, None, :] \
+ + sampling_offsets \
+ / offset_normalizer[None, None, None, :, None, :]
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
+ + sampling_offsets / self.num_points \
+ * reference_points[:, :, None, :, None, 2:] \
+ * 0.5
+ else:
+ raise ValueError(
+ f'Last dim of reference_points must be'
+ f' 2 or 4, but get {reference_points.shape[-1]} instead.')
+ if torch.cuda.is_available() and value.is_cuda:
+ output = MultiScaleDeformableAttnFunction.apply(
+ value, spatial_shapes, level_start_index, sampling_locations,
+ attention_weights, self.im2col_step)
+ else:
+ output = multi_scale_deformable_attn_pytorch(
+ value, spatial_shapes, sampling_locations, attention_weights)
+
+ output = self.output_proj(output)
+
+ if not self.batch_first:
+ # (num_query, bs ,embed_dims)
+ output = output.permute(1, 0, 2)
+
+ return self.dropout(output) + identity
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/nms.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..908ac66645eef29fb55fce82497eb9f6af1a2667
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/nms.py
@@ -0,0 +1,417 @@
+import os
+
+import numpy as np
+import torch
+
+from annotator.mmpkg.mmcv.utils import deprecated_api_warning
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['nms', 'softnms', 'nms_match', 'nms_rotated'])
+
+
+# This function is modified from: https://github.com/pytorch/vision/
+class NMSop(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, bboxes, scores, iou_threshold, offset, score_threshold,
+ max_num):
+ is_filtering_by_score = score_threshold > 0
+ if is_filtering_by_score:
+ valid_mask = scores > score_threshold
+ bboxes, scores = bboxes[valid_mask], scores[valid_mask]
+ valid_inds = torch.nonzero(
+ valid_mask, as_tuple=False).squeeze(dim=1)
+
+ inds = ext_module.nms(
+ bboxes, scores, iou_threshold=float(iou_threshold), offset=offset)
+
+ if max_num > 0:
+ inds = inds[:max_num]
+ if is_filtering_by_score:
+ inds = valid_inds[inds]
+ return inds
+
+ @staticmethod
+ def symbolic(g, bboxes, scores, iou_threshold, offset, score_threshold,
+ max_num):
+ from ..onnx import is_custom_op_loaded
+ has_custom_op = is_custom_op_loaded()
+ # TensorRT nms plugin is aligned with original nms in ONNXRuntime
+ is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT'
+ if has_custom_op and (not is_trt_backend):
+ return g.op(
+ 'mmcv::NonMaxSuppression',
+ bboxes,
+ scores,
+ iou_threshold_f=float(iou_threshold),
+ offset_i=int(offset))
+ else:
+ from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze
+ from ..onnx.onnx_utils.symbolic_helper import _size_helper
+
+ boxes = unsqueeze(g, bboxes, 0)
+ scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
+
+ if max_num > 0:
+ max_num = g.op(
+ 'Constant',
+ value_t=torch.tensor(max_num, dtype=torch.long))
+ else:
+ dim = g.op('Constant', value_t=torch.tensor(0))
+ max_num = _size_helper(g, bboxes, dim)
+ max_output_per_class = max_num
+ iou_threshold = g.op(
+ 'Constant',
+ value_t=torch.tensor([iou_threshold], dtype=torch.float))
+ score_threshold = g.op(
+ 'Constant',
+ value_t=torch.tensor([score_threshold], dtype=torch.float))
+ nms_out = g.op('NonMaxSuppression', boxes, scores,
+ max_output_per_class, iou_threshold,
+ score_threshold)
+ return squeeze(
+ g,
+ select(
+ g, nms_out, 1,
+ g.op(
+ 'Constant',
+ value_t=torch.tensor([2], dtype=torch.long))), 1)
+
+
+class SoftNMSop(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, boxes, scores, iou_threshold, sigma, min_score, method,
+ offset):
+ dets = boxes.new_empty((boxes.size(0), 5), device='cpu')
+ inds = ext_module.softnms(
+ boxes.cpu(),
+ scores.cpu(),
+ dets.cpu(),
+ iou_threshold=float(iou_threshold),
+ sigma=float(sigma),
+ min_score=float(min_score),
+ method=int(method),
+ offset=int(offset))
+ return dets, inds
+
+ @staticmethod
+ def symbolic(g, boxes, scores, iou_threshold, sigma, min_score, method,
+ offset):
+ from packaging import version
+ assert version.parse(torch.__version__) >= version.parse('1.7.0')
+ nms_out = g.op(
+ 'mmcv::SoftNonMaxSuppression',
+ boxes,
+ scores,
+ iou_threshold_f=float(iou_threshold),
+ sigma_f=float(sigma),
+ min_score_f=float(min_score),
+ method_i=int(method),
+ offset_i=int(offset),
+ outputs=2)
+ return nms_out
+
+
+@deprecated_api_warning({'iou_thr': 'iou_threshold'})
+def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1):
+ """Dispatch to either CPU or GPU NMS implementations.
+
+ The input can be either torch tensor or numpy array. GPU NMS will be used
+ if the input is gpu tensor, otherwise CPU NMS
+ will be used. The returned type will always be the same as inputs.
+
+ Arguments:
+ boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4).
+ scores (torch.Tensor or np.ndarray): scores in shape (N, ).
+ iou_threshold (float): IoU threshold for NMS.
+ offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset).
+ score_threshold (float): score threshold for NMS.
+ max_num (int): maximum number of boxes after NMS.
+
+ Returns:
+ tuple: kept dets(boxes and scores) and indice, which is always the \
+ same data type as the input.
+
+ Example:
+ >>> boxes = np.array([[49.1, 32.4, 51.0, 35.9],
+ >>> [49.3, 32.9, 51.0, 35.3],
+ >>> [49.2, 31.8, 51.0, 35.4],
+ >>> [35.1, 11.5, 39.1, 15.7],
+ >>> [35.6, 11.8, 39.3, 14.2],
+ >>> [35.3, 11.5, 39.9, 14.5],
+ >>> [35.2, 11.7, 39.7, 15.7]], dtype=np.float32)
+ >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.5, 0.4, 0.3],\
+ dtype=np.float32)
+ >>> iou_threshold = 0.6
+ >>> dets, inds = nms(boxes, scores, iou_threshold)
+ >>> assert len(inds) == len(dets) == 3
+ """
+ assert isinstance(boxes, (torch.Tensor, np.ndarray))
+ assert isinstance(scores, (torch.Tensor, np.ndarray))
+ is_numpy = False
+ if isinstance(boxes, np.ndarray):
+ is_numpy = True
+ boxes = torch.from_numpy(boxes)
+ if isinstance(scores, np.ndarray):
+ scores = torch.from_numpy(scores)
+ assert boxes.size(1) == 4
+ assert boxes.size(0) == scores.size(0)
+ assert offset in (0, 1)
+
+ if torch.__version__ == 'parrots':
+ indata_list = [boxes, scores]
+ indata_dict = {
+ 'iou_threshold': float(iou_threshold),
+ 'offset': int(offset)
+ }
+ inds = ext_module.nms(*indata_list, **indata_dict)
+ else:
+ inds = NMSop.apply(boxes, scores, iou_threshold, offset,
+ score_threshold, max_num)
+ dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1)
+ if is_numpy:
+ dets = dets.cpu().numpy()
+ inds = inds.cpu().numpy()
+ return dets, inds
+
+
+@deprecated_api_warning({'iou_thr': 'iou_threshold'})
+def soft_nms(boxes,
+ scores,
+ iou_threshold=0.3,
+ sigma=0.5,
+ min_score=1e-3,
+ method='linear',
+ offset=0):
+ """Dispatch to only CPU Soft NMS implementations.
+
+ The input can be either a torch tensor or numpy array.
+ The returned type will always be the same as inputs.
+
+ Arguments:
+ boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4).
+ scores (torch.Tensor or np.ndarray): scores in shape (N, ).
+ iou_threshold (float): IoU threshold for NMS.
+ sigma (float): hyperparameter for gaussian method
+ min_score (float): score filter threshold
+ method (str): either 'linear' or 'gaussian'
+ offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset).
+
+ Returns:
+ tuple: kept dets(boxes and scores) and indice, which is always the \
+ same data type as the input.
+
+ Example:
+ >>> boxes = np.array([[4., 3., 5., 3.],
+ >>> [4., 3., 5., 4.],
+ >>> [3., 1., 3., 1.],
+ >>> [3., 1., 3., 1.],
+ >>> [3., 1., 3., 1.],
+ >>> [3., 1., 3., 1.]], dtype=np.float32)
+ >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.4, 0.0], dtype=np.float32)
+ >>> iou_threshold = 0.6
+ >>> dets, inds = soft_nms(boxes, scores, iou_threshold, sigma=0.5)
+ >>> assert len(inds) == len(dets) == 5
+ """
+
+ assert isinstance(boxes, (torch.Tensor, np.ndarray))
+ assert isinstance(scores, (torch.Tensor, np.ndarray))
+ is_numpy = False
+ if isinstance(boxes, np.ndarray):
+ is_numpy = True
+ boxes = torch.from_numpy(boxes)
+ if isinstance(scores, np.ndarray):
+ scores = torch.from_numpy(scores)
+ assert boxes.size(1) == 4
+ assert boxes.size(0) == scores.size(0)
+ assert offset in (0, 1)
+ method_dict = {'naive': 0, 'linear': 1, 'gaussian': 2}
+ assert method in method_dict.keys()
+
+ if torch.__version__ == 'parrots':
+ dets = boxes.new_empty((boxes.size(0), 5), device='cpu')
+ indata_list = [boxes.cpu(), scores.cpu(), dets.cpu()]
+ indata_dict = {
+ 'iou_threshold': float(iou_threshold),
+ 'sigma': float(sigma),
+ 'min_score': min_score,
+ 'method': method_dict[method],
+ 'offset': int(offset)
+ }
+ inds = ext_module.softnms(*indata_list, **indata_dict)
+ else:
+ dets, inds = SoftNMSop.apply(boxes.cpu(), scores.cpu(),
+ float(iou_threshold), float(sigma),
+ float(min_score), method_dict[method],
+ int(offset))
+
+ dets = dets[:inds.size(0)]
+
+ if is_numpy:
+ dets = dets.cpu().numpy()
+ inds = inds.cpu().numpy()
+ return dets, inds
+ else:
+ return dets.to(device=boxes.device), inds.to(device=boxes.device)
+
+
+def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
+ """Performs non-maximum suppression in a batched fashion.
+
+ Modified from https://github.com/pytorch/vision/blob
+ /505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39.
+ In order to perform NMS independently per class, we add an offset to all
+ the boxes. The offset is dependent only on the class idx, and is large
+ enough so that boxes from different classes do not overlap.
+
+ Arguments:
+ boxes (torch.Tensor): boxes in shape (N, 4).
+ scores (torch.Tensor): scores in shape (N, ).
+ idxs (torch.Tensor): each index value correspond to a bbox cluster,
+ and NMS will not be applied between elements of different idxs,
+ shape (N, ).
+ nms_cfg (dict): specify nms type and other parameters like iou_thr.
+ Possible keys includes the following.
+
+ - iou_thr (float): IoU threshold used for NMS.
+ - split_thr (float): threshold number of boxes. In some cases the
+ number of boxes is large (e.g., 200k). To avoid OOM during
+ training, the users could set `split_thr` to a small value.
+ If the number of boxes is greater than the threshold, it will
+ perform NMS on each group of boxes separately and sequentially.
+ Defaults to 10000.
+ class_agnostic (bool): if true, nms is class agnostic,
+ i.e. IoU thresholding happens over all boxes,
+ regardless of the predicted class.
+
+ Returns:
+ tuple: kept dets and indice.
+ """
+ nms_cfg_ = nms_cfg.copy()
+ class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic)
+ if class_agnostic:
+ boxes_for_nms = boxes
+ else:
+ max_coordinate = boxes.max()
+ offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
+ boxes_for_nms = boxes + offsets[:, None]
+
+ nms_type = nms_cfg_.pop('type', 'nms')
+ nms_op = eval(nms_type)
+
+ split_thr = nms_cfg_.pop('split_thr', 10000)
+ # Won't split to multiple nms nodes when exporting to onnx
+ if boxes_for_nms.shape[0] < split_thr or torch.onnx.is_in_onnx_export():
+ dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_)
+ boxes = boxes[keep]
+ # -1 indexing works abnormal in TensorRT
+ # This assumes `dets` has 5 dimensions where
+ # the last dimension is score.
+ # TODO: more elegant way to handle the dimension issue.
+ # Some type of nms would reweight the score, such as SoftNMS
+ scores = dets[:, 4]
+ else:
+ max_num = nms_cfg_.pop('max_num', -1)
+ total_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
+ # Some type of nms would reweight the score, such as SoftNMS
+ scores_after_nms = scores.new_zeros(scores.size())
+ for id in torch.unique(idxs):
+ mask = (idxs == id).nonzero(as_tuple=False).view(-1)
+ dets, keep = nms_op(boxes_for_nms[mask], scores[mask], **nms_cfg_)
+ total_mask[mask[keep]] = True
+ scores_after_nms[mask[keep]] = dets[:, -1]
+ keep = total_mask.nonzero(as_tuple=False).view(-1)
+
+ scores, inds = scores_after_nms[keep].sort(descending=True)
+ keep = keep[inds]
+ boxes = boxes[keep]
+
+ if max_num > 0:
+ keep = keep[:max_num]
+ boxes = boxes[:max_num]
+ scores = scores[:max_num]
+
+ return torch.cat([boxes, scores[:, None]], -1), keep
+
+
+def nms_match(dets, iou_threshold):
+ """Matched dets into different groups by NMS.
+
+ NMS match is Similar to NMS but when a bbox is suppressed, nms match will
+ record the indice of suppressed bbox and form a group with the indice of
+ kept bbox. In each group, indice is sorted as score order.
+
+ Arguments:
+ dets (torch.Tensor | np.ndarray): Det boxes with scores, shape (N, 5).
+ iou_thr (float): IoU thresh for NMS.
+
+ Returns:
+ List[torch.Tensor | np.ndarray]: The outer list corresponds different
+ matched group, the inner Tensor corresponds the indices for a group
+ in score order.
+ """
+ if dets.shape[0] == 0:
+ matched = []
+ else:
+ assert dets.shape[-1] == 5, 'inputs dets.shape should be (N, 5), ' \
+ f'but get {dets.shape}'
+ if isinstance(dets, torch.Tensor):
+ dets_t = dets.detach().cpu()
+ else:
+ dets_t = torch.from_numpy(dets)
+ indata_list = [dets_t]
+ indata_dict = {'iou_threshold': float(iou_threshold)}
+ matched = ext_module.nms_match(*indata_list, **indata_dict)
+ if torch.__version__ == 'parrots':
+ matched = matched.tolist()
+
+ if isinstance(dets, torch.Tensor):
+ return [dets.new_tensor(m, dtype=torch.long) for m in matched]
+ else:
+ return [np.array(m, dtype=np.int) for m in matched]
+
+
+def nms_rotated(dets, scores, iou_threshold, labels=None):
+ """Performs non-maximum suppression (NMS) on the rotated boxes according to
+ their intersection-over-union (IoU).
+
+ Rotated NMS iteratively removes lower scoring rotated boxes which have an
+ IoU greater than iou_threshold with another (higher scoring) rotated box.
+
+ Args:
+ boxes (Tensor): Rotated boxes in shape (N, 5). They are expected to \
+ be in (x_ctr, y_ctr, width, height, angle_radian) format.
+ scores (Tensor): scores in shape (N, ).
+ iou_threshold (float): IoU thresh for NMS.
+ labels (Tensor): boxes' label in shape (N,).
+
+ Returns:
+ tuple: kept dets(boxes and scores) and indice, which is always the \
+ same data type as the input.
+ """
+ if dets.shape[0] == 0:
+ return dets, None
+ multi_label = labels is not None
+ if multi_label:
+ dets_wl = torch.cat((dets, labels.unsqueeze(1)), 1)
+ else:
+ dets_wl = dets
+ _, order = scores.sort(0, descending=True)
+ dets_sorted = dets_wl.index_select(0, order)
+
+ if torch.__version__ == 'parrots':
+ keep_inds = ext_module.nms_rotated(
+ dets_wl,
+ scores,
+ order,
+ dets_sorted,
+ iou_threshold=iou_threshold,
+ multi_label=multi_label)
+ else:
+ keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
+ iou_threshold, multi_label)
+ dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
+ dim=1)
+ return dets, keep_inds
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/pixel_group.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/pixel_group.py
new file mode 100644
index 0000000000000000000000000000000000000000..2143c75f835a467c802fc3c37ecd3ac0f85bcda4
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/pixel_group.py
@@ -0,0 +1,75 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['pixel_group'])
+
+
+def pixel_group(score, mask, embedding, kernel_label, kernel_contour,
+ kernel_region_num, distance_threshold):
+ """Group pixels into text instances, which is widely used text detection
+ methods.
+
+ Arguments:
+ score (np.array or Tensor): The foreground score with size hxw.
+ mask (np.array or Tensor): The foreground mask with size hxw.
+ embedding (np.array or Tensor): The embedding with size hxwxc to
+ distinguish instances.
+ kernel_label (np.array or Tensor): The instance kernel index with
+ size hxw.
+ kernel_contour (np.array or Tensor): The kernel contour with size hxw.
+ kernel_region_num (int): The instance kernel region number.
+ distance_threshold (float): The embedding distance threshold between
+ kernel and pixel in one instance.
+
+ Returns:
+ pixel_assignment (List[List[float]]): The instance coordinate list.
+ Each element consists of averaged confidence, pixel number, and
+ coordinates (x_i, y_i for all pixels) in order.
+ """
+ assert isinstance(score, (torch.Tensor, np.ndarray))
+ assert isinstance(mask, (torch.Tensor, np.ndarray))
+ assert isinstance(embedding, (torch.Tensor, np.ndarray))
+ assert isinstance(kernel_label, (torch.Tensor, np.ndarray))
+ assert isinstance(kernel_contour, (torch.Tensor, np.ndarray))
+ assert isinstance(kernel_region_num, int)
+ assert isinstance(distance_threshold, float)
+
+ if isinstance(score, np.ndarray):
+ score = torch.from_numpy(score)
+ if isinstance(mask, np.ndarray):
+ mask = torch.from_numpy(mask)
+ if isinstance(embedding, np.ndarray):
+ embedding = torch.from_numpy(embedding)
+ if isinstance(kernel_label, np.ndarray):
+ kernel_label = torch.from_numpy(kernel_label)
+ if isinstance(kernel_contour, np.ndarray):
+ kernel_contour = torch.from_numpy(kernel_contour)
+
+ if torch.__version__ == 'parrots':
+ label = ext_module.pixel_group(
+ score,
+ mask,
+ embedding,
+ kernel_label,
+ kernel_contour,
+ kernel_region_num=kernel_region_num,
+ distance_threshold=distance_threshold)
+ label = label.tolist()
+ label = label[0]
+ list_index = kernel_region_num
+ pixel_assignment = []
+ for x in range(kernel_region_num):
+ pixel_assignment.append(
+ np.array(
+ label[list_index:list_index + int(label[x])],
+ dtype=np.float))
+ list_index = list_index + int(label[x])
+ else:
+ pixel_assignment = ext_module.pixel_group(score, mask, embedding,
+ kernel_label, kernel_contour,
+ kernel_region_num,
+ distance_threshold)
+ return pixel_assignment
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/point_sample.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/point_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..08b1617805fa84e1c8afc61f3263b4b86bd2a136
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/point_sample.py
@@ -0,0 +1,336 @@
+# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa
+
+from os import path as osp
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.modules.utils import _pair
+from torch.onnx.operators import shape_as_tensor
+
+
+def bilinear_grid_sample(im, grid, align_corners=False):
+ """Given an input and a flow-field grid, computes the output using input
+ values and pixel locations from grid. Supported only bilinear interpolation
+ method to sample the input pixels.
+
+ Args:
+ im (torch.Tensor): Input feature map, shape (N, C, H, W)
+ grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
+ align_corners {bool}: If set to True, the extrema (-1 and 1) are
+ considered as referring to the center points of the input’s
+ corner pixels. If set to False, they are instead considered as
+ referring to the corner points of the input’s corner pixels,
+ making the sampling more resolution agnostic.
+ Returns:
+ torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
+ """
+ n, c, h, w = im.shape
+ gn, gh, gw, _ = grid.shape
+ assert n == gn
+
+ x = grid[:, :, :, 0]
+ y = grid[:, :, :, 1]
+
+ if align_corners:
+ x = ((x + 1) / 2) * (w - 1)
+ y = ((y + 1) / 2) * (h - 1)
+ else:
+ x = ((x + 1) * w - 1) / 2
+ y = ((y + 1) * h - 1) / 2
+
+ x = x.view(n, -1)
+ y = y.view(n, -1)
+
+ x0 = torch.floor(x).long()
+ y0 = torch.floor(y).long()
+ x1 = x0 + 1
+ y1 = y0 + 1
+
+ wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
+ wb = ((x1 - x) * (y - y0)).unsqueeze(1)
+ wc = ((x - x0) * (y1 - y)).unsqueeze(1)
+ wd = ((x - x0) * (y - y0)).unsqueeze(1)
+
+ # Apply default for grid_sample function zero padding
+ im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
+ padded_h = h + 2
+ padded_w = w + 2
+ # save points positions after padding
+ x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
+
+ # Clip coordinates to padded image size
+ x0 = torch.where(x0 < 0, torch.tensor(0), x0)
+ x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
+ x1 = torch.where(x1 < 0, torch.tensor(0), x1)
+ x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
+ y0 = torch.where(y0 < 0, torch.tensor(0), y0)
+ y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
+ y1 = torch.where(y1 < 0, torch.tensor(0), y1)
+ y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)
+
+ im_padded = im_padded.view(n, c, -1)
+
+ x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
+ x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
+
+ Ia = torch.gather(im_padded, 2, x0_y0)
+ Ib = torch.gather(im_padded, 2, x0_y1)
+ Ic = torch.gather(im_padded, 2, x1_y0)
+ Id = torch.gather(im_padded, 2, x1_y1)
+
+ return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
+
+
+def is_in_onnx_export_without_custom_ops():
+ from annotator.mmpkg.mmcv.ops import get_onnxruntime_op_path
+ ort_custom_op_path = get_onnxruntime_op_path()
+ return torch.onnx.is_in_onnx_export(
+ ) and not osp.exists(ort_custom_op_path)
+
+
+def normalize(grid):
+ """Normalize input grid from [-1, 1] to [0, 1]
+ Args:
+ grid (Tensor): The grid to be normalize, range [-1, 1].
+ Returns:
+ Tensor: Normalized grid, range [0, 1].
+ """
+
+ return (grid + 1.0) / 2.0
+
+
+def denormalize(grid):
+ """Denormalize input grid from range [0, 1] to [-1, 1]
+ Args:
+ grid (Tensor): The grid to be denormalize, range [0, 1].
+ Returns:
+ Tensor: Denormalized grid, range [-1, 1].
+ """
+
+ return grid * 2.0 - 1.0
+
+
+def generate_grid(num_grid, size, device):
+ """Generate regular square grid of points in [0, 1] x [0, 1] coordinate
+ space.
+
+ Args:
+ num_grid (int): The number of grids to sample, one for each region.
+ size (tuple(int, int)): The side size of the regular grid.
+ device (torch.device): Desired device of returned tensor.
+
+ Returns:
+ (torch.Tensor): A tensor of shape (num_grid, size[0]*size[1], 2) that
+ contains coordinates for the regular grids.
+ """
+
+ affine_trans = torch.tensor([[[1., 0., 0.], [0., 1., 0.]]], device=device)
+ grid = F.affine_grid(
+ affine_trans, torch.Size((1, 1, *size)), align_corners=False)
+ grid = normalize(grid)
+ return grid.view(1, -1, 2).expand(num_grid, -1, -1)
+
+
+def rel_roi_point_to_abs_img_point(rois, rel_roi_points):
+ """Convert roi based relative point coordinates to image based absolute
+ point coordinates.
+
+ Args:
+ rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5)
+ rel_roi_points (Tensor): Point coordinates inside RoI, relative to
+ RoI, location, range (0, 1), shape (N, P, 2)
+ Returns:
+ Tensor: Image based absolute point coordinates, shape (N, P, 2)
+ """
+
+ with torch.no_grad():
+ assert rel_roi_points.size(0) == rois.size(0)
+ assert rois.dim() == 2
+ assert rel_roi_points.dim() == 3
+ assert rel_roi_points.size(2) == 2
+ # remove batch idx
+ if rois.size(1) == 5:
+ rois = rois[:, 1:]
+ abs_img_points = rel_roi_points.clone()
+ # To avoid an error during exporting to onnx use independent
+ # variables instead inplace computation
+ xs = abs_img_points[:, :, 0] * (rois[:, None, 2] - rois[:, None, 0])
+ ys = abs_img_points[:, :, 1] * (rois[:, None, 3] - rois[:, None, 1])
+ xs += rois[:, None, 0]
+ ys += rois[:, None, 1]
+ abs_img_points = torch.stack([xs, ys], dim=2)
+ return abs_img_points
+
+
+def get_shape_from_feature_map(x):
+ """Get spatial resolution of input feature map considering exporting to
+ onnx mode.
+
+ Args:
+ x (torch.Tensor): Input tensor, shape (N, C, H, W)
+ Returns:
+ torch.Tensor: Spatial resolution (width, height), shape (1, 1, 2)
+ """
+ if torch.onnx.is_in_onnx_export():
+ img_shape = shape_as_tensor(x)[2:].flip(0).view(1, 1, 2).to(
+ x.device).float()
+ else:
+ img_shape = torch.tensor(x.shape[2:]).flip(0).view(1, 1, 2).to(
+ x.device).float()
+ return img_shape
+
+
+def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.):
+ """Convert image based absolute point coordinates to image based relative
+ coordinates for sampling.
+
+ Args:
+ abs_img_points (Tensor): Image based absolute point coordinates,
+ shape (N, P, 2)
+ img (tuple/Tensor): (height, width) of image or feature map.
+ spatial_scale (float): Scale points by this factor. Default: 1.
+
+ Returns:
+ Tensor: Image based relative point coordinates for sampling,
+ shape (N, P, 2)
+ """
+
+ assert (isinstance(img, tuple) and len(img) == 2) or \
+ (isinstance(img, torch.Tensor) and len(img.shape) == 4)
+
+ if isinstance(img, tuple):
+ h, w = img
+ scale = torch.tensor([w, h],
+ dtype=torch.float,
+ device=abs_img_points.device)
+ scale = scale.view(1, 1, 2)
+ else:
+ scale = get_shape_from_feature_map(img)
+
+ return abs_img_points / scale * spatial_scale
+
+
+def rel_roi_point_to_rel_img_point(rois,
+ rel_roi_points,
+ img,
+ spatial_scale=1.):
+ """Convert roi based relative point coordinates to image based absolute
+ point coordinates.
+
+ Args:
+ rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5)
+ rel_roi_points (Tensor): Point coordinates inside RoI, relative to
+ RoI, location, range (0, 1), shape (N, P, 2)
+ img (tuple/Tensor): (height, width) of image or feature map.
+ spatial_scale (float): Scale points by this factor. Default: 1.
+
+ Returns:
+ Tensor: Image based relative point coordinates for sampling,
+ shape (N, P, 2)
+ """
+
+ abs_img_point = rel_roi_point_to_abs_img_point(rois, rel_roi_points)
+ rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img,
+ spatial_scale)
+
+ return rel_img_point
+
+
+def point_sample(input, points, align_corners=False, **kwargs):
+ """A wrapper around :func:`grid_sample` to support 3D point_coords tensors
+ Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to
+ lie inside ``[0, 1] x [0, 1]`` square.
+
+ Args:
+ input (Tensor): Feature map, shape (N, C, H, W).
+ points (Tensor): Image based absolute point coordinates (normalized),
+ range [0, 1] x [0, 1], shape (N, P, 2) or (N, Hgrid, Wgrid, 2).
+ align_corners (bool): Whether align_corners. Default: False
+
+ Returns:
+ Tensor: Features of `point` on `input`, shape (N, C, P) or
+ (N, C, Hgrid, Wgrid).
+ """
+
+ add_dim = False
+ if points.dim() == 3:
+ add_dim = True
+ points = points.unsqueeze(2)
+ if is_in_onnx_export_without_custom_ops():
+ # If custom ops for onnx runtime not compiled use python
+ # implementation of grid_sample function to make onnx graph
+ # with supported nodes
+ output = bilinear_grid_sample(
+ input, denormalize(points), align_corners=align_corners)
+ else:
+ output = F.grid_sample(
+ input, denormalize(points), align_corners=align_corners, **kwargs)
+ if add_dim:
+ output = output.squeeze(3)
+ return output
+
+
+class SimpleRoIAlign(nn.Module):
+
+ def __init__(self, output_size, spatial_scale, aligned=True):
+ """Simple RoI align in PointRend, faster than standard RoIAlign.
+
+ Args:
+ output_size (tuple[int]): h, w
+ spatial_scale (float): scale the input boxes by this number
+ aligned (bool): if False, use the legacy implementation in
+ MMDetection, align_corners=True will be used in F.grid_sample.
+ If True, align the results more perfectly.
+ """
+
+ super(SimpleRoIAlign, self).__init__()
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+ # to be consistent with other RoI ops
+ self.use_torchvision = False
+ self.aligned = aligned
+
+ def forward(self, features, rois):
+ num_imgs = features.size(0)
+ num_rois = rois.size(0)
+ rel_roi_points = generate_grid(
+ num_rois, self.output_size, device=rois.device)
+
+ if torch.onnx.is_in_onnx_export():
+ rel_img_points = rel_roi_point_to_rel_img_point(
+ rois, rel_roi_points, features, self.spatial_scale)
+ rel_img_points = rel_img_points.reshape(num_imgs, -1,
+ *rel_img_points.shape[1:])
+ point_feats = point_sample(
+ features, rel_img_points, align_corners=not self.aligned)
+ point_feats = point_feats.transpose(1, 2)
+ else:
+ point_feats = []
+ for batch_ind in range(num_imgs):
+ # unravel batch dim
+ feat = features[batch_ind].unsqueeze(0)
+ inds = (rois[:, 0].long() == batch_ind)
+ if inds.any():
+ rel_img_points = rel_roi_point_to_rel_img_point(
+ rois[inds], rel_roi_points[inds], feat,
+ self.spatial_scale).unsqueeze(0)
+ point_feat = point_sample(
+ feat, rel_img_points, align_corners=not self.aligned)
+ point_feat = point_feat.squeeze(0).transpose(0, 1)
+ point_feats.append(point_feat)
+
+ point_feats = torch.cat(point_feats, dim=0)
+
+ channels = features.size(1)
+ roi_feats = point_feats.reshape(num_rois, channels, *self.output_size)
+
+ return roi_feats
+
+ def __repr__(self):
+ format_str = self.__class__.__name__
+ format_str += '(output_size={}, spatial_scale={}'.format(
+ self.output_size, self.spatial_scale)
+ return format_str
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/points_in_boxes.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/points_in_boxes.py
new file mode 100644
index 0000000000000000000000000000000000000000..4003173a53052161dbcd687a2fa1d755642fdab8
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/points_in_boxes.py
@@ -0,0 +1,133 @@
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'points_in_boxes_part_forward', 'points_in_boxes_cpu_forward',
+ 'points_in_boxes_all_forward'
+])
+
+
+def points_in_boxes_part(points, boxes):
+ """Find the box in which each point is (CUDA).
+
+ Args:
+ points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
+ boxes (torch.Tensor): [B, T, 7],
+ num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz] in
+ LiDAR/DEPTH coordinate, (x, y, z) is the bottom center
+
+ Returns:
+ box_idxs_of_pts (torch.Tensor): (B, M), default background = -1
+ """
+ assert points.shape[0] == boxes.shape[0], \
+ 'Points and boxes should have the same batch size, ' \
+ f'but got {points.shape[0]} and {boxes.shape[0]}'
+ assert boxes.shape[2] == 7, \
+ 'boxes dimension should be 7, ' \
+ f'but got unexpected shape {boxes.shape[2]}'
+ assert points.shape[2] == 3, \
+ 'points dimension should be 3, ' \
+ f'but got unexpected shape {points.shape[2]}'
+ batch_size, num_points, _ = points.shape
+
+ box_idxs_of_pts = points.new_zeros((batch_size, num_points),
+ dtype=torch.int).fill_(-1)
+
+ # If manually put the tensor 'points' or 'boxes' on a device
+ # which is not the current device, some temporary variables
+ # will be created on the current device in the cuda op,
+ # and the output will be incorrect.
+ # Therefore, we force the current device to be the same
+ # as the device of the tensors if it was not.
+ # Please refer to https://github.com/open-mmlab/mmdetection3d/issues/305
+ # for the incorrect output before the fix.
+ points_device = points.get_device()
+ assert points_device == boxes.get_device(), \
+ 'Points and boxes should be put on the same device'
+ if torch.cuda.current_device() != points_device:
+ torch.cuda.set_device(points_device)
+
+ ext_module.points_in_boxes_part_forward(boxes.contiguous(),
+ points.contiguous(),
+ box_idxs_of_pts)
+
+ return box_idxs_of_pts
+
+
+def points_in_boxes_cpu(points, boxes):
+ """Find all boxes in which each point is (CPU). The CPU version of
+ :meth:`points_in_boxes_all`.
+
+ Args:
+ points (torch.Tensor): [B, M, 3], [x, y, z] in
+ LiDAR/DEPTH coordinate
+ boxes (torch.Tensor): [B, T, 7],
+ num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
+ (x, y, z) is the bottom center.
+
+ Returns:
+ box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0.
+ """
+ assert points.shape[0] == boxes.shape[0], \
+ 'Points and boxes should have the same batch size, ' \
+ f'but got {points.shape[0]} and {boxes.shape[0]}'
+ assert boxes.shape[2] == 7, \
+ 'boxes dimension should be 7, ' \
+ f'but got unexpected shape {boxes.shape[2]}'
+ assert points.shape[2] == 3, \
+ 'points dimension should be 3, ' \
+ f'but got unexpected shape {points.shape[2]}'
+ batch_size, num_points, _ = points.shape
+ num_boxes = boxes.shape[1]
+
+ point_indices = points.new_zeros((batch_size, num_boxes, num_points),
+ dtype=torch.int)
+ for b in range(batch_size):
+ ext_module.points_in_boxes_cpu_forward(boxes[b].float().contiguous(),
+ points[b].float().contiguous(),
+ point_indices[b])
+ point_indices = point_indices.transpose(1, 2)
+
+ return point_indices
+
+
+def points_in_boxes_all(points, boxes):
+ """Find all boxes in which each point is (CUDA).
+
+ Args:
+ points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
+ boxes (torch.Tensor): [B, T, 7],
+ num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
+ (x, y, z) is the bottom center.
+
+ Returns:
+ box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0.
+ """
+ assert boxes.shape[0] == points.shape[0], \
+ 'Points and boxes should have the same batch size, ' \
+ f'but got {boxes.shape[0]} and {boxes.shape[0]}'
+ assert boxes.shape[2] == 7, \
+ 'boxes dimension should be 7, ' \
+ f'but got unexpected shape {boxes.shape[2]}'
+ assert points.shape[2] == 3, \
+ 'points dimension should be 3, ' \
+ f'but got unexpected shape {points.shape[2]}'
+ batch_size, num_points, _ = points.shape
+ num_boxes = boxes.shape[1]
+
+ box_idxs_of_pts = points.new_zeros((batch_size, num_points, num_boxes),
+ dtype=torch.int).fill_(0)
+
+ # Same reason as line 25-32
+ points_device = points.get_device()
+ assert points_device == boxes.get_device(), \
+ 'Points and boxes should be put on the same device'
+ if torch.cuda.current_device() != points_device:
+ torch.cuda.set_device(points_device)
+
+ ext_module.points_in_boxes_all_forward(boxes.contiguous(),
+ points.contiguous(),
+ box_idxs_of_pts)
+
+ return box_idxs_of_pts
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/points_sampler.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/points_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae1a24f939dd0e2934765326363ea51c2f2b4cca
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/points_sampler.py
@@ -0,0 +1,177 @@
+from typing import List
+
+import torch
+from torch import nn as nn
+
+from annotator.mmpkg.mmcv.runner import force_fp32
+from .furthest_point_sample import (furthest_point_sample,
+ furthest_point_sample_with_dist)
+
+
+def calc_square_dist(point_feat_a, point_feat_b, norm=True):
+ """Calculating square distance between a and b.
+
+ Args:
+ point_feat_a (Tensor): (B, N, C) Feature vector of each point.
+ point_feat_b (Tensor): (B, M, C) Feature vector of each point.
+ norm (Bool, optional): Whether to normalize the distance.
+ Default: True.
+
+ Returns:
+ Tensor: (B, N, M) Distance between each pair points.
+ """
+ num_channel = point_feat_a.shape[-1]
+ # [bs, n, 1]
+ a_square = torch.sum(point_feat_a.unsqueeze(dim=2).pow(2), dim=-1)
+ # [bs, 1, m]
+ b_square = torch.sum(point_feat_b.unsqueeze(dim=1).pow(2), dim=-1)
+
+ corr_matrix = torch.matmul(point_feat_a, point_feat_b.transpose(1, 2))
+
+ dist = a_square + b_square - 2 * corr_matrix
+ if norm:
+ dist = torch.sqrt(dist) / num_channel
+ return dist
+
+
+def get_sampler_cls(sampler_type):
+ """Get the type and mode of points sampler.
+
+ Args:
+ sampler_type (str): The type of points sampler.
+ The valid value are "D-FPS", "F-FPS", or "FS".
+
+ Returns:
+ class: Points sampler type.
+ """
+ sampler_mappings = {
+ 'D-FPS': DFPSSampler,
+ 'F-FPS': FFPSSampler,
+ 'FS': FSSampler,
+ }
+ try:
+ return sampler_mappings[sampler_type]
+ except KeyError:
+ raise KeyError(
+ f'Supported `sampler_type` are {sampler_mappings.keys()}, but got \
+ {sampler_type}')
+
+
+class PointsSampler(nn.Module):
+ """Points sampling.
+
+ Args:
+ num_point (list[int]): Number of sample points.
+ fps_mod_list (list[str], optional): Type of FPS method, valid mod
+ ['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS'].
+ F-FPS: using feature distances for FPS.
+ D-FPS: using Euclidean distances of points for FPS.
+ FS: using F-FPS and D-FPS simultaneously.
+ fps_sample_range_list (list[int], optional):
+ Range of points to apply FPS. Default: [-1].
+ """
+
+ def __init__(self,
+ num_point: List[int],
+ fps_mod_list: List[str] = ['D-FPS'],
+ fps_sample_range_list: List[int] = [-1]):
+ super().__init__()
+ # FPS would be applied to different fps_mod in the list,
+ # so the length of the num_point should be equal to
+ # fps_mod_list and fps_sample_range_list.
+ assert len(num_point) == len(fps_mod_list) == len(
+ fps_sample_range_list)
+ self.num_point = num_point
+ self.fps_sample_range_list = fps_sample_range_list
+ self.samplers = nn.ModuleList()
+ for fps_mod in fps_mod_list:
+ self.samplers.append(get_sampler_cls(fps_mod)())
+ self.fp16_enabled = False
+
+ @force_fp32()
+ def forward(self, points_xyz, features):
+ """
+ Args:
+ points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+ features (Tensor): (B, C, N) Descriptors of the features.
+
+ Returns:
+ Tensor: (B, npoint, sample_num) Indices of sampled points.
+ """
+ indices = []
+ last_fps_end_index = 0
+
+ for fps_sample_range, sampler, npoint in zip(
+ self.fps_sample_range_list, self.samplers, self.num_point):
+ assert fps_sample_range < points_xyz.shape[1]
+
+ if fps_sample_range == -1:
+ sample_points_xyz = points_xyz[:, last_fps_end_index:]
+ if features is not None:
+ sample_features = features[:, :, last_fps_end_index:]
+ else:
+ sample_features = None
+ else:
+ sample_points_xyz = \
+ points_xyz[:, last_fps_end_index:fps_sample_range]
+ if features is not None:
+ sample_features = features[:, :, last_fps_end_index:
+ fps_sample_range]
+ else:
+ sample_features = None
+
+ fps_idx = sampler(sample_points_xyz.contiguous(), sample_features,
+ npoint)
+
+ indices.append(fps_idx + last_fps_end_index)
+ last_fps_end_index += fps_sample_range
+ indices = torch.cat(indices, dim=1)
+
+ return indices
+
+
+class DFPSSampler(nn.Module):
+ """Using Euclidean distances of points for FPS."""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, points, features, npoint):
+ """Sampling points with D-FPS."""
+ fps_idx = furthest_point_sample(points.contiguous(), npoint)
+ return fps_idx
+
+
+class FFPSSampler(nn.Module):
+ """Using feature distances for FPS."""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, points, features, npoint):
+ """Sampling points with F-FPS."""
+ assert features is not None, \
+ 'feature input to FFPS_Sampler should not be None'
+ features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2)
+ features_dist = calc_square_dist(
+ features_for_fps, features_for_fps, norm=False)
+ fps_idx = furthest_point_sample_with_dist(features_dist, npoint)
+ return fps_idx
+
+
+class FSSampler(nn.Module):
+ """Using F-FPS and D-FPS simultaneously."""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, points, features, npoint):
+ """Sampling points with FS_Sampling."""
+ assert features is not None, \
+ 'feature input to FS_Sampler should not be None'
+ ffps_sampler = FFPSSampler()
+ dfps_sampler = DFPSSampler()
+ fps_idx_ffps = ffps_sampler(points, features, npoint)
+ fps_idx_dfps = dfps_sampler(points, features, npoint)
+ fps_idx = torch.cat([fps_idx_ffps, fps_idx_dfps], dim=1)
+ return fps_idx
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/psa_mask.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/psa_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdf14e62b50e8d4dd6856c94333c703bcc4c9ab6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/psa_mask.py
@@ -0,0 +1,92 @@
+# Modified from https://github.com/hszhao/semseg/blob/master/lib/psa
+from torch import nn
+from torch.autograd import Function
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+ ['psamask_forward', 'psamask_backward'])
+
+
+class PSAMaskFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, psa_type, mask_size):
+ return g.op(
+ 'mmcv::MMCVPSAMask',
+ input,
+ psa_type_i=psa_type,
+ mask_size_i=mask_size)
+
+ @staticmethod
+ def forward(ctx, input, psa_type, mask_size):
+ ctx.psa_type = psa_type
+ ctx.mask_size = _pair(mask_size)
+ ctx.save_for_backward(input)
+
+ h_mask, w_mask = ctx.mask_size
+ batch_size, channels, h_feature, w_feature = input.size()
+ assert channels == h_mask * w_mask
+ output = input.new_zeros(
+ (batch_size, h_feature * w_feature, h_feature, w_feature))
+
+ ext_module.psamask_forward(
+ input,
+ output,
+ psa_type=psa_type,
+ num_=batch_size,
+ h_feature=h_feature,
+ w_feature=w_feature,
+ h_mask=h_mask,
+ w_mask=w_mask,
+ half_h_mask=(h_mask - 1) // 2,
+ half_w_mask=(w_mask - 1) // 2)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input = ctx.saved_tensors[0]
+ psa_type = ctx.psa_type
+ h_mask, w_mask = ctx.mask_size
+ batch_size, channels, h_feature, w_feature = input.size()
+ grad_input = grad_output.new_zeros(
+ (batch_size, channels, h_feature, w_feature))
+ ext_module.psamask_backward(
+ grad_output,
+ grad_input,
+ psa_type=psa_type,
+ num_=batch_size,
+ h_feature=h_feature,
+ w_feature=w_feature,
+ h_mask=h_mask,
+ w_mask=w_mask,
+ half_h_mask=(h_mask - 1) // 2,
+ half_w_mask=(w_mask - 1) // 2)
+ return grad_input, None, None, None
+
+
+psa_mask = PSAMaskFunction.apply
+
+
+class PSAMask(nn.Module):
+
+ def __init__(self, psa_type, mask_size=None):
+ super(PSAMask, self).__init__()
+ assert psa_type in ['collect', 'distribute']
+ if psa_type == 'collect':
+ psa_type_enum = 0
+ else:
+ psa_type_enum = 1
+ self.psa_type_enum = psa_type_enum
+ self.mask_size = mask_size
+ self.psa_type = psa_type
+
+ def forward(self, input):
+ return psa_mask(input, self.psa_type_enum, self.mask_size)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(psa_type={self.psa_type}, '
+ s += f'mask_size={self.mask_size})'
+ return s
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/roi_align.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/roi_align.py
new file mode 100644
index 0000000000000000000000000000000000000000..0755aefc66e67233ceae0f4b77948301c443e9fb
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/roi_align.py
@@ -0,0 +1,223 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import deprecated_api_warning, ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+ ['roi_align_forward', 'roi_align_backward'])
+
+
+class RoIAlignFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
+ pool_mode, aligned):
+ from ..onnx import is_custom_op_loaded
+ has_custom_op = is_custom_op_loaded()
+ if has_custom_op:
+ return g.op(
+ 'mmcv::MMCVRoiAlign',
+ input,
+ rois,
+ output_height_i=output_size[0],
+ output_width_i=output_size[1],
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_i=sampling_ratio,
+ mode_s=pool_mode,
+ aligned_i=aligned)
+ else:
+ from torch.onnx.symbolic_opset9 import sub, squeeze
+ from torch.onnx.symbolic_helper import _slice_helper
+ from torch.onnx import TensorProtoDataType
+ # batch_indices = rois[:, 0].long()
+ batch_indices = _slice_helper(
+ g, rois, axes=[1], starts=[0], ends=[1])
+ batch_indices = squeeze(g, batch_indices, 1)
+ batch_indices = g.op(
+ 'Cast', batch_indices, to_i=TensorProtoDataType.INT64)
+ # rois = rois[:, 1:]
+ rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
+ if aligned:
+ # rois -= 0.5/spatial_scale
+ aligned_offset = g.op(
+ 'Constant',
+ value_t=torch.tensor([0.5 / spatial_scale],
+ dtype=torch.float32))
+ rois = sub(g, rois, aligned_offset)
+ # roi align
+ return g.op(
+ 'RoiAlign',
+ input,
+ rois,
+ batch_indices,
+ output_height_i=output_size[0],
+ output_width_i=output_size[1],
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_i=max(0, sampling_ratio),
+ mode_s=pool_mode)
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ rois,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ pool_mode='avg',
+ aligned=True):
+ ctx.output_size = _pair(output_size)
+ ctx.spatial_scale = spatial_scale
+ ctx.sampling_ratio = sampling_ratio
+ assert pool_mode in ('max', 'avg')
+ ctx.pool_mode = 0 if pool_mode == 'max' else 1
+ ctx.aligned = aligned
+ ctx.input_shape = input.size()
+
+ assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
+
+ output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
+ ctx.output_size[1])
+ output = input.new_zeros(output_shape)
+ if ctx.pool_mode == 0:
+ argmax_y = input.new_zeros(output_shape)
+ argmax_x = input.new_zeros(output_shape)
+ else:
+ argmax_y = input.new_zeros(0)
+ argmax_x = input.new_zeros(0)
+
+ ext_module.roi_align_forward(
+ input,
+ rois,
+ output,
+ argmax_y,
+ argmax_x,
+ aligned_height=ctx.output_size[0],
+ aligned_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ pool_mode=ctx.pool_mode,
+ aligned=ctx.aligned)
+
+ ctx.save_for_backward(rois, argmax_y, argmax_x)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ rois, argmax_y, argmax_x = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(ctx.input_shape)
+ # complex head architecture may cause grad_output uncontiguous.
+ grad_output = grad_output.contiguous()
+ ext_module.roi_align_backward(
+ grad_output,
+ rois,
+ argmax_y,
+ argmax_x,
+ grad_input,
+ aligned_height=ctx.output_size[0],
+ aligned_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale,
+ sampling_ratio=ctx.sampling_ratio,
+ pool_mode=ctx.pool_mode,
+ aligned=ctx.aligned)
+ return grad_input, None, None, None, None, None, None
+
+
+roi_align = RoIAlignFunction.apply
+
+
+class RoIAlign(nn.Module):
+ """RoI align pooling layer.
+
+ Args:
+ output_size (tuple): h, w
+ spatial_scale (float): scale the input boxes by this number
+ sampling_ratio (int): number of inputs samples to take for each
+ output sample. 0 to take samples densely for current models.
+ pool_mode (str, 'avg' or 'max'): pooling mode in each bin.
+ aligned (bool): if False, use the legacy implementation in
+ MMDetection. If True, align the results more perfectly.
+ use_torchvision (bool): whether to use roi_align from torchvision.
+
+ Note:
+ The implementation of RoIAlign when aligned=True is modified from
+ https://github.com/facebookresearch/detectron2/
+
+ The meaning of aligned=True:
+
+ Given a continuous coordinate c, its two neighboring pixel
+ indices (in our pixel model) are computed by floor(c - 0.5) and
+ ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete
+ indices [0] and [1] (which are sampled from the underlying signal
+ at continuous coordinates 0.5 and 1.5). But the original roi_align
+ (aligned=False) does not subtract the 0.5 when computing
+ neighboring pixel indices and therefore it uses pixels with a
+ slightly incorrect alignment (relative to our pixel model) when
+ performing bilinear interpolation.
+
+ With `aligned=True`,
+ we first appropriately scale the ROI and then shift it by -0.5
+ prior to calling roi_align. This produces the correct neighbors;
+
+ The difference does not make a difference to the model's
+ performance if ROIAlign is used together with conv layers.
+ """
+
+ @deprecated_api_warning(
+ {
+ 'out_size': 'output_size',
+ 'sample_num': 'sampling_ratio'
+ },
+ cls_name='RoIAlign')
+ def __init__(self,
+ output_size,
+ spatial_scale=1.0,
+ sampling_ratio=0,
+ pool_mode='avg',
+ aligned=True,
+ use_torchvision=False):
+ super(RoIAlign, self).__init__()
+
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+ self.sampling_ratio = int(sampling_ratio)
+ self.pool_mode = pool_mode
+ self.aligned = aligned
+ self.use_torchvision = use_torchvision
+
+ def forward(self, input, rois):
+ """
+ Args:
+ input: NCHW images
+ rois: Bx5 boxes. First column is the index into N.\
+ The other 4 columns are xyxy.
+ """
+ if self.use_torchvision:
+ from torchvision.ops import roi_align as tv_roi_align
+ if 'aligned' in tv_roi_align.__code__.co_varnames:
+ return tv_roi_align(input, rois, self.output_size,
+ self.spatial_scale, self.sampling_ratio,
+ self.aligned)
+ else:
+ if self.aligned:
+ rois -= rois.new_tensor([0.] +
+ [0.5 / self.spatial_scale] * 4)
+ return tv_roi_align(input, rois, self.output_size,
+ self.spatial_scale, self.sampling_ratio)
+ else:
+ return roi_align(input, rois, self.output_size, self.spatial_scale,
+ self.sampling_ratio, self.pool_mode, self.aligned)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(output_size={self.output_size}, '
+ s += f'spatial_scale={self.spatial_scale}, '
+ s += f'sampling_ratio={self.sampling_ratio}, '
+ s += f'pool_mode={self.pool_mode}, '
+ s += f'aligned={self.aligned}, '
+ s += f'use_torchvision={self.use_torchvision})'
+ return s
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/roi_align_rotated.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/roi_align_rotated.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ce4961a3555d4da8bc3e32f1f7d5ad50036587d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/roi_align_rotated.py
@@ -0,0 +1,177 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['roi_align_rotated_forward', 'roi_align_rotated_backward'])
+
+
+class RoIAlignRotatedFunction(Function):
+
+ @staticmethod
+ def symbolic(g, features, rois, out_size, spatial_scale, sample_num,
+ aligned, clockwise):
+ if isinstance(out_size, int):
+ out_h = out_size
+ out_w = out_size
+ elif isinstance(out_size, tuple):
+ assert len(out_size) == 2
+ assert isinstance(out_size[0], int)
+ assert isinstance(out_size[1], int)
+ out_h, out_w = out_size
+ else:
+ raise TypeError(
+ '"out_size" must be an integer or tuple of integers')
+ return g.op(
+ 'mmcv::MMCVRoIAlignRotated',
+ features,
+ rois,
+ output_height_i=out_h,
+ output_width_i=out_h,
+ spatial_scale_f=spatial_scale,
+ sampling_ratio_i=sample_num,
+ aligned_i=aligned,
+ clockwise_i=clockwise)
+
+ @staticmethod
+ def forward(ctx,
+ features,
+ rois,
+ out_size,
+ spatial_scale,
+ sample_num=0,
+ aligned=True,
+ clockwise=False):
+ if isinstance(out_size, int):
+ out_h = out_size
+ out_w = out_size
+ elif isinstance(out_size, tuple):
+ assert len(out_size) == 2
+ assert isinstance(out_size[0], int)
+ assert isinstance(out_size[1], int)
+ out_h, out_w = out_size
+ else:
+ raise TypeError(
+ '"out_size" must be an integer or tuple of integers')
+ ctx.spatial_scale = spatial_scale
+ ctx.sample_num = sample_num
+ ctx.aligned = aligned
+ ctx.clockwise = clockwise
+ ctx.save_for_backward(rois)
+ ctx.feature_size = features.size()
+
+ batch_size, num_channels, data_height, data_width = features.size()
+ num_rois = rois.size(0)
+
+ output = features.new_zeros(num_rois, num_channels, out_h, out_w)
+ ext_module.roi_align_rotated_forward(
+ features,
+ rois,
+ output,
+ pooled_height=out_h,
+ pooled_width=out_w,
+ spatial_scale=spatial_scale,
+ sample_num=sample_num,
+ aligned=aligned,
+ clockwise=clockwise)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ feature_size = ctx.feature_size
+ spatial_scale = ctx.spatial_scale
+ aligned = ctx.aligned
+ clockwise = ctx.clockwise
+ sample_num = ctx.sample_num
+ rois = ctx.saved_tensors[0]
+ assert feature_size is not None
+ batch_size, num_channels, data_height, data_width = feature_size
+
+ out_w = grad_output.size(3)
+ out_h = grad_output.size(2)
+
+ grad_input = grad_rois = None
+
+ if ctx.needs_input_grad[0]:
+ grad_input = rois.new_zeros(batch_size, num_channels, data_height,
+ data_width)
+ ext_module.roi_align_rotated_backward(
+ grad_output.contiguous(),
+ rois,
+ grad_input,
+ pooled_height=out_h,
+ pooled_width=out_w,
+ spatial_scale=spatial_scale,
+ sample_num=sample_num,
+ aligned=aligned,
+ clockwise=clockwise)
+ return grad_input, grad_rois, None, None, None, None, None
+
+
+roi_align_rotated = RoIAlignRotatedFunction.apply
+
+
+class RoIAlignRotated(nn.Module):
+ """RoI align pooling layer for rotated proposals.
+
+ It accepts a feature map of shape (N, C, H, W) and rois with shape
+ (n, 6) with each roi decoded as (batch_index, center_x, center_y,
+ w, h, angle). The angle is in radian.
+
+ Args:
+ out_size (tuple): h, w
+ spatial_scale (float): scale the input boxes by this number
+ sample_num (int): number of inputs samples to take for each
+ output sample. 0 to take samples densely for current models.
+ aligned (bool): if False, use the legacy implementation in
+ MMDetection. If True, align the results more perfectly.
+ Default: True.
+ clockwise (bool): If True, the angle in each proposal follows a
+ clockwise fashion in image space, otherwise, the angle is
+ counterclockwise. Default: False.
+
+ Note:
+ The implementation of RoIAlign when aligned=True is modified from
+ https://github.com/facebookresearch/detectron2/
+
+ The meaning of aligned=True:
+
+ Given a continuous coordinate c, its two neighboring pixel
+ indices (in our pixel model) are computed by floor(c - 0.5) and
+ ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete
+ indices [0] and [1] (which are sampled from the underlying signal
+ at continuous coordinates 0.5 and 1.5). But the original roi_align
+ (aligned=False) does not subtract the 0.5 when computing
+ neighboring pixel indices and therefore it uses pixels with a
+ slightly incorrect alignment (relative to our pixel model) when
+ performing bilinear interpolation.
+
+ With `aligned=True`,
+ we first appropriately scale the ROI and then shift it by -0.5
+ prior to calling roi_align. This produces the correct neighbors;
+
+ The difference does not make a difference to the model's
+ performance if ROIAlign is used together with conv layers.
+ """
+
+ def __init__(self,
+ out_size,
+ spatial_scale,
+ sample_num=0,
+ aligned=True,
+ clockwise=False):
+ super(RoIAlignRotated, self).__init__()
+
+ self.out_size = out_size
+ self.spatial_scale = float(spatial_scale)
+ self.sample_num = int(sample_num)
+ self.aligned = aligned
+ self.clockwise = clockwise
+
+ def forward(self, features, rois):
+ return RoIAlignRotatedFunction.apply(features, rois, self.out_size,
+ self.spatial_scale,
+ self.sample_num, self.aligned,
+ self.clockwise)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/roi_pool.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/roi_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..d339d8f2941eabc1cbe181a9c6c5ab5ff4ff4e5f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/roi_pool.py
@@ -0,0 +1,86 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+ ['roi_pool_forward', 'roi_pool_backward'])
+
+
+class RoIPoolFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, rois, output_size, spatial_scale):
+ return g.op(
+ 'MaxRoiPool',
+ input,
+ rois,
+ pooled_shape_i=output_size,
+ spatial_scale_f=spatial_scale)
+
+ @staticmethod
+ def forward(ctx, input, rois, output_size, spatial_scale=1.0):
+ ctx.output_size = _pair(output_size)
+ ctx.spatial_scale = spatial_scale
+ ctx.input_shape = input.size()
+
+ assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
+
+ output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
+ ctx.output_size[1])
+ output = input.new_zeros(output_shape)
+ argmax = input.new_zeros(output_shape, dtype=torch.int)
+
+ ext_module.roi_pool_forward(
+ input,
+ rois,
+ output,
+ argmax,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale)
+
+ ctx.save_for_backward(rois, argmax)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ rois, argmax = ctx.saved_tensors
+ grad_input = grad_output.new_zeros(ctx.input_shape)
+
+ ext_module.roi_pool_backward(
+ grad_output,
+ rois,
+ argmax,
+ grad_input,
+ pooled_height=ctx.output_size[0],
+ pooled_width=ctx.output_size[1],
+ spatial_scale=ctx.spatial_scale)
+
+ return grad_input, None, None, None
+
+
+roi_pool = RoIPoolFunction.apply
+
+
+class RoIPool(nn.Module):
+
+ def __init__(self, output_size, spatial_scale=1.0):
+ super(RoIPool, self).__init__()
+
+ self.output_size = _pair(output_size)
+ self.spatial_scale = float(spatial_scale)
+
+ def forward(self, input, rois):
+ return roi_pool(input, rois, self.output_size, self.spatial_scale)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'(output_size={self.output_size}, '
+ s += f'spatial_scale={self.spatial_scale})'
+ return s
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/roiaware_pool3d.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/roiaware_pool3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..8191920ca50b388ef58f577dc986da101662ac53
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/roiaware_pool3d.py
@@ -0,0 +1,114 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+
+import annotator.mmpkg.mmcv as mmcv
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['roiaware_pool3d_forward', 'roiaware_pool3d_backward'])
+
+
+class RoIAwarePool3d(nn.Module):
+ """Encode the geometry-specific features of each 3D proposal.
+
+ Please refer to `PartA2 `_ for more
+ details.
+
+ Args:
+ out_size (int or tuple): The size of output features. n or
+ [n1, n2, n3].
+ max_pts_per_voxel (int, optional): The maximum number of points per
+ voxel. Default: 128.
+ mode (str, optional): Pooling method of RoIAware, 'max' or 'avg'.
+ Default: 'max'.
+ """
+
+ def __init__(self, out_size, max_pts_per_voxel=128, mode='max'):
+ super().__init__()
+
+ self.out_size = out_size
+ self.max_pts_per_voxel = max_pts_per_voxel
+ assert mode in ['max', 'avg']
+ pool_mapping = {'max': 0, 'avg': 1}
+ self.mode = pool_mapping[mode]
+
+ def forward(self, rois, pts, pts_feature):
+ """
+ Args:
+ rois (torch.Tensor): [N, 7], in LiDAR coordinate,
+ (x, y, z) is the bottom center of rois.
+ pts (torch.Tensor): [npoints, 3], coordinates of input points.
+ pts_feature (torch.Tensor): [npoints, C], features of input points.
+
+ Returns:
+ pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
+ """
+
+ return RoIAwarePool3dFunction.apply(rois, pts, pts_feature,
+ self.out_size,
+ self.max_pts_per_voxel, self.mode)
+
+
+class RoIAwarePool3dFunction(Function):
+
+ @staticmethod
+ def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
+ mode):
+ """
+ Args:
+ rois (torch.Tensor): [N, 7], in LiDAR coordinate,
+ (x, y, z) is the bottom center of rois.
+ pts (torch.Tensor): [npoints, 3], coordinates of input points.
+ pts_feature (torch.Tensor): [npoints, C], features of input points.
+ out_size (int or tuple): The size of output features. n or
+ [n1, n2, n3].
+ max_pts_per_voxel (int): The maximum number of points per voxel.
+ Default: 128.
+ mode (int): Pooling method of RoIAware, 0 (max pool) or 1 (average
+ pool).
+
+ Returns:
+ pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C], output
+ pooled features.
+ """
+
+ if isinstance(out_size, int):
+ out_x = out_y = out_z = out_size
+ else:
+ assert len(out_size) == 3
+ assert mmcv.is_tuple_of(out_size, int)
+ out_x, out_y, out_z = out_size
+
+ num_rois = rois.shape[0]
+ num_channels = pts_feature.shape[-1]
+ num_pts = pts.shape[0]
+
+ pooled_features = pts_feature.new_zeros(
+ (num_rois, out_x, out_y, out_z, num_channels))
+ argmax = pts_feature.new_zeros(
+ (num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int)
+ pts_idx_of_voxels = pts_feature.new_zeros(
+ (num_rois, out_x, out_y, out_z, max_pts_per_voxel),
+ dtype=torch.int)
+
+ ext_module.roiaware_pool3d_forward(rois, pts, pts_feature, argmax,
+ pts_idx_of_voxels, pooled_features,
+ mode)
+
+ ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode,
+ num_pts, num_channels)
+ return pooled_features
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ ret = ctx.roiaware_pool3d_for_backward
+ pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret
+
+ grad_in = grad_out.new_zeros((num_pts, num_channels))
+ ext_module.roiaware_pool3d_backward(pts_idx_of_voxels, argmax,
+ grad_out.contiguous(), grad_in,
+ mode)
+
+ return None, None, grad_in, None, None, None
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/roipoint_pool3d.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/roipoint_pool3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a21412c0728431c04b84245bc2e3109eea9aefc
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/roipoint_pool3d.py
@@ -0,0 +1,77 @@
+from torch import nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['roipoint_pool3d_forward'])
+
+
+class RoIPointPool3d(nn.Module):
+ """Encode the geometry-specific features of each 3D proposal.
+
+ Please refer to `Paper of PartA2 `_
+ for more details.
+
+ Args:
+ num_sampled_points (int, optional): Number of samples in each roi.
+ Default: 512.
+ """
+
+ def __init__(self, num_sampled_points=512):
+ super().__init__()
+ self.num_sampled_points = num_sampled_points
+
+ def forward(self, points, point_features, boxes3d):
+ """
+ Args:
+ points (torch.Tensor): Input points whose shape is (B, N, C).
+ point_features (torch.Tensor): Features of input points whose shape
+ is (B, N, C).
+ boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7).
+
+ Returns:
+ pooled_features (torch.Tensor): The output pooled features whose
+ shape is (B, M, 512, 3 + C).
+ pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M).
+ """
+ return RoIPointPool3dFunction.apply(points, point_features, boxes3d,
+ self.num_sampled_points)
+
+
+class RoIPointPool3dFunction(Function):
+
+ @staticmethod
+ def forward(ctx, points, point_features, boxes3d, num_sampled_points=512):
+ """
+ Args:
+ points (torch.Tensor): Input points whose shape is (B, N, C).
+ point_features (torch.Tensor): Features of input points whose shape
+ is (B, N, C).
+ boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7).
+ num_sampled_points (int, optional): The num of sampled points.
+ Default: 512.
+
+ Returns:
+ pooled_features (torch.Tensor): The output pooled features whose
+ shape is (B, M, 512, 3 + C).
+ pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M).
+ """
+ assert len(points.shape) == 3 and points.shape[2] == 3
+ batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[
+ 1], point_features.shape[2]
+ pooled_boxes3d = boxes3d.view(batch_size, -1, 7)
+ pooled_features = point_features.new_zeros(
+ (batch_size, boxes_num, num_sampled_points, 3 + feature_len))
+ pooled_empty_flag = point_features.new_zeros(
+ (batch_size, boxes_num)).int()
+
+ ext_module.roipoint_pool3d_forward(points.contiguous(),
+ pooled_boxes3d.contiguous(),
+ point_features.contiguous(),
+ pooled_features, pooled_empty_flag)
+
+ return pooled_features, pooled_empty_flag
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ raise NotImplementedError
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/saconv.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/saconv.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d7be88c428ea2b9af2c32c60a86dddd13988ce8
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/saconv.py
@@ -0,0 +1,145 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from annotator.mmpkg.mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
+from annotator.mmpkg.mmcv.ops.deform_conv import deform_conv2d
+from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version
+
+
+@CONV_LAYERS.register_module(name='SAC')
+class SAConv2d(ConvAWS2d):
+ """SAC (Switchable Atrous Convolution)
+
+ This is an implementation of SAC in DetectoRS
+ (https://arxiv.org/pdf/2006.02334.pdf).
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 0
+ padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
+ ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 1
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If ``True``, adds a learnable bias to the
+ output. Default: ``True``
+ use_deform: If ``True``, replace convolution with deformable
+ convolution. Default: ``False``.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ use_deform=False):
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias)
+ self.use_deform = use_deform
+ self.switch = nn.Conv2d(
+ self.in_channels, 1, kernel_size=1, stride=stride, bias=True)
+ self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size()))
+ self.pre_context = nn.Conv2d(
+ self.in_channels, self.in_channels, kernel_size=1, bias=True)
+ self.post_context = nn.Conv2d(
+ self.out_channels, self.out_channels, kernel_size=1, bias=True)
+ if self.use_deform:
+ self.offset_s = nn.Conv2d(
+ self.in_channels,
+ 18,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ bias=True)
+ self.offset_l = nn.Conv2d(
+ self.in_channels,
+ 18,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ bias=True)
+ self.init_weights()
+
+ def init_weights(self):
+ constant_init(self.switch, 0, bias=1)
+ self.weight_diff.data.zero_()
+ constant_init(self.pre_context, 0)
+ constant_init(self.post_context, 0)
+ if self.use_deform:
+ constant_init(self.offset_s, 0)
+ constant_init(self.offset_l, 0)
+
+ def forward(self, x):
+ # pre-context
+ avg_x = F.adaptive_avg_pool2d(x, output_size=1)
+ avg_x = self.pre_context(avg_x)
+ avg_x = avg_x.expand_as(x)
+ x = x + avg_x
+ # switch
+ avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect')
+ avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
+ switch = self.switch(avg_x)
+ # sac
+ weight = self._get_weight(self.weight)
+ zero_bias = torch.zeros(
+ self.out_channels, device=weight.device, dtype=weight.dtype)
+
+ if self.use_deform:
+ offset = self.offset_s(avg_x)
+ out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
+ self.dilation, self.groups, 1)
+ else:
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
+ out_s = super().conv2d_forward(x, weight)
+ elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
+ # bias is a required argument of _conv_forward in torch 1.8.0
+ out_s = super()._conv_forward(x, weight, zero_bias)
+ else:
+ out_s = super()._conv_forward(x, weight)
+ ori_p = self.padding
+ ori_d = self.dilation
+ self.padding = tuple(3 * p for p in self.padding)
+ self.dilation = tuple(3 * d for d in self.dilation)
+ weight = weight + self.weight_diff
+ if self.use_deform:
+ offset = self.offset_l(avg_x)
+ out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
+ self.dilation, self.groups, 1)
+ else:
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
+ out_l = super().conv2d_forward(x, weight)
+ elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
+ # bias is a required argument of _conv_forward in torch 1.8.0
+ out_l = super()._conv_forward(x, weight, zero_bias)
+ else:
+ out_l = super()._conv_forward(x, weight)
+
+ out = switch * out_s + (1 - switch) * out_l
+ self.padding = ori_p
+ self.dilation = ori_d
+ # post-context
+ avg_x = F.adaptive_avg_pool2d(out, output_size=1)
+ avg_x = self.post_context(avg_x)
+ avg_x = avg_x.expand_as(out)
+ out = out + avg_x
+ return out
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/scatter_points.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/scatter_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b8aa4169e9f6ca4a6f845ce17d6d1e4db416bb8
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/scatter_points.py
@@ -0,0 +1,135 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext',
+ ['dynamic_point_to_voxel_forward', 'dynamic_point_to_voxel_backward'])
+
+
+class _DynamicScatter(Function):
+
+ @staticmethod
+ def forward(ctx, feats, coors, reduce_type='max'):
+ """convert kitti points(N, >=3) to voxels.
+
+ Args:
+ feats (torch.Tensor): [N, C]. Points features to be reduced
+ into voxels.
+ coors (torch.Tensor): [N, ndim]. Corresponding voxel coordinates
+ (specifically multi-dim voxel index) of each points.
+ reduce_type (str, optional): Reduce op. support 'max', 'sum' and
+ 'mean'. Default: 'max'.
+
+ Returns:
+ voxel_feats (torch.Tensor): [M, C]. Reduced features, input
+ features that shares the same voxel coordinates are reduced to
+ one row.
+ voxel_coors (torch.Tensor): [M, ndim]. Voxel coordinates.
+ """
+ results = ext_module.dynamic_point_to_voxel_forward(
+ feats, coors, reduce_type)
+ (voxel_feats, voxel_coors, point2voxel_map,
+ voxel_points_count) = results
+ ctx.reduce_type = reduce_type
+ ctx.save_for_backward(feats, voxel_feats, point2voxel_map,
+ voxel_points_count)
+ ctx.mark_non_differentiable(voxel_coors)
+ return voxel_feats, voxel_coors
+
+ @staticmethod
+ def backward(ctx, grad_voxel_feats, grad_voxel_coors=None):
+ (feats, voxel_feats, point2voxel_map,
+ voxel_points_count) = ctx.saved_tensors
+ grad_feats = torch.zeros_like(feats)
+ # TODO: whether to use index put or use cuda_backward
+ # To use index put, need point to voxel index
+ ext_module.dynamic_point_to_voxel_backward(
+ grad_feats, grad_voxel_feats.contiguous(), feats, voxel_feats,
+ point2voxel_map, voxel_points_count, ctx.reduce_type)
+ return grad_feats, None, None
+
+
+dynamic_scatter = _DynamicScatter.apply
+
+
+class DynamicScatter(nn.Module):
+ """Scatters points into voxels, used in the voxel encoder with dynamic
+ voxelization.
+
+ Note:
+ The CPU and GPU implementation get the same output, but have numerical
+ difference after summation and division (e.g., 5e-7).
+
+ Args:
+ voxel_size (list): list [x, y, z] size of three dimension.
+ point_cloud_range (list): The coordinate range of points, [x_min,
+ y_min, z_min, x_max, y_max, z_max].
+ average_points (bool): whether to use avg pooling to scatter points
+ into voxel.
+ """
+
+ def __init__(self, voxel_size, point_cloud_range, average_points: bool):
+ super().__init__()
+
+ self.voxel_size = voxel_size
+ self.point_cloud_range = point_cloud_range
+ self.average_points = average_points
+
+ def forward_single(self, points, coors):
+ """Scatters points into voxels.
+
+ Args:
+ points (torch.Tensor): Points to be reduced into voxels.
+ coors (torch.Tensor): Corresponding voxel coordinates (specifically
+ multi-dim voxel index) of each points.
+
+ Returns:
+ voxel_feats (torch.Tensor): Reduced features, input features that
+ shares the same voxel coordinates are reduced to one row.
+ voxel_coors (torch.Tensor): Voxel coordinates.
+ """
+ reduce = 'mean' if self.average_points else 'max'
+ return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce)
+
+ def forward(self, points, coors):
+ """Scatters points/features into voxels.
+
+ Args:
+ points (torch.Tensor): Points to be reduced into voxels.
+ coors (torch.Tensor): Corresponding voxel coordinates (specifically
+ multi-dim voxel index) of each points.
+
+ Returns:
+ voxel_feats (torch.Tensor): Reduced features, input features that
+ shares the same voxel coordinates are reduced to one row.
+ voxel_coors (torch.Tensor): Voxel coordinates.
+ """
+ if coors.size(-1) == 3:
+ return self.forward_single(points, coors)
+ else:
+ batch_size = coors[-1, 0] + 1
+ voxels, voxel_coors = [], []
+ for i in range(batch_size):
+ inds = torch.where(coors[:, 0] == i)
+ voxel, voxel_coor = self.forward_single(
+ points[inds], coors[inds][:, 1:])
+ coor_pad = nn.functional.pad(
+ voxel_coor, (1, 0), mode='constant', value=i)
+ voxel_coors.append(coor_pad)
+ voxels.append(voxel)
+ features = torch.cat(voxels, dim=0)
+ feature_coors = torch.cat(voxel_coors, dim=0)
+
+ return features, feature_coors
+
+ def __repr__(self):
+ s = self.__class__.__name__ + '('
+ s += 'voxel_size=' + str(self.voxel_size)
+ s += ', point_cloud_range=' + str(self.point_cloud_range)
+ s += ', average_points=' + str(self.average_points)
+ s += ')'
+ return s
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/sync_bn.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/sync_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..46db9200f9eafbad662a04e71f60a099a3178346
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/sync_bn.py
@@ -0,0 +1,279 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.module import Module
+from torch.nn.parameter import Parameter
+
+from annotator.mmpkg.mmcv.cnn import NORM_LAYERS
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+ 'sync_bn_forward_mean', 'sync_bn_forward_var', 'sync_bn_forward_output',
+ 'sync_bn_backward_param', 'sync_bn_backward_data'
+])
+
+
+class SyncBatchNormFunction(Function):
+
+ @staticmethod
+ def symbolic(g, input, running_mean, running_var, weight, bias, momentum,
+ eps, group, group_size, stats_mode):
+ return g.op(
+ 'mmcv::MMCVSyncBatchNorm',
+ input,
+ running_mean,
+ running_var,
+ weight,
+ bias,
+ momentum_f=momentum,
+ eps_f=eps,
+ group_i=group,
+ group_size_i=group_size,
+ stats_mode=stats_mode)
+
+ @staticmethod
+ def forward(self, input, running_mean, running_var, weight, bias, momentum,
+ eps, group, group_size, stats_mode):
+ self.momentum = momentum
+ self.eps = eps
+ self.group = group
+ self.group_size = group_size
+ self.stats_mode = stats_mode
+
+ assert isinstance(
+ input, (torch.HalfTensor, torch.FloatTensor,
+ torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \
+ f'only support Half or Float Tensor, but {input.type()}'
+ output = torch.zeros_like(input)
+ input3d = input.flatten(start_dim=2)
+ output3d = output.view_as(input3d)
+ num_channels = input3d.size(1)
+
+ # ensure mean/var/norm/std are initialized as zeros
+ # ``torch.empty()`` does not guarantee that
+ mean = torch.zeros(
+ num_channels, dtype=torch.float, device=input3d.device)
+ var = torch.zeros(
+ num_channels, dtype=torch.float, device=input3d.device)
+ norm = torch.zeros_like(
+ input3d, dtype=torch.float, device=input3d.device)
+ std = torch.zeros(
+ num_channels, dtype=torch.float, device=input3d.device)
+
+ batch_size = input3d.size(0)
+ if batch_size > 0:
+ ext_module.sync_bn_forward_mean(input3d, mean)
+ batch_flag = torch.ones([1], device=mean.device, dtype=mean.dtype)
+ else:
+ # skip updating mean and leave it as zeros when the input is empty
+ batch_flag = torch.zeros([1], device=mean.device, dtype=mean.dtype)
+
+ # synchronize mean and the batch flag
+ vec = torch.cat([mean, batch_flag])
+ if self.stats_mode == 'N':
+ vec *= batch_size
+ if self.group_size > 1:
+ dist.all_reduce(vec, group=self.group)
+ total_batch = vec[-1].detach()
+ mean = vec[:num_channels]
+
+ if self.stats_mode == 'default':
+ mean = mean / self.group_size
+ elif self.stats_mode == 'N':
+ mean = mean / total_batch.clamp(min=1)
+ else:
+ raise NotImplementedError
+
+ # leave var as zeros when the input is empty
+ if batch_size > 0:
+ ext_module.sync_bn_forward_var(input3d, mean, var)
+
+ if self.stats_mode == 'N':
+ var *= batch_size
+ if self.group_size > 1:
+ dist.all_reduce(var, group=self.group)
+
+ if self.stats_mode == 'default':
+ var /= self.group_size
+ elif self.stats_mode == 'N':
+ var /= total_batch.clamp(min=1)
+ else:
+ raise NotImplementedError
+
+ # if the total batch size over all the ranks is zero,
+ # we should not update the statistics in the current batch
+ update_flag = total_batch.clamp(max=1)
+ momentum = update_flag * self.momentum
+ ext_module.sync_bn_forward_output(
+ input3d,
+ mean,
+ var,
+ weight,
+ bias,
+ running_mean,
+ running_var,
+ norm,
+ std,
+ output3d,
+ eps=self.eps,
+ momentum=momentum,
+ group_size=self.group_size)
+ self.save_for_backward(norm, std, weight)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(self, grad_output):
+ norm, std, weight = self.saved_tensors
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(weight)
+ grad_input = torch.zeros_like(grad_output)
+ grad_output3d = grad_output.flatten(start_dim=2)
+ grad_input3d = grad_input.view_as(grad_output3d)
+
+ batch_size = grad_input3d.size(0)
+ if batch_size > 0:
+ ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight,
+ grad_bias)
+
+ # all reduce
+ if self.group_size > 1:
+ dist.all_reduce(grad_weight, group=self.group)
+ dist.all_reduce(grad_bias, group=self.group)
+ grad_weight /= self.group_size
+ grad_bias /= self.group_size
+
+ if batch_size > 0:
+ ext_module.sync_bn_backward_data(grad_output3d, weight,
+ grad_weight, grad_bias, norm, std,
+ grad_input3d)
+
+ return grad_input, None, None, grad_weight, grad_bias, \
+ None, None, None, None, None
+
+
+@NORM_LAYERS.register_module(name='MMSyncBN')
+class SyncBatchNorm(Module):
+ """Synchronized Batch Normalization.
+
+ Args:
+ num_features (int): number of features/chennels in input tensor
+ eps (float, optional): a value added to the denominator for numerical
+ stability. Defaults to 1e-5.
+ momentum (float, optional): the value used for the running_mean and
+ running_var computation. Defaults to 0.1.
+ affine (bool, optional): whether to use learnable affine parameters.
+ Defaults to True.
+ track_running_stats (bool, optional): whether to track the running
+ mean and variance during training. When set to False, this
+ module does not track such statistics, and initializes statistics
+ buffers ``running_mean`` and ``running_var`` as ``None``. When
+ these buffers are ``None``, this module always uses batch
+ statistics in both training and eval modes. Defaults to True.
+ group (int, optional): synchronization of stats happen within
+ each process group individually. By default it is synchronization
+ across the whole world. Defaults to None.
+ stats_mode (str, optional): The statistical mode. Available options
+ includes ``'default'`` and ``'N'``. Defaults to 'default'.
+ When ``stats_mode=='default'``, it computes the overall statistics
+ using those from each worker with equal weight, i.e., the
+ statistics are synchronized and simply divied by ``group``. This
+ mode will produce inaccurate statistics when empty tensors occur.
+ When ``stats_mode=='N'``, it compute the overall statistics using
+ the total number of batches in each worker ignoring the number of
+ group, i.e., the statistics are synchronized and then divied by
+ the total batch ``N``. This mode is beneficial when empty tensors
+ occur during training, as it average the total mean by the real
+ number of batch.
+ """
+
+ def __init__(self,
+ num_features,
+ eps=1e-5,
+ momentum=0.1,
+ affine=True,
+ track_running_stats=True,
+ group=None,
+ stats_mode='default'):
+ super(SyncBatchNorm, self).__init__()
+ self.num_features = num_features
+ self.eps = eps
+ self.momentum = momentum
+ self.affine = affine
+ self.track_running_stats = track_running_stats
+ group = dist.group.WORLD if group is None else group
+ self.group = group
+ self.group_size = dist.get_world_size(group)
+ assert stats_mode in ['default', 'N'], \
+ f'"stats_mode" only accepts "default" and "N", got "{stats_mode}"'
+ self.stats_mode = stats_mode
+ if self.affine:
+ self.weight = Parameter(torch.Tensor(num_features))
+ self.bias = Parameter(torch.Tensor(num_features))
+ else:
+ self.register_parameter('weight', None)
+ self.register_parameter('bias', None)
+ if self.track_running_stats:
+ self.register_buffer('running_mean', torch.zeros(num_features))
+ self.register_buffer('running_var', torch.ones(num_features))
+ self.register_buffer('num_batches_tracked',
+ torch.tensor(0, dtype=torch.long))
+ else:
+ self.register_buffer('running_mean', None)
+ self.register_buffer('running_var', None)
+ self.register_buffer('num_batches_tracked', None)
+ self.reset_parameters()
+
+ def reset_running_stats(self):
+ if self.track_running_stats:
+ self.running_mean.zero_()
+ self.running_var.fill_(1)
+ self.num_batches_tracked.zero_()
+
+ def reset_parameters(self):
+ self.reset_running_stats()
+ if self.affine:
+ self.weight.data.uniform_() # pytorch use ones_()
+ self.bias.data.zero_()
+
+ def forward(self, input):
+ if input.dim() < 2:
+ raise ValueError(
+ f'expected at least 2D input, got {input.dim()}D input')
+ if self.momentum is None:
+ exponential_average_factor = 0.0
+ else:
+ exponential_average_factor = self.momentum
+
+ if self.training and self.track_running_stats:
+ if self.num_batches_tracked is not None:
+ self.num_batches_tracked += 1
+ if self.momentum is None: # use cumulative moving average
+ exponential_average_factor = 1.0 / float(
+ self.num_batches_tracked)
+ else: # use exponential moving average
+ exponential_average_factor = self.momentum
+
+ if self.training or not self.track_running_stats:
+ return SyncBatchNormFunction.apply(
+ input, self.running_mean, self.running_var, self.weight,
+ self.bias, exponential_average_factor, self.eps, self.group,
+ self.group_size, self.stats_mode)
+ else:
+ return F.batch_norm(input, self.running_mean, self.running_var,
+ self.weight, self.bias, False,
+ exponential_average_factor, self.eps)
+
+ def __repr__(self):
+ s = self.__class__.__name__
+ s += f'({self.num_features}, '
+ s += f'eps={self.eps}, '
+ s += f'momentum={self.momentum}, '
+ s += f'affine={self.affine}, '
+ s += f'track_running_stats={self.track_running_stats}, '
+ s += f'group_size={self.group_size},'
+ s += f'stats_mode={self.stats_mode})'
+ return s
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/three_interpolate.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/three_interpolate.py
new file mode 100644
index 0000000000000000000000000000000000000000..203f47f05d58087e034fb3cd8cd6a09233947b4a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/three_interpolate.py
@@ -0,0 +1,68 @@
+from typing import Tuple
+
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['three_interpolate_forward', 'three_interpolate_backward'])
+
+
+class ThreeInterpolate(Function):
+ """Performs weighted linear interpolation on 3 features.
+
+ Please refer to `Paper of PointNet++ `_
+ for more details.
+ """
+
+ @staticmethod
+ def forward(ctx, features: torch.Tensor, indices: torch.Tensor,
+ weight: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ features (Tensor): (B, C, M) Features descriptors to be
+ interpolated
+ indices (Tensor): (B, n, 3) index three nearest neighbors
+ of the target features in features
+ weight (Tensor): (B, n, 3) weights of interpolation
+
+ Returns:
+ Tensor: (B, C, N) tensor of the interpolated features
+ """
+ assert features.is_contiguous()
+ assert indices.is_contiguous()
+ assert weight.is_contiguous()
+
+ B, c, m = features.size()
+ n = indices.size(1)
+ ctx.three_interpolate_for_backward = (indices, weight, m)
+ output = torch.cuda.FloatTensor(B, c, n)
+
+ ext_module.three_interpolate_forward(
+ features, indices, weight, output, b=B, c=c, m=m, n=n)
+ return output
+
+ @staticmethod
+ def backward(
+ ctx, grad_out: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ grad_out (Tensor): (B, C, N) tensor with gradients of outputs
+
+ Returns:
+ Tensor: (B, C, M) tensor with gradients of features
+ """
+ idx, weight, m = ctx.three_interpolate_for_backward
+ B, c, n = grad_out.size()
+
+ grad_features = torch.cuda.FloatTensor(B, c, m).zero_()
+ grad_out_data = grad_out.data.contiguous()
+
+ ext_module.three_interpolate_backward(
+ grad_out_data, idx, weight, grad_features.data, b=B, c=c, n=n, m=m)
+ return grad_features, None, None
+
+
+three_interpolate = ThreeInterpolate.apply
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/three_nn.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/three_nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b01047a129989cd5545a0a86f23a487f4a13ce1
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/three_nn.py
@@ -0,0 +1,51 @@
+from typing import Tuple
+
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['three_nn_forward'])
+
+
+class ThreeNN(Function):
+ """Find the top-3 nearest neighbors of the target set from the source set.
+
+ Please refer to `Paper of PointNet++ `_
+ for more details.
+ """
+
+ @staticmethod
+ def forward(ctx, target: torch.Tensor,
+ source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ target (Tensor): shape (B, N, 3), points set that needs to
+ find the nearest neighbors.
+ source (Tensor): shape (B, M, 3), points set that is used
+ to find the nearest neighbors of points in target set.
+
+ Returns:
+ Tensor: shape (B, N, 3), L2 distance of each point in target
+ set to their corresponding nearest neighbors.
+ """
+ target = target.contiguous()
+ source = source.contiguous()
+
+ B, N, _ = target.size()
+ m = source.size(1)
+ dist2 = torch.cuda.FloatTensor(B, N, 3)
+ idx = torch.cuda.IntTensor(B, N, 3)
+
+ ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m)
+ if torch.__version__ != 'parrots':
+ ctx.mark_non_differentiable(idx)
+
+ return torch.sqrt(dist2), idx
+
+ @staticmethod
+ def backward(ctx, a=None, b=None):
+ return None, None
+
+
+three_nn = ThreeNN.apply
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/tin_shift.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/tin_shift.py
new file mode 100644
index 0000000000000000000000000000000000000000..472c9fcfe45a124e819b7ed5653e585f94a8811e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/tin_shift.py
@@ -0,0 +1,68 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Code reference from "Temporal Interlacing Network"
+# https://github.com/deepcs233/TIN/blob/master/cuda_shift/rtc_wrap.py
+# Hao Shao, Shengju Qian, Yu Liu
+# shaoh19@mails.tsinghua.edu.cn, sjqian@cse.cuhk.edu.hk, yuliu@ee.cuhk.edu.hk
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+ ['tin_shift_forward', 'tin_shift_backward'])
+
+
+class TINShiftFunction(Function):
+
+ @staticmethod
+ def forward(ctx, input, shift):
+ C = input.size(2)
+ num_segments = shift.size(1)
+ if C // num_segments <= 0 or C % num_segments != 0:
+ raise ValueError('C should be a multiple of num_segments, '
+ f'but got C={C} and num_segments={num_segments}.')
+
+ ctx.save_for_backward(shift)
+
+ out = torch.zeros_like(input)
+ ext_module.tin_shift_forward(input, shift, out)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+
+ shift = ctx.saved_tensors[0]
+ data_grad_input = grad_output.new(*grad_output.size()).zero_()
+ shift_grad_input = shift.new(*shift.size()).zero_()
+ ext_module.tin_shift_backward(grad_output, shift, data_grad_input)
+
+ return data_grad_input, shift_grad_input
+
+
+tin_shift = TINShiftFunction.apply
+
+
+class TINShift(nn.Module):
+ """Temporal Interlace Shift.
+
+ Temporal Interlace shift is a differentiable temporal-wise frame shifting
+ which is proposed in "Temporal Interlacing Network"
+
+ Please refer to https://arxiv.org/abs/2001.06499 for more details.
+ Code is modified from https://github.com/mit-han-lab/temporal-shift-module
+ """
+
+ def forward(self, input, shift):
+ """Perform temporal interlace shift.
+
+ Args:
+ input (Tensor): Feature map with shape [N, num_segments, C, H * W].
+ shift (Tensor): Shift tensor with shape [N, num_segments].
+
+ Returns:
+ Feature map after temporal interlace shift.
+ """
+ return tin_shift(input, shift)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/upfirdn2d.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..751db20a344e1164748d8d4d8b2f775247925eab
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/upfirdn2d.py
@@ -0,0 +1,330 @@
+# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
+
+# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
+# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator
+# Augmentation (ADA)
+# =======================================================================
+
+# 1. Definitions
+
+# "Licensor" means any person or entity that distributes its Work.
+
+# "Software" means the original work of authorship made available under
+# this License.
+
+# "Work" means the Software and any additions to or derivative works of
+# the Software that are made available under this License.
+
+# The terms "reproduce," "reproduction," "derivative works," and
+# "distribution" have the meaning as provided under U.S. copyright law;
+# provided, however, that for the purposes of this License, derivative
+# works shall not include works that remain separable from, or merely
+# link (or bind by name) to the interfaces of, the Work.
+
+# Works, including the Software, are "made available" under this License
+# by including in or with the Work either (a) a copyright notice
+# referencing the applicability of this License to the Work, or (b) a
+# copy of this License.
+
+# 2. License Grants
+
+# 2.1 Copyright Grant. Subject to the terms and conditions of this
+# License, each Licensor grants to you a perpetual, worldwide,
+# non-exclusive, royalty-free, copyright license to reproduce,
+# prepare derivative works of, publicly display, publicly perform,
+# sublicense and distribute its Work and any resulting derivative
+# works in any form.
+
+# 3. Limitations
+
+# 3.1 Redistribution. You may reproduce or distribute the Work only
+# if (a) you do so under this License, (b) you include a complete
+# copy of this License with your distribution, and (c) you retain
+# without modification any copyright, patent, trademark, or
+# attribution notices that are present in the Work.
+
+# 3.2 Derivative Works. You may specify that additional or different
+# terms apply to the use, reproduction, and distribution of your
+# derivative works of the Work ("Your Terms") only if (a) Your Terms
+# provide that the use limitation in Section 3.3 applies to your
+# derivative works, and (b) you identify the specific derivative
+# works that are subject to Your Terms. Notwithstanding Your Terms,
+# this License (including the redistribution requirements in Section
+# 3.1) will continue to apply to the Work itself.
+
+# 3.3 Use Limitation. The Work and any derivative works thereof only
+# may be used or intended for use non-commercially. Notwithstanding
+# the foregoing, NVIDIA and its affiliates may use the Work and any
+# derivative works commercially. As used herein, "non-commercially"
+# means for research or evaluation purposes only.
+
+# 3.4 Patent Claims. If you bring or threaten to bring a patent claim
+# against any Licensor (including any claim, cross-claim or
+# counterclaim in a lawsuit) to enforce any patents that you allege
+# are infringed by any Work, then your rights under this License from
+# such Licensor (including the grant in Section 2.1) will terminate
+# immediately.
+
+# 3.5 Trademarks. This License does not grant any rights to use any
+# Licensor’s or its affiliates’ names, logos, or trademarks, except
+# as necessary to reproduce the notices described in this License.
+
+# 3.6 Termination. If you violate any term of this License, then your
+# rights under this License (including the grant in Section 2.1) will
+# terminate immediately.
+
+# 4. Disclaimer of Warranty.
+
+# THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
+# NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
+# THIS LICENSE.
+
+# 5. Limitation of Liability.
+
+# EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
+# THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
+# SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
+# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
+# OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
+# (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
+# LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
+# COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGES.
+
+# =======================================================================
+
+import torch
+from torch.autograd import Function
+from torch.nn import functional as F
+
+from annotator.mmpkg.mmcv.utils import to_2tuple
+from ..utils import ext_loader
+
+upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d'])
+
+
+class UpFirDn2dBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad,
+ in_size, out_size):
+
+ up_x, up_y = up
+ down_x, down_y = down
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+
+ grad_input = upfirdn2d_ext.upfirdn2d(
+ grad_output,
+ grad_kernel,
+ up_x=down_x,
+ up_y=down_y,
+ down_x=up_x,
+ down_y=up_y,
+ pad_x0=g_pad_x0,
+ pad_x1=g_pad_x1,
+ pad_y0=g_pad_y0,
+ pad_y1=g_pad_y1)
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2],
+ in_size[3])
+
+ ctx.save_for_backward(kernel)
+
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ ctx.up_x = up_x
+ ctx.up_y = up_y
+ ctx.down_x = down_x
+ ctx.down_y = down_y
+ ctx.pad_x0 = pad_x0
+ ctx.pad_x1 = pad_x1
+ ctx.pad_y0 = pad_y0
+ ctx.pad_y1 = pad_y1
+ ctx.in_size = in_size
+ ctx.out_size = out_size
+
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_input):
+ kernel, = ctx.saved_tensors
+
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2],
+ ctx.in_size[3], 1)
+
+ gradgrad_out = upfirdn2d_ext.upfirdn2d(
+ gradgrad_input,
+ kernel,
+ up_x=ctx.up_x,
+ up_y=ctx.up_y,
+ down_x=ctx.down_x,
+ down_y=ctx.down_y,
+ pad_x0=ctx.pad_x0,
+ pad_x1=ctx.pad_x1,
+ pad_y0=ctx.pad_y0,
+ pad_y1=ctx.pad_y1)
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
+ # ctx.out_size[1], ctx.in_size[3])
+ gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1],
+ ctx.out_size[0], ctx.out_size[1])
+
+ return gradgrad_out, None, None, None, None, None, None, None, None
+
+
+class UpFirDn2d(Function):
+
+ @staticmethod
+ def forward(ctx, input, kernel, up, down, pad):
+ up_x, up_y = up
+ down_x, down_y = down
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ kernel_h, kernel_w = kernel.shape
+ batch, channel, in_h, in_w = input.shape
+ ctx.in_size = input.shape
+
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ ctx.out_size = (out_h, out_w)
+
+ ctx.up = (up_x, up_y)
+ ctx.down = (down_x, down_y)
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+
+ g_pad_x0 = kernel_w - pad_x0 - 1
+ g_pad_y0 = kernel_h - pad_y0 - 1
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+
+ out = upfirdn2d_ext.upfirdn2d(
+ input,
+ kernel,
+ up_x=up_x,
+ up_y=up_y,
+ down_x=down_x,
+ down_y=down_y,
+ pad_x0=pad_x0,
+ pad_x1=pad_x1,
+ pad_y0=pad_y0,
+ pad_y1=pad_y1)
+ # out = out.view(major, out_h, out_w, minor)
+ out = out.view(-1, channel, out_h, out_w)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, grad_kernel = ctx.saved_tensors
+
+ grad_input = UpFirDn2dBackward.apply(
+ grad_output,
+ kernel,
+ grad_kernel,
+ ctx.up,
+ ctx.down,
+ ctx.pad,
+ ctx.g_pad,
+ ctx.in_size,
+ ctx.out_size,
+ )
+
+ return grad_input, None, None, None, None
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ """UpFRIDn for 2d features.
+
+ UpFIRDn is short for upsample, apply FIR filter and downsample. More
+ details can be found in:
+ https://www.mathworks.com/help/signal/ref/upfirdn.html
+
+ Args:
+ input (Tensor): Tensor with shape of (n, c, h, w).
+ kernel (Tensor): Filter kernel.
+ up (int | tuple[int], optional): Upsampling factor. If given a number,
+ we will use this factor for the both height and width side.
+ Defaults to 1.
+ down (int | tuple[int], optional): Downsampling factor. If given a
+ number, we will use this factor for the both height and width side.
+ Defaults to 1.
+ pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad) or
+ (x_pad_0, x_pad_1, y_pad_0, y_pad_1). Defaults to (0, 0).
+
+ Returns:
+ Tensor: Tensor after UpFIRDn.
+ """
+ if input.device.type == 'cpu':
+ if len(pad) == 2:
+ pad = (pad[0], pad[1], pad[0], pad[1])
+
+ up = to_2tuple(up)
+
+ down = to_2tuple(down)
+
+ out = upfirdn2d_native(input, kernel, up[0], up[1], down[0], down[1],
+ pad[0], pad[1], pad[2], pad[3])
+ else:
+ _up = to_2tuple(up)
+
+ _down = to_2tuple(down)
+
+ if len(pad) == 4:
+ _pad = pad
+ elif len(pad) == 2:
+ _pad = (pad[0], pad[1], pad[0], pad[1])
+
+ out = UpFirDn2d.apply(input, kernel, _up, _down, _pad)
+
+ return out
+
+
+def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
+ pad_y0, pad_y1):
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(
+ out,
+ [0, 0,
+ max(pad_x0, 0),
+ max(pad_x1, 0),
+ max(pad_y0, 0),
+ max(pad_y1, 0)])
+ out = out[:,
+ max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape(
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/voxelize.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/voxelize.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca3226a4fbcbfe58490fa2ea8e1c16b531214121
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/ops/voxelize.py
@@ -0,0 +1,132 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+ '_ext', ['dynamic_voxelize_forward', 'hard_voxelize_forward'])
+
+
+class _Voxelization(Function):
+
+ @staticmethod
+ def forward(ctx,
+ points,
+ voxel_size,
+ coors_range,
+ max_points=35,
+ max_voxels=20000):
+ """Convert kitti points(N, >=3) to voxels.
+
+ Args:
+ points (torch.Tensor): [N, ndim]. Points[:, :3] contain xyz points
+ and points[:, 3:] contain other information like reflectivity.
+ voxel_size (tuple or float): The size of voxel with the shape of
+ [3].
+ coors_range (tuple or float): The coordinate range of voxel with
+ the shape of [6].
+ max_points (int, optional): maximum points contained in a voxel. if
+ max_points=-1, it means using dynamic_voxelize. Default: 35.
+ max_voxels (int, optional): maximum voxels this function create.
+ for second, 20000 is a good choice. Users should shuffle points
+ before call this function because max_voxels may drop points.
+ Default: 20000.
+
+ Returns:
+ voxels_out (torch.Tensor): Output voxels with the shape of [M,
+ max_points, ndim]. Only contain points and returned when
+ max_points != -1.
+ coors_out (torch.Tensor): Output coordinates with the shape of
+ [M, 3].
+ num_points_per_voxel_out (torch.Tensor): Num points per voxel with
+ the shape of [M]. Only returned when max_points != -1.
+ """
+ if max_points == -1 or max_voxels == -1:
+ coors = points.new_zeros(size=(points.size(0), 3), dtype=torch.int)
+ ext_module.dynamic_voxelize_forward(points, coors, voxel_size,
+ coors_range, 3)
+ return coors
+ else:
+ voxels = points.new_zeros(
+ size=(max_voxels, max_points, points.size(1)))
+ coors = points.new_zeros(size=(max_voxels, 3), dtype=torch.int)
+ num_points_per_voxel = points.new_zeros(
+ size=(max_voxels, ), dtype=torch.int)
+ voxel_num = ext_module.hard_voxelize_forward(
+ points, voxels, coors, num_points_per_voxel, voxel_size,
+ coors_range, max_points, max_voxels, 3)
+ # select the valid voxels
+ voxels_out = voxels[:voxel_num]
+ coors_out = coors[:voxel_num]
+ num_points_per_voxel_out = num_points_per_voxel[:voxel_num]
+ return voxels_out, coors_out, num_points_per_voxel_out
+
+
+voxelization = _Voxelization.apply
+
+
+class Voxelization(nn.Module):
+ """Convert kitti points(N, >=3) to voxels.
+
+ Please refer to `PVCNN `_ for more
+ details.
+
+ Args:
+ voxel_size (tuple or float): The size of voxel with the shape of [3].
+ point_cloud_range (tuple or float): The coordinate range of voxel with
+ the shape of [6].
+ max_num_points (int): maximum points contained in a voxel. if
+ max_points=-1, it means using dynamic_voxelize.
+ max_voxels (int, optional): maximum voxels this function create.
+ for second, 20000 is a good choice. Users should shuffle points
+ before call this function because max_voxels may drop points.
+ Default: 20000.
+ """
+
+ def __init__(self,
+ voxel_size,
+ point_cloud_range,
+ max_num_points,
+ max_voxels=20000):
+ super().__init__()
+
+ self.voxel_size = voxel_size
+ self.point_cloud_range = point_cloud_range
+ self.max_num_points = max_num_points
+ if isinstance(max_voxels, tuple):
+ self.max_voxels = max_voxels
+ else:
+ self.max_voxels = _pair(max_voxels)
+
+ point_cloud_range = torch.tensor(
+ point_cloud_range, dtype=torch.float32)
+ voxel_size = torch.tensor(voxel_size, dtype=torch.float32)
+ grid_size = (point_cloud_range[3:] -
+ point_cloud_range[:3]) / voxel_size
+ grid_size = torch.round(grid_size).long()
+ input_feat_shape = grid_size[:2]
+ self.grid_size = grid_size
+ # the origin shape is as [x-len, y-len, z-len]
+ # [w, h, d] -> [d, h, w]
+ self.pcd_shape = [*input_feat_shape, 1][::-1]
+
+ def forward(self, input):
+ if self.training:
+ max_voxels = self.max_voxels[0]
+ else:
+ max_voxels = self.max_voxels[1]
+
+ return voxelization(input, self.voxel_size, self.point_cloud_range,
+ self.max_num_points, max_voxels)
+
+ def __repr__(self):
+ s = self.__class__.__name__ + '('
+ s += 'voxel_size=' + str(self.voxel_size)
+ s += ', point_cloud_range=' + str(self.point_cloud_range)
+ s += ', max_num_points=' + str(self.max_num_points)
+ s += ', max_voxels=' + str(self.max_voxels)
+ s += ')'
+ return s
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ed2c17ad357742e423beeaf4d35db03fe9af469
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .collate import collate
+from .data_container import DataContainer
+from .data_parallel import MMDataParallel
+from .distributed import MMDistributedDataParallel
+from .registry import MODULE_WRAPPERS
+from .scatter_gather import scatter, scatter_kwargs
+from .utils import is_module_wrapper
+
+__all__ = [
+ 'collate', 'DataContainer', 'MMDataParallel', 'MMDistributedDataParallel',
+ 'scatter', 'scatter_kwargs', 'is_module_wrapper', 'MODULE_WRAPPERS'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/_functions.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b5a8a44483ab991411d07122b22a1d027e4be8e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/_functions.py
@@ -0,0 +1,79 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parallel._functions import _get_stream
+
+
+def scatter(input, devices, streams=None):
+ """Scatters tensor across multiple GPUs."""
+ if streams is None:
+ streams = [None] * len(devices)
+
+ if isinstance(input, list):
+ chunk_size = (len(input) - 1) // len(devices) + 1
+ outputs = [
+ scatter(input[i], [devices[i // chunk_size]],
+ [streams[i // chunk_size]]) for i in range(len(input))
+ ]
+ return outputs
+ elif isinstance(input, torch.Tensor):
+ output = input.contiguous()
+ # TODO: copy to a pinned buffer first (if copying from CPU)
+ stream = streams[0] if output.numel() > 0 else None
+ if devices != [-1]:
+ with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
+ output = output.cuda(devices[0], non_blocking=True)
+ else:
+ # unsqueeze the first dimension thus the tensor's shape is the
+ # same as those scattered with GPU.
+ output = output.unsqueeze(0)
+ return output
+ else:
+ raise Exception(f'Unknown type {type(input)}.')
+
+
+def synchronize_stream(output, devices, streams):
+ if isinstance(output, list):
+ chunk_size = len(output) // len(devices)
+ for i in range(len(devices)):
+ for j in range(chunk_size):
+ synchronize_stream(output[i * chunk_size + j], [devices[i]],
+ [streams[i]])
+ elif isinstance(output, torch.Tensor):
+ if output.numel() != 0:
+ with torch.cuda.device(devices[0]):
+ main_stream = torch.cuda.current_stream()
+ main_stream.wait_stream(streams[0])
+ output.record_stream(main_stream)
+ else:
+ raise Exception(f'Unknown type {type(output)}.')
+
+
+def get_input_device(input):
+ if isinstance(input, list):
+ for item in input:
+ input_device = get_input_device(item)
+ if input_device != -1:
+ return input_device
+ return -1
+ elif isinstance(input, torch.Tensor):
+ return input.get_device() if input.is_cuda else -1
+ else:
+ raise Exception(f'Unknown type {type(input)}.')
+
+
+class Scatter:
+
+ @staticmethod
+ def forward(target_gpus, input):
+ input_device = get_input_device(input)
+ streams = None
+ if input_device == -1 and target_gpus != [-1]:
+ # Perform CPU to GPU copies in a background stream
+ streams = [_get_stream(device) for device in target_gpus]
+
+ outputs = scatter(input, target_gpus, streams)
+ # Synchronize with the copy stream
+ if streams is not None:
+ synchronize_stream(outputs, target_gpus, streams)
+
+ return tuple(outputs)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/collate.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/collate.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad749197df21b0d74297548be5f66a696adebf7f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/collate.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections.abc import Mapping, Sequence
+
+import torch
+import torch.nn.functional as F
+from torch.utils.data.dataloader import default_collate
+
+from .data_container import DataContainer
+
+
+def collate(batch, samples_per_gpu=1):
+ """Puts each data field into a tensor/DataContainer with outer dimension
+ batch size.
+
+ Extend default_collate to add support for
+ :type:`~mmcv.parallel.DataContainer`. There are 3 cases.
+
+ 1. cpu_only = True, e.g., meta data
+ 2. cpu_only = False, stack = True, e.g., images tensors
+ 3. cpu_only = False, stack = False, e.g., gt bboxes
+ """
+
+ if not isinstance(batch, Sequence):
+ raise TypeError(f'{batch.dtype} is not supported.')
+
+ if isinstance(batch[0], DataContainer):
+ stacked = []
+ if batch[0].cpu_only:
+ for i in range(0, len(batch), samples_per_gpu):
+ stacked.append(
+ [sample.data for sample in batch[i:i + samples_per_gpu]])
+ return DataContainer(
+ stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
+ elif batch[0].stack:
+ for i in range(0, len(batch), samples_per_gpu):
+ assert isinstance(batch[i].data, torch.Tensor)
+
+ if batch[i].pad_dims is not None:
+ ndim = batch[i].dim()
+ assert ndim > batch[i].pad_dims
+ max_shape = [0 for _ in range(batch[i].pad_dims)]
+ for dim in range(1, batch[i].pad_dims + 1):
+ max_shape[dim - 1] = batch[i].size(-dim)
+ for sample in batch[i:i + samples_per_gpu]:
+ for dim in range(0, ndim - batch[i].pad_dims):
+ assert batch[i].size(dim) == sample.size(dim)
+ for dim in range(1, batch[i].pad_dims + 1):
+ max_shape[dim - 1] = max(max_shape[dim - 1],
+ sample.size(-dim))
+ padded_samples = []
+ for sample in batch[i:i + samples_per_gpu]:
+ pad = [0 for _ in range(batch[i].pad_dims * 2)]
+ for dim in range(1, batch[i].pad_dims + 1):
+ pad[2 * dim -
+ 1] = max_shape[dim - 1] - sample.size(-dim)
+ padded_samples.append(
+ F.pad(
+ sample.data, pad, value=sample.padding_value))
+ stacked.append(default_collate(padded_samples))
+ elif batch[i].pad_dims is None:
+ stacked.append(
+ default_collate([
+ sample.data
+ for sample in batch[i:i + samples_per_gpu]
+ ]))
+ else:
+ raise ValueError(
+ 'pad_dims should be either None or integers (1-3)')
+
+ else:
+ for i in range(0, len(batch), samples_per_gpu):
+ stacked.append(
+ [sample.data for sample in batch[i:i + samples_per_gpu]])
+ return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
+ elif isinstance(batch[0], Sequence):
+ transposed = zip(*batch)
+ return [collate(samples, samples_per_gpu) for samples in transposed]
+ elif isinstance(batch[0], Mapping):
+ return {
+ key: collate([d[key] for d in batch], samples_per_gpu)
+ for key in batch[0]
+ }
+ else:
+ return default_collate(batch)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/data_container.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/data_container.py
new file mode 100644
index 0000000000000000000000000000000000000000..cedb0d32a51a1f575a622b38de2cee3ab4757821
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/data_container.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+
+import torch
+
+
+def assert_tensor_type(func):
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ if not isinstance(args[0].data, torch.Tensor):
+ raise AttributeError(
+ f'{args[0].__class__.__name__} has no attribute '
+ f'{func.__name__} for type {args[0].datatype}')
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+class DataContainer:
+ """A container for any type of objects.
+
+ Typically tensors will be stacked in the collate function and sliced along
+ some dimension in the scatter function. This behavior has some limitations.
+ 1. All tensors have to be the same size.
+ 2. Types are limited (numpy array or Tensor).
+
+ We design `DataContainer` and `MMDataParallel` to overcome these
+ limitations. The behavior can be either of the following.
+
+ - copy to GPU, pad all tensors to the same size and stack them
+ - copy to GPU without stacking
+ - leave the objects as is and pass it to the model
+ - pad_dims specifies the number of last few dimensions to do padding
+ """
+
+ def __init__(self,
+ data,
+ stack=False,
+ padding_value=0,
+ cpu_only=False,
+ pad_dims=2):
+ self._data = data
+ self._cpu_only = cpu_only
+ self._stack = stack
+ self._padding_value = padding_value
+ assert pad_dims in [None, 1, 2, 3]
+ self._pad_dims = pad_dims
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}({repr(self.data)})'
+
+ def __len__(self):
+ return len(self._data)
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def datatype(self):
+ if isinstance(self.data, torch.Tensor):
+ return self.data.type()
+ else:
+ return type(self.data)
+
+ @property
+ def cpu_only(self):
+ return self._cpu_only
+
+ @property
+ def stack(self):
+ return self._stack
+
+ @property
+ def padding_value(self):
+ return self._padding_value
+
+ @property
+ def pad_dims(self):
+ return self._pad_dims
+
+ @assert_tensor_type
+ def size(self, *args, **kwargs):
+ return self.data.size(*args, **kwargs)
+
+ @assert_tensor_type
+ def dim(self):
+ return self.data.dim()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/data_parallel.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/data_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..79b5f69b654cf647dc7ae9174223781ab5c607d2
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/data_parallel.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import chain
+
+from torch.nn.parallel import DataParallel
+
+from .scatter_gather import scatter_kwargs
+
+
+class MMDataParallel(DataParallel):
+ """The DataParallel module that supports DataContainer.
+
+ MMDataParallel has two main differences with PyTorch DataParallel:
+
+ - It supports a custom type :class:`DataContainer` which allows more
+ flexible control of input data during both GPU and CPU inference.
+ - It implement two more APIs ``train_step()`` and ``val_step()``.
+
+ Args:
+ module (:class:`nn.Module`): Module to be encapsulated.
+ device_ids (list[int]): Device IDS of modules to be scattered to.
+ Defaults to None when GPU is not available.
+ output_device (str | int): Device ID for output. Defaults to None.
+ dim (int): Dimension used to scatter the data. Defaults to 0.
+ """
+
+ def __init__(self, *args, dim=0, **kwargs):
+ super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
+ self.dim = dim
+
+ def forward(self, *inputs, **kwargs):
+ """Override the original forward function.
+
+ The main difference lies in the CPU inference where the data in
+ :class:`DataContainers` will still be gathered.
+ """
+ if not self.device_ids:
+ # We add the following line thus the module could gather and
+ # convert data containers as those in GPU inference
+ inputs, kwargs = self.scatter(inputs, kwargs, [-1])
+ return self.module(*inputs[0], **kwargs[0])
+ else:
+ return super().forward(*inputs, **kwargs)
+
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+ def train_step(self, *inputs, **kwargs):
+ if not self.device_ids:
+ # We add the following line thus the module could gather and
+ # convert data containers as those in GPU inference
+ inputs, kwargs = self.scatter(inputs, kwargs, [-1])
+ return self.module.train_step(*inputs[0], **kwargs[0])
+
+ assert len(self.device_ids) == 1, \
+ ('MMDataParallel only supports single GPU training, if you need to'
+ ' train with multiple GPUs, please use MMDistributedDataParallel'
+ 'instead.')
+
+ for t in chain(self.module.parameters(), self.module.buffers()):
+ if t.device != self.src_device_obj:
+ raise RuntimeError(
+ 'module must have its parameters and buffers '
+ f'on device {self.src_device_obj} (device_ids[0]) but '
+ f'found one of them on device: {t.device}')
+
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ return self.module.train_step(*inputs[0], **kwargs[0])
+
+ def val_step(self, *inputs, **kwargs):
+ if not self.device_ids:
+ # We add the following line thus the module could gather and
+ # convert data containers as those in GPU inference
+ inputs, kwargs = self.scatter(inputs, kwargs, [-1])
+ return self.module.val_step(*inputs[0], **kwargs[0])
+
+ assert len(self.device_ids) == 1, \
+ ('MMDataParallel only supports single GPU training, if you need to'
+ ' train with multiple GPUs, please use MMDistributedDataParallel'
+ ' instead.')
+
+ for t in chain(self.module.parameters(), self.module.buffers()):
+ if t.device != self.src_device_obj:
+ raise RuntimeError(
+ 'module must have its parameters and buffers '
+ f'on device {self.src_device_obj} (device_ids[0]) but '
+ f'found one of them on device: {t.device}')
+
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ return self.module.val_step(*inputs[0], **kwargs[0])
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/distributed.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..929c7a451a7443d715ab0cceef530c53eff44cb9
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/distributed.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parallel.distributed import (DistributedDataParallel,
+ _find_tensors)
+
+from annotator.mmpkg.mmcv import print_log
+from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version
+from .scatter_gather import scatter_kwargs
+
+
+class MMDistributedDataParallel(DistributedDataParallel):
+ """The DDP module that supports DataContainer.
+
+ MMDDP has two main differences with PyTorch DDP:
+
+ - It supports a custom type :class:`DataContainer` which allows more
+ flexible control of input data.
+ - It implement two APIs ``train_step()`` and ``val_step()``.
+ """
+
+ def to_kwargs(self, inputs, kwargs, device_id):
+ # Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
+ # to move all tensors to device_id
+ return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
+
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+ def train_step(self, *inputs, **kwargs):
+ """train_step() API for module wrapped by DistributedDataParallel.
+
+ This method is basically the same as
+ ``DistributedDataParallel.forward()``, while replacing
+ ``self.module.forward()`` with ``self.module.train_step()``.
+ It is compatible with PyTorch 1.1 - 1.5.
+ """
+
+ # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
+ # end of backward to the beginning of forward.
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) >= digit_version('1.7')
+ and self.reducer._rebuild_buckets()):
+ print_log(
+ 'Reducer buckets have been rebuilt in this iteration.',
+ logger='mmcv')
+
+ if getattr(self, 'require_forward_param_sync', True):
+ self._sync_params()
+ if self.device_ids:
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ if len(self.device_ids) == 1:
+ output = self.module.train_step(*inputs[0], **kwargs[0])
+ else:
+ outputs = self.parallel_apply(
+ self._module_copies[:len(inputs)], inputs, kwargs)
+ output = self.gather(outputs, self.output_device)
+ else:
+ output = self.module.train_step(*inputs, **kwargs)
+
+ if torch.is_grad_enabled() and getattr(
+ self, 'require_backward_grad_sync', True):
+ if self.find_unused_parameters:
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
+ else:
+ self.reducer.prepare_for_backward([])
+ else:
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) > digit_version('1.2')):
+ self.require_forward_param_sync = False
+ return output
+
+ def val_step(self, *inputs, **kwargs):
+ """val_step() API for module wrapped by DistributedDataParallel.
+
+ This method is basically the same as
+ ``DistributedDataParallel.forward()``, while replacing
+ ``self.module.forward()`` with ``self.module.val_step()``.
+ It is compatible with PyTorch 1.1 - 1.5.
+ """
+ # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
+ # end of backward to the beginning of forward.
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) >= digit_version('1.7')
+ and self.reducer._rebuild_buckets()):
+ print_log(
+ 'Reducer buckets have been rebuilt in this iteration.',
+ logger='mmcv')
+
+ if getattr(self, 'require_forward_param_sync', True):
+ self._sync_params()
+ if self.device_ids:
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ if len(self.device_ids) == 1:
+ output = self.module.val_step(*inputs[0], **kwargs[0])
+ else:
+ outputs = self.parallel_apply(
+ self._module_copies[:len(inputs)], inputs, kwargs)
+ output = self.gather(outputs, self.output_device)
+ else:
+ output = self.module.val_step(*inputs, **kwargs)
+
+ if torch.is_grad_enabled() and getattr(
+ self, 'require_backward_grad_sync', True):
+ if self.find_unused_parameters:
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
+ else:
+ self.reducer.prepare_for_backward([])
+ else:
+ if ('parrots' not in TORCH_VERSION
+ and digit_version(TORCH_VERSION) > digit_version('1.2')):
+ self.require_forward_param_sync = False
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/distributed_deprecated.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/distributed_deprecated.py
new file mode 100644
index 0000000000000000000000000000000000000000..be60a37041fc6a76deae1851dde30448eaff054f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/distributed_deprecated.py
@@ -0,0 +1,70 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch._utils import (_flatten_dense_tensors, _take_tensors,
+ _unflatten_dense_tensors)
+
+from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version
+from .registry import MODULE_WRAPPERS
+from .scatter_gather import scatter_kwargs
+
+
+@MODULE_WRAPPERS.register_module()
+class MMDistributedDataParallel(nn.Module):
+
+ def __init__(self,
+ module,
+ dim=0,
+ broadcast_buffers=True,
+ bucket_cap_mb=25):
+ super(MMDistributedDataParallel, self).__init__()
+ self.module = module
+ self.dim = dim
+ self.broadcast_buffers = broadcast_buffers
+
+ self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
+ self._sync_params()
+
+ def _dist_broadcast_coalesced(self, tensors, buffer_size):
+ for tensors in _take_tensors(tensors, buffer_size):
+ flat_tensors = _flatten_dense_tensors(tensors)
+ dist.broadcast(flat_tensors, 0)
+ for tensor, synced in zip(
+ tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
+ tensor.copy_(synced)
+
+ def _sync_params(self):
+ module_states = list(self.module.state_dict().values())
+ if len(module_states) > 0:
+ self._dist_broadcast_coalesced(module_states,
+ self.broadcast_bucket_size)
+ if self.broadcast_buffers:
+ if (TORCH_VERSION != 'parrots'
+ and digit_version(TORCH_VERSION) < digit_version('1.0')):
+ buffers = [b.data for b in self.module._all_buffers()]
+ else:
+ buffers = [b.data for b in self.module.buffers()]
+ if len(buffers) > 0:
+ self._dist_broadcast_coalesced(buffers,
+ self.broadcast_bucket_size)
+
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+ def forward(self, *inputs, **kwargs):
+ inputs, kwargs = self.scatter(inputs, kwargs,
+ [torch.cuda.current_device()])
+ return self.module(*inputs[0], **kwargs[0])
+
+ def train_step(self, *inputs, **kwargs):
+ inputs, kwargs = self.scatter(inputs, kwargs,
+ [torch.cuda.current_device()])
+ output = self.module.train_step(*inputs[0], **kwargs[0])
+ return output
+
+ def val_step(self, *inputs, **kwargs):
+ inputs, kwargs = self.scatter(inputs, kwargs,
+ [torch.cuda.current_device()])
+ output = self.module.val_step(*inputs[0], **kwargs[0])
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/registry.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ce151e5f890691e8b583e5d50b492801bae82bd
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/registry.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+from annotator.mmpkg.mmcv.utils import Registry
+
+MODULE_WRAPPERS = Registry('module wrapper')
+MODULE_WRAPPERS.register_module(module=DataParallel)
+MODULE_WRAPPERS.register_module(module=DistributedDataParallel)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/scatter_gather.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/scatter_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..900ff88566f8f14830590459dc4fd16d4b382e47
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/scatter_gather.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parallel._functions import Scatter as OrigScatter
+
+from ._functions import Scatter
+from .data_container import DataContainer
+
+
+def scatter(inputs, target_gpus, dim=0):
+ """Scatter inputs to target gpus.
+
+ The only difference from original :func:`scatter` is to add support for
+ :type:`~mmcv.parallel.DataContainer`.
+ """
+
+ def scatter_map(obj):
+ if isinstance(obj, torch.Tensor):
+ if target_gpus != [-1]:
+ return OrigScatter.apply(target_gpus, None, dim, obj)
+ else:
+ # for CPU inference we use self-implemented scatter
+ return Scatter.forward(target_gpus, obj)
+ if isinstance(obj, DataContainer):
+ if obj.cpu_only:
+ return obj.data
+ else:
+ return Scatter.forward(target_gpus, obj.data)
+ if isinstance(obj, tuple) and len(obj) > 0:
+ return list(zip(*map(scatter_map, obj)))
+ if isinstance(obj, list) and len(obj) > 0:
+ out = list(map(list, zip(*map(scatter_map, obj))))
+ return out
+ if isinstance(obj, dict) and len(obj) > 0:
+ out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
+ return out
+ return [obj for targets in target_gpus]
+
+ # After scatter_map is called, a scatter_map cell will exist. This cell
+ # has a reference to the actual function scatter_map, which has references
+ # to a closure that has a reference to the scatter_map cell (because the
+ # fn is recursive). To avoid this reference cycle, we set the function to
+ # None, clearing the cell
+ try:
+ return scatter_map(inputs)
+ finally:
+ scatter_map = None
+
+
+def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
+ """Scatter with support for kwargs dictionary."""
+ inputs = scatter(inputs, target_gpus, dim) if inputs else []
+ kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
+ if len(inputs) < len(kwargs):
+ inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
+ elif len(kwargs) < len(inputs):
+ kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
+ inputs = tuple(inputs)
+ kwargs = tuple(kwargs)
+ return inputs, kwargs
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/utils.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f5712cb42c38a2e8563bf563efb6681383cab9b
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/parallel/utils.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .registry import MODULE_WRAPPERS
+
+
+def is_module_wrapper(module):
+ """Check if a module is a module wrapper.
+
+ The following 3 modules in MMCV (and their subclasses) are regarded as
+ module wrappers: DataParallel, DistributedDataParallel,
+ MMDistributedDataParallel (the deprecated version). You may add you own
+ module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS.
+
+ Args:
+ module (nn.Module): The module to be checked.
+
+ Returns:
+ bool: True if the input module is a module wrapper.
+ """
+ module_wrappers = tuple(MODULE_WRAPPERS.module_dict.values())
+ return isinstance(module, module_wrappers)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..52e4b48d383a84a055dcd7f6236f6e8e58eab924
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/__init__.py
@@ -0,0 +1,47 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_module import BaseModule, ModuleList, Sequential
+from .base_runner import BaseRunner
+from .builder import RUNNERS, build_runner
+from .checkpoint import (CheckpointLoader, _load_checkpoint,
+ _load_checkpoint_with_prefix, load_checkpoint,
+ load_state_dict, save_checkpoint, weights_to_cpu)
+from .default_constructor import DefaultRunnerConstructor
+from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
+ init_dist, master_only)
+from .epoch_based_runner import EpochBasedRunner, Runner
+from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
+from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
+ DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook,
+ Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
+ GradientCumulativeOptimizerHook, Hook, IterTimerHook,
+ LoggerHook, LrUpdaterHook, MlflowLoggerHook,
+ NeptuneLoggerHook, OptimizerHook, PaviLoggerHook,
+ SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook,
+ WandbLoggerHook)
+from .iter_based_runner import IterBasedRunner, IterLoader
+from .log_buffer import LogBuffer
+from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
+ DefaultOptimizerConstructor, build_optimizer,
+ build_optimizer_constructor)
+from .priority import Priority, get_priority
+from .utils import get_host_info, get_time_str, obj_from_dict, set_random_seed
+
+__all__ = [
+ 'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer',
+ 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
+ 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
+ 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
+ 'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook',
+ 'DvcliveLoggerHook', '_load_checkpoint', 'load_state_dict',
+ 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
+ 'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
+ 'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
+ 'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer',
+ 'build_optimizer_constructor', 'IterLoader', 'set_random_seed',
+ 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook',
+ 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
+ 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
+ '_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
+ 'ModuleList', 'GradientCumulativeOptimizerHook',
+ 'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/base_module.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/base_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..72e1164dfc442056cdc386050177f011b4e9900f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/base_module.py
@@ -0,0 +1,195 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+from abc import ABCMeta
+from collections import defaultdict
+from logging import FileHandler
+
+import torch.nn as nn
+
+from annotator.mmpkg.mmcv.runner.dist_utils import master_only
+from annotator.mmpkg.mmcv.utils.logging import get_logger, logger_initialized, print_log
+
+
+class BaseModule(nn.Module, metaclass=ABCMeta):
+ """Base module for all modules in openmmlab.
+
+ ``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional
+ functionality of parameter initialization. Compared with
+ ``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
+
+ - ``init_cfg``: the config to control the initialization.
+ - ``init_weights``: The function of parameter
+ initialization and recording initialization
+ information.
+ - ``_params_init_info``: Used to track the parameter
+ initialization information. This attribute only
+ exists during executing the ``init_weights``.
+
+ Args:
+ init_cfg (dict, optional): Initialization config dict.
+ """
+
+ def __init__(self, init_cfg=None):
+ """Initialize BaseModule, inherited from `torch.nn.Module`"""
+
+ # NOTE init_cfg can be defined in different levels, but init_cfg
+ # in low levels has a higher priority.
+
+ super(BaseModule, self).__init__()
+ # define default value of init_cfg instead of hard code
+ # in init_weights() function
+ self._is_init = False
+
+ self.init_cfg = copy.deepcopy(init_cfg)
+
+ # Backward compatibility in derived classes
+ # if pretrained is not None:
+ # warnings.warn('DeprecationWarning: pretrained is a deprecated \
+ # key, please consider using init_cfg')
+ # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+
+ @property
+ def is_init(self):
+ return self._is_init
+
+ def init_weights(self):
+ """Initialize the weights."""
+
+ is_top_level_module = False
+ # check if it is top-level module
+ if not hasattr(self, '_params_init_info'):
+ # The `_params_init_info` is used to record the initialization
+ # information of the parameters
+ # the key should be the obj:`nn.Parameter` of model and the value
+ # should be a dict containing
+ # - init_info (str): The string that describes the initialization.
+ # - tmp_mean_value (FloatTensor): The mean of the parameter,
+ # which indicates whether the parameter has been modified.
+ # this attribute would be deleted after all parameters
+ # is initialized.
+ self._params_init_info = defaultdict(dict)
+ is_top_level_module = True
+
+ # Initialize the `_params_init_info`,
+ # When detecting the `tmp_mean_value` of
+ # the corresponding parameter is changed, update related
+ # initialization information
+ for name, param in self.named_parameters():
+ self._params_init_info[param][
+ 'init_info'] = f'The value is the same before and ' \
+ f'after calling `init_weights` ' \
+ f'of {self.__class__.__name__} '
+ self._params_init_info[param][
+ 'tmp_mean_value'] = param.data.mean()
+
+ # pass `params_init_info` to all submodules
+ # All submodules share the same `params_init_info`,
+ # so it will be updated when parameters are
+ # modified at any level of the model.
+ for sub_module in self.modules():
+ sub_module._params_init_info = self._params_init_info
+
+ # Get the initialized logger, if not exist,
+ # create a logger named `mmcv`
+ logger_names = list(logger_initialized.keys())
+ logger_name = logger_names[0] if logger_names else 'mmcv'
+
+ from ..cnn import initialize
+ from ..cnn.utils.weight_init import update_init_info
+ module_name = self.__class__.__name__
+ if not self._is_init:
+ if self.init_cfg:
+ print_log(
+ f'initialize {module_name} with init_cfg {self.init_cfg}',
+ logger=logger_name)
+ initialize(self, self.init_cfg)
+ if isinstance(self.init_cfg, dict):
+ # prevent the parameters of
+ # the pre-trained model
+ # from being overwritten by
+ # the `init_weights`
+ if self.init_cfg['type'] == 'Pretrained':
+ return
+
+ for m in self.children():
+ if hasattr(m, 'init_weights'):
+ m.init_weights()
+ # users may overload the `init_weights`
+ update_init_info(
+ m,
+ init_info=f'Initialized by '
+ f'user-defined `init_weights`'
+ f' in {m.__class__.__name__} ')
+
+ self._is_init = True
+ else:
+ warnings.warn(f'init_weights of {self.__class__.__name__} has '
+ f'been called more than once.')
+
+ if is_top_level_module:
+ self._dump_init_info(logger_name)
+
+ for sub_module in self.modules():
+ del sub_module._params_init_info
+
+ @master_only
+ def _dump_init_info(self, logger_name):
+ """Dump the initialization information to a file named
+ `initialization.log.json` in workdir.
+
+ Args:
+ logger_name (str): The name of logger.
+ """
+
+ logger = get_logger(logger_name)
+
+ with_file_handler = False
+ # dump the information to the logger file if there is a `FileHandler`
+ for handler in logger.handlers:
+ if isinstance(handler, FileHandler):
+ handler.stream.write(
+ 'Name of parameter - Initialization information\n')
+ for name, param in self.named_parameters():
+ handler.stream.write(
+ f'\n{name} - {param.shape}: '
+ f"\n{self._params_init_info[param]['init_info']} \n")
+ handler.stream.flush()
+ with_file_handler = True
+ if not with_file_handler:
+ for name, param in self.named_parameters():
+ print_log(
+ f'\n{name} - {param.shape}: '
+ f"\n{self._params_init_info[param]['init_info']} \n ",
+ logger=logger_name)
+
+ def __repr__(self):
+ s = super().__repr__()
+ if self.init_cfg:
+ s += f'\ninit_cfg={self.init_cfg}'
+ return s
+
+
+class Sequential(BaseModule, nn.Sequential):
+ """Sequential module in openmmlab.
+
+ Args:
+ init_cfg (dict, optional): Initialization config dict.
+ """
+
+ def __init__(self, *args, init_cfg=None):
+ BaseModule.__init__(self, init_cfg)
+ nn.Sequential.__init__(self, *args)
+
+
+class ModuleList(BaseModule, nn.ModuleList):
+ """ModuleList in openmmlab.
+
+ Args:
+ modules (iterable, optional): an iterable of modules to add.
+ init_cfg (dict, optional): Initialization config dict.
+ """
+
+ def __init__(self, modules=None, init_cfg=None):
+ BaseModule.__init__(self, init_cfg)
+ nn.ModuleList.__init__(self, modules)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/base_runner.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/base_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..a75a7d5db9f281fda10008636b24e2b98d9336a0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/base_runner.py
@@ -0,0 +1,542 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import logging
+import os.path as osp
+import warnings
+from abc import ABCMeta, abstractmethod
+
+import torch
+from torch.optim import Optimizer
+
+import annotator.mmpkg.mmcv as mmcv
+from ..parallel import is_module_wrapper
+from .checkpoint import load_checkpoint
+from .dist_utils import get_dist_info
+from .hooks import HOOKS, Hook
+from .log_buffer import LogBuffer
+from .priority import Priority, get_priority
+from .utils import get_time_str
+
+
+class BaseRunner(metaclass=ABCMeta):
+ """The base class of Runner, a training helper for PyTorch.
+
+ All subclasses should implement the following APIs:
+
+ - ``run()``
+ - ``train()``
+ - ``val()``
+ - ``save_checkpoint()``
+
+ Args:
+ model (:obj:`torch.nn.Module`): The model to be run.
+ batch_processor (callable): A callable method that process a data
+ batch. The interface of this method should be
+ `batch_processor(model, data, train_mode) -> dict`
+ optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
+ optimizer (in most cases) or a dict of optimizers (in models that
+ requires more than one optimizer, e.g., GAN).
+ work_dir (str, optional): The working directory to save checkpoints
+ and logs. Defaults to None.
+ logger (:obj:`logging.Logger`): Logger used during training.
+ Defaults to None. (The default value is just for backward
+ compatibility)
+ meta (dict | None): A dict records some import information such as
+ environment info and seed, which will be logged in logger hook.
+ Defaults to None.
+ max_epochs (int, optional): Total training epochs.
+ max_iters (int, optional): Total training iterations.
+ """
+
+ def __init__(self,
+ model,
+ batch_processor=None,
+ optimizer=None,
+ work_dir=None,
+ logger=None,
+ meta=None,
+ max_iters=None,
+ max_epochs=None):
+ if batch_processor is not None:
+ if not callable(batch_processor):
+ raise TypeError('batch_processor must be callable, '
+ f'but got {type(batch_processor)}')
+ warnings.warn('batch_processor is deprecated, please implement '
+ 'train_step() and val_step() in the model instead.')
+ # raise an error is `batch_processor` is not None and
+ # `model.train_step()` exists.
+ if is_module_wrapper(model):
+ _model = model.module
+ else:
+ _model = model
+ if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
+ raise RuntimeError(
+ 'batch_processor and model.train_step()/model.val_step() '
+ 'cannot be both available.')
+ else:
+ assert hasattr(model, 'train_step')
+
+ # check the type of `optimizer`
+ if isinstance(optimizer, dict):
+ for name, optim in optimizer.items():
+ if not isinstance(optim, Optimizer):
+ raise TypeError(
+ f'optimizer must be a dict of torch.optim.Optimizers, '
+ f'but optimizer["{name}"] is a {type(optim)}')
+ elif not isinstance(optimizer, Optimizer) and optimizer is not None:
+ raise TypeError(
+ f'optimizer must be a torch.optim.Optimizer object '
+ f'or dict or None, but got {type(optimizer)}')
+
+ # check the type of `logger`
+ if not isinstance(logger, logging.Logger):
+ raise TypeError(f'logger must be a logging.Logger object, '
+ f'but got {type(logger)}')
+
+ # check the type of `meta`
+ if meta is not None and not isinstance(meta, dict):
+ raise TypeError(
+ f'meta must be a dict or None, but got {type(meta)}')
+
+ self.model = model
+ self.batch_processor = batch_processor
+ self.optimizer = optimizer
+ self.logger = logger
+ self.meta = meta
+ # create work_dir
+ if mmcv.is_str(work_dir):
+ self.work_dir = osp.abspath(work_dir)
+ mmcv.mkdir_or_exist(self.work_dir)
+ elif work_dir is None:
+ self.work_dir = None
+ else:
+ raise TypeError('"work_dir" must be a str or None')
+
+ # get model name from the model class
+ if hasattr(self.model, 'module'):
+ self._model_name = self.model.module.__class__.__name__
+ else:
+ self._model_name = self.model.__class__.__name__
+
+ self._rank, self._world_size = get_dist_info()
+ self.timestamp = get_time_str()
+ self.mode = None
+ self._hooks = []
+ self._epoch = 0
+ self._iter = 0
+ self._inner_iter = 0
+
+ if max_epochs is not None and max_iters is not None:
+ raise ValueError(
+ 'Only one of `max_epochs` or `max_iters` can be set.')
+
+ self._max_epochs = max_epochs
+ self._max_iters = max_iters
+ # TODO: Redesign LogBuffer, it is not flexible and elegant enough
+ self.log_buffer = LogBuffer()
+
+ @property
+ def model_name(self):
+ """str: Name of the model, usually the module class name."""
+ return self._model_name
+
+ @property
+ def rank(self):
+ """int: Rank of current process. (distributed training)"""
+ return self._rank
+
+ @property
+ def world_size(self):
+ """int: Number of processes participating in the job.
+ (distributed training)"""
+ return self._world_size
+
+ @property
+ def hooks(self):
+ """list[:obj:`Hook`]: A list of registered hooks."""
+ return self._hooks
+
+ @property
+ def epoch(self):
+ """int: Current epoch."""
+ return self._epoch
+
+ @property
+ def iter(self):
+ """int: Current iteration."""
+ return self._iter
+
+ @property
+ def inner_iter(self):
+ """int: Iteration in an epoch."""
+ return self._inner_iter
+
+ @property
+ def max_epochs(self):
+ """int: Maximum training epochs."""
+ return self._max_epochs
+
+ @property
+ def max_iters(self):
+ """int: Maximum training iterations."""
+ return self._max_iters
+
+ @abstractmethod
+ def train(self):
+ pass
+
+ @abstractmethod
+ def val(self):
+ pass
+
+ @abstractmethod
+ def run(self, data_loaders, workflow, **kwargs):
+ pass
+
+ @abstractmethod
+ def save_checkpoint(self,
+ out_dir,
+ filename_tmpl,
+ save_optimizer=True,
+ meta=None,
+ create_symlink=True):
+ pass
+
+ def current_lr(self):
+ """Get current learning rates.
+
+ Returns:
+ list[float] | dict[str, list[float]]: Current learning rates of all
+ param groups. If the runner has a dict of optimizers, this
+ method will return a dict.
+ """
+ if isinstance(self.optimizer, torch.optim.Optimizer):
+ lr = [group['lr'] for group in self.optimizer.param_groups]
+ elif isinstance(self.optimizer, dict):
+ lr = dict()
+ for name, optim in self.optimizer.items():
+ lr[name] = [group['lr'] for group in optim.param_groups]
+ else:
+ raise RuntimeError(
+ 'lr is not applicable because optimizer does not exist.')
+ return lr
+
+ def current_momentum(self):
+ """Get current momentums.
+
+ Returns:
+ list[float] | dict[str, list[float]]: Current momentums of all
+ param groups. If the runner has a dict of optimizers, this
+ method will return a dict.
+ """
+
+ def _get_momentum(optimizer):
+ momentums = []
+ for group in optimizer.param_groups:
+ if 'momentum' in group.keys():
+ momentums.append(group['momentum'])
+ elif 'betas' in group.keys():
+ momentums.append(group['betas'][0])
+ else:
+ momentums.append(0)
+ return momentums
+
+ if self.optimizer is None:
+ raise RuntimeError(
+ 'momentum is not applicable because optimizer does not exist.')
+ elif isinstance(self.optimizer, torch.optim.Optimizer):
+ momentums = _get_momentum(self.optimizer)
+ elif isinstance(self.optimizer, dict):
+ momentums = dict()
+ for name, optim in self.optimizer.items():
+ momentums[name] = _get_momentum(optim)
+ return momentums
+
+ def register_hook(self, hook, priority='NORMAL'):
+ """Register a hook into the hook list.
+
+ The hook will be inserted into a priority queue, with the specified
+ priority (See :class:`Priority` for details of priorities).
+ For hooks with the same priority, they will be triggered in the same
+ order as they are registered.
+
+ Args:
+ hook (:obj:`Hook`): The hook to be registered.
+ priority (int or str or :obj:`Priority`): Hook priority.
+ Lower value means higher priority.
+ """
+ assert isinstance(hook, Hook)
+ if hasattr(hook, 'priority'):
+ raise ValueError('"priority" is a reserved attribute for hooks')
+ priority = get_priority(priority)
+ hook.priority = priority
+ # insert the hook to a sorted list
+ inserted = False
+ for i in range(len(self._hooks) - 1, -1, -1):
+ if priority >= self._hooks[i].priority:
+ self._hooks.insert(i + 1, hook)
+ inserted = True
+ break
+ if not inserted:
+ self._hooks.insert(0, hook)
+
+ def register_hook_from_cfg(self, hook_cfg):
+ """Register a hook from its cfg.
+
+ Args:
+ hook_cfg (dict): Hook config. It should have at least keys 'type'
+ and 'priority' indicating its type and priority.
+
+ Notes:
+ The specific hook class to register should not use 'type' and
+ 'priority' arguments during initialization.
+ """
+ hook_cfg = hook_cfg.copy()
+ priority = hook_cfg.pop('priority', 'NORMAL')
+ hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
+ self.register_hook(hook, priority=priority)
+
+ def call_hook(self, fn_name):
+ """Call all hooks.
+
+ Args:
+ fn_name (str): The function name in each hook to be called, such as
+ "before_train_epoch".
+ """
+ for hook in self._hooks:
+ getattr(hook, fn_name)(self)
+
+ def get_hook_info(self):
+ # Get hooks info in each stage
+ stage_hook_map = {stage: [] for stage in Hook.stages}
+ for hook in self.hooks:
+ try:
+ priority = Priority(hook.priority).name
+ except ValueError:
+ priority = hook.priority
+ classname = hook.__class__.__name__
+ hook_info = f'({priority:<12}) {classname:<35}'
+ for trigger_stage in hook.get_triggered_stages():
+ stage_hook_map[trigger_stage].append(hook_info)
+
+ stage_hook_infos = []
+ for stage in Hook.stages:
+ hook_infos = stage_hook_map[stage]
+ if len(hook_infos) > 0:
+ info = f'{stage}:\n'
+ info += '\n'.join(hook_infos)
+ info += '\n -------------------- '
+ stage_hook_infos.append(info)
+ return '\n'.join(stage_hook_infos)
+
+ def load_checkpoint(self,
+ filename,
+ map_location='cpu',
+ strict=False,
+ revise_keys=[(r'^module.', '')]):
+ return load_checkpoint(
+ self.model,
+ filename,
+ map_location,
+ strict,
+ self.logger,
+ revise_keys=revise_keys)
+
+ def resume(self,
+ checkpoint,
+ resume_optimizer=True,
+ map_location='default'):
+ if map_location == 'default':
+ if torch.cuda.is_available():
+ device_id = torch.cuda.current_device()
+ checkpoint = self.load_checkpoint(
+ checkpoint,
+ map_location=lambda storage, loc: storage.cuda(device_id))
+ else:
+ checkpoint = self.load_checkpoint(checkpoint)
+ else:
+ checkpoint = self.load_checkpoint(
+ checkpoint, map_location=map_location)
+
+ self._epoch = checkpoint['meta']['epoch']
+ self._iter = checkpoint['meta']['iter']
+ if self.meta is None:
+ self.meta = {}
+ self.meta.setdefault('hook_msgs', {})
+ # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages
+ self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {}))
+
+ # Re-calculate the number of iterations when resuming
+ # models with different number of GPUs
+ if 'config' in checkpoint['meta']:
+ config = mmcv.Config.fromstring(
+ checkpoint['meta']['config'], file_format='.py')
+ previous_gpu_ids = config.get('gpu_ids', None)
+ if previous_gpu_ids and len(previous_gpu_ids) > 0 and len(
+ previous_gpu_ids) != self.world_size:
+ self._iter = int(self._iter * len(previous_gpu_ids) /
+ self.world_size)
+ self.logger.info('the iteration number is changed due to '
+ 'change of GPU number')
+
+ # resume meta information meta
+ self.meta = checkpoint['meta']
+
+ if 'optimizer' in checkpoint and resume_optimizer:
+ if isinstance(self.optimizer, Optimizer):
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
+ elif isinstance(self.optimizer, dict):
+ for k in self.optimizer.keys():
+ self.optimizer[k].load_state_dict(
+ checkpoint['optimizer'][k])
+ else:
+ raise TypeError(
+ 'Optimizer should be dict or torch.optim.Optimizer '
+ f'but got {type(self.optimizer)}')
+
+ self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
+
+ def register_lr_hook(self, lr_config):
+ if lr_config is None:
+ return
+ elif isinstance(lr_config, dict):
+ assert 'policy' in lr_config
+ policy_type = lr_config.pop('policy')
+ # If the type of policy is all in lower case, e.g., 'cyclic',
+ # then its first letter will be capitalized, e.g., to be 'Cyclic'.
+ # This is for the convenient usage of Lr updater.
+ # Since this is not applicable for `
+ # CosineAnnealingLrUpdater`,
+ # the string will not be changed if it contains capital letters.
+ if policy_type == policy_type.lower():
+ policy_type = policy_type.title()
+ hook_type = policy_type + 'LrUpdaterHook'
+ lr_config['type'] = hook_type
+ hook = mmcv.build_from_cfg(lr_config, HOOKS)
+ else:
+ hook = lr_config
+ self.register_hook(hook, priority='VERY_HIGH')
+
+ def register_momentum_hook(self, momentum_config):
+ if momentum_config is None:
+ return
+ if isinstance(momentum_config, dict):
+ assert 'policy' in momentum_config
+ policy_type = momentum_config.pop('policy')
+ # If the type of policy is all in lower case, e.g., 'cyclic',
+ # then its first letter will be capitalized, e.g., to be 'Cyclic'.
+ # This is for the convenient usage of momentum updater.
+ # Since this is not applicable for
+ # `CosineAnnealingMomentumUpdater`,
+ # the string will not be changed if it contains capital letters.
+ if policy_type == policy_type.lower():
+ policy_type = policy_type.title()
+ hook_type = policy_type + 'MomentumUpdaterHook'
+ momentum_config['type'] = hook_type
+ hook = mmcv.build_from_cfg(momentum_config, HOOKS)
+ else:
+ hook = momentum_config
+ self.register_hook(hook, priority='HIGH')
+
+ def register_optimizer_hook(self, optimizer_config):
+ if optimizer_config is None:
+ return
+ if isinstance(optimizer_config, dict):
+ optimizer_config.setdefault('type', 'OptimizerHook')
+ hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
+ else:
+ hook = optimizer_config
+ self.register_hook(hook, priority='ABOVE_NORMAL')
+
+ def register_checkpoint_hook(self, checkpoint_config):
+ if checkpoint_config is None:
+ return
+ if isinstance(checkpoint_config, dict):
+ checkpoint_config.setdefault('type', 'CheckpointHook')
+ hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
+ else:
+ hook = checkpoint_config
+ self.register_hook(hook, priority='NORMAL')
+
+ def register_logger_hooks(self, log_config):
+ if log_config is None:
+ return
+ log_interval = log_config['interval']
+ for info in log_config['hooks']:
+ logger_hook = mmcv.build_from_cfg(
+ info, HOOKS, default_args=dict(interval=log_interval))
+ self.register_hook(logger_hook, priority='VERY_LOW')
+
+ def register_timer_hook(self, timer_config):
+ if timer_config is None:
+ return
+ if isinstance(timer_config, dict):
+ timer_config_ = copy.deepcopy(timer_config)
+ hook = mmcv.build_from_cfg(timer_config_, HOOKS)
+ else:
+ hook = timer_config
+ self.register_hook(hook, priority='LOW')
+
+ def register_custom_hooks(self, custom_config):
+ if custom_config is None:
+ return
+
+ if not isinstance(custom_config, list):
+ custom_config = [custom_config]
+
+ for item in custom_config:
+ if isinstance(item, dict):
+ self.register_hook_from_cfg(item)
+ else:
+ self.register_hook(item, priority='NORMAL')
+
+ def register_profiler_hook(self, profiler_config):
+ if profiler_config is None:
+ return
+ if isinstance(profiler_config, dict):
+ profiler_config.setdefault('type', 'ProfilerHook')
+ hook = mmcv.build_from_cfg(profiler_config, HOOKS)
+ else:
+ hook = profiler_config
+ self.register_hook(hook)
+
+ def register_training_hooks(self,
+ lr_config,
+ optimizer_config=None,
+ checkpoint_config=None,
+ log_config=None,
+ momentum_config=None,
+ timer_config=dict(type='IterTimerHook'),
+ custom_hooks_config=None):
+ """Register default and custom hooks for training.
+
+ Default and custom hooks include:
+
+ +----------------------+-------------------------+
+ | Hooks | Priority |
+ +======================+=========================+
+ | LrUpdaterHook | VERY_HIGH (10) |
+ +----------------------+-------------------------+
+ | MomentumUpdaterHook | HIGH (30) |
+ +----------------------+-------------------------+
+ | OptimizerStepperHook | ABOVE_NORMAL (40) |
+ +----------------------+-------------------------+
+ | CheckpointSaverHook | NORMAL (50) |
+ +----------------------+-------------------------+
+ | IterTimerHook | LOW (70) |
+ +----------------------+-------------------------+
+ | LoggerHook(s) | VERY_LOW (90) |
+ +----------------------+-------------------------+
+ | CustomHook(s) | defaults to NORMAL (50) |
+ +----------------------+-------------------------+
+
+ If custom hooks have same priority with default hooks, custom hooks
+ will be triggered after default hooks.
+ """
+ self.register_lr_hook(lr_config)
+ self.register_momentum_hook(momentum_config)
+ self.register_optimizer_hook(optimizer_config)
+ self.register_checkpoint_hook(checkpoint_config)
+ self.register_timer_hook(timer_config)
+ self.register_logger_hooks(log_config)
+ self.register_custom_hooks(custom_hooks_config)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/builder.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..77c96ba0b2f30ead9da23f293c5dc84dd3e4a74f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/builder.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+from ..utils import Registry
+
+RUNNERS = Registry('runner')
+RUNNER_BUILDERS = Registry('runner builder')
+
+
+def build_runner_constructor(cfg):
+ return RUNNER_BUILDERS.build(cfg)
+
+
+def build_runner(cfg, default_args=None):
+ runner_cfg = copy.deepcopy(cfg)
+ constructor_type = runner_cfg.pop('constructor',
+ 'DefaultRunnerConstructor')
+ runner_constructor = build_runner_constructor(
+ dict(
+ type=constructor_type,
+ runner_cfg=runner_cfg,
+ default_args=default_args))
+ runner = runner_constructor()
+ return runner
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/checkpoint.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..d690be1dfe70b1b82eaac8fe4db7022b35d5426c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/checkpoint.py
@@ -0,0 +1,707 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import io
+import os
+import os.path as osp
+import pkgutil
+import re
+import time
+import warnings
+from collections import OrderedDict
+from importlib import import_module
+from tempfile import TemporaryDirectory
+
+import torch
+import torchvision
+from torch.optim import Optimizer
+from torch.utils import model_zoo
+
+import annotator.mmpkg.mmcv as mmcv
+from ..fileio import FileClient
+from ..fileio import load as load_file
+from ..parallel import is_module_wrapper
+from ..utils import mkdir_or_exist
+from .dist_utils import get_dist_info
+
+ENV_MMCV_HOME = 'MMCV_HOME'
+ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+DEFAULT_CACHE_DIR = '~/.cache'
+
+
+def _get_mmcv_home():
+ mmcv_home = os.path.expanduser(
+ os.getenv(
+ ENV_MMCV_HOME,
+ os.path.join(
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
+
+ mkdir_or_exist(mmcv_home)
+ return mmcv_home
+
+
+def load_state_dict(module, state_dict, strict=False, logger=None):
+ """Load state_dict to a module.
+
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
+ Default value for ``strict`` is set to ``False`` and the message for
+ param mismatch will be shown even if strict is False.
+
+ Args:
+ module (Module): Module that receives the state_dict.
+ state_dict (OrderedDict): Weights.
+ strict (bool): whether to strictly enforce that the keys
+ in :attr:`state_dict` match the keys returned by this module's
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
+ message. If not specified, print function will be used.
+ """
+ unexpected_keys = []
+ all_missing_keys = []
+ err_msg = []
+
+ metadata = getattr(state_dict, '_metadata', None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ # use _load_from_state_dict to enable checkpoint version control
+ def load(module, prefix=''):
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+ local_metadata = {} if metadata is None else metadata.get(
+ prefix[:-1], {})
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
+ all_missing_keys, unexpected_keys,
+ err_msg)
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + '.')
+
+ load(module)
+ load = None # break load->load reference cycle
+
+ # ignore "num_batches_tracked" of BN layers
+ missing_keys = [
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
+ ]
+
+ if unexpected_keys:
+ err_msg.append('unexpected key in source '
+ f'state_dict: {", ".join(unexpected_keys)}\n')
+ if missing_keys:
+ err_msg.append(
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
+
+ rank, _ = get_dist_info()
+ if len(err_msg) > 0 and rank == 0:
+ err_msg.insert(
+ 0, 'The model and loaded state dict do not match exactly\n')
+ err_msg = '\n'.join(err_msg)
+ if strict:
+ raise RuntimeError(err_msg)
+ elif logger is not None:
+ logger.warning(err_msg)
+ else:
+ print(err_msg)
+
+
+def get_torchvision_models():
+ model_urls = dict()
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
+ if ispkg:
+ continue
+ _zoo = import_module(f'torchvision.models.{name}')
+ if hasattr(_zoo, 'model_urls'):
+ _urls = getattr(_zoo, 'model_urls')
+ model_urls.update(_urls)
+ return model_urls
+
+
+def get_external_models():
+ mmcv_home = _get_mmcv_home()
+ default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
+ default_urls = load_file(default_json_path)
+ assert isinstance(default_urls, dict)
+ external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
+ if osp.exists(external_json_path):
+ external_urls = load_file(external_json_path)
+ assert isinstance(external_urls, dict)
+ default_urls.update(external_urls)
+
+ return default_urls
+
+
+def get_mmcls_models():
+ mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
+ mmcls_urls = load_file(mmcls_json_path)
+
+ return mmcls_urls
+
+
+def get_deprecated_model_names():
+ deprecate_json_path = osp.join(mmcv.__path__[0],
+ 'model_zoo/deprecated.json')
+ deprecate_urls = load_file(deprecate_json_path)
+ assert isinstance(deprecate_urls, dict)
+
+ return deprecate_urls
+
+
+def _process_mmcls_checkpoint(checkpoint):
+ state_dict = checkpoint['state_dict']
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ if k.startswith('backbone.'):
+ new_state_dict[k[9:]] = v
+ new_checkpoint = dict(state_dict=new_state_dict)
+
+ return new_checkpoint
+
+
+class CheckpointLoader:
+ """A general checkpoint loader to manage all schemes."""
+
+ _schemes = {}
+
+ @classmethod
+ def _register_scheme(cls, prefixes, loader, force=False):
+ if isinstance(prefixes, str):
+ prefixes = [prefixes]
+ else:
+ assert isinstance(prefixes, (list, tuple))
+ for prefix in prefixes:
+ if (prefix not in cls._schemes) or force:
+ cls._schemes[prefix] = loader
+ else:
+ raise KeyError(
+ f'{prefix} is already registered as a loader backend, '
+ 'add "force=True" if you want to override it')
+ # sort, longer prefixes take priority
+ cls._schemes = OrderedDict(
+ sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))
+
+ @classmethod
+ def register_scheme(cls, prefixes, loader=None, force=False):
+ """Register a loader to CheckpointLoader.
+
+ This method can be used as a normal class method or a decorator.
+
+ Args:
+ prefixes (str or list[str] or tuple[str]):
+ The prefix of the registered loader.
+ loader (function, optional): The loader function to be registered.
+ When this method is used as a decorator, loader is None.
+ Defaults to None.
+ force (bool, optional): Whether to override the loader
+ if the prefix has already been registered. Defaults to False.
+ """
+
+ if loader is not None:
+ cls._register_scheme(prefixes, loader, force=force)
+ return
+
+ def _register(loader_cls):
+ cls._register_scheme(prefixes, loader_cls, force=force)
+ return loader_cls
+
+ return _register
+
+ @classmethod
+ def _get_checkpoint_loader(cls, path):
+ """Finds a loader that supports the given path. Falls back to the local
+ loader if no other loader is found.
+
+ Args:
+ path (str): checkpoint path
+
+ Returns:
+ loader (function): checkpoint loader
+ """
+
+ for p in cls._schemes:
+ if path.startswith(p):
+ return cls._schemes[p]
+
+ @classmethod
+ def load_checkpoint(cls, filename, map_location=None, logger=None):
+ """load checkpoint through URL scheme path.
+
+ Args:
+ filename (str): checkpoint file name with given prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None
+ logger (:mod:`logging.Logger`, optional): The logger for message.
+ Default: None
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+
+ checkpoint_loader = cls._get_checkpoint_loader(filename)
+ class_name = checkpoint_loader.__name__
+ mmcv.print_log(
+ f'load checkpoint from {class_name[10:]} path: {filename}', logger)
+ return checkpoint_loader(filename, map_location)
+
+
+@CheckpointLoader.register_scheme(prefixes='')
+def load_from_local(filename, map_location):
+ """load checkpoint by local file path.
+
+ Args:
+ filename (str): local checkpoint file path
+ map_location (str, optional): Same as :func:`torch.load`.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
+def load_from_http(filename, map_location=None, model_dir=None):
+ """load checkpoint through HTTP or HTTPS scheme path. In distributed
+ setting, this function only download checkpoint at local rank 0.
+
+ Args:
+ filename (str): checkpoint file path with modelzoo or
+ torchvision prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ model_dir (string, optional): directory in which to save the object,
+ Default: None
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ if rank == 0:
+ checkpoint = model_zoo.load_url(
+ filename, model_dir=model_dir, map_location=map_location)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ checkpoint = model_zoo.load_url(
+ filename, model_dir=model_dir, map_location=map_location)
+ return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes='pavi://')
+def load_from_pavi(filename, map_location=None):
+ """load checkpoint through the file path prefixed with pavi. In distributed
+ setting, this function download ckpt at all ranks to different temporary
+ directories.
+
+ Args:
+ filename (str): checkpoint file path with pavi prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ assert filename.startswith('pavi://'), \
+ f'Expected filename startswith `pavi://`, but get {filename}'
+ model_path = filename[7:]
+
+ try:
+ from pavi import modelcloud
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+
+ model = modelcloud.get(model_path)
+ with TemporaryDirectory() as tmp_dir:
+ downloaded_file = osp.join(tmp_dir, model.name)
+ model.download(downloaded_file)
+ checkpoint = torch.load(downloaded_file, map_location=map_location)
+ return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes='s3://')
+def load_from_ceph(filename, map_location=None, backend='petrel'):
+ """load checkpoint through the file path prefixed with s3. In distributed
+ setting, this function download ckpt at all ranks to different temporary
+ directories.
+
+ Args:
+ filename (str): checkpoint file path with s3 prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ backend (str, optional): The storage backend type. Options are 'ceph',
+ 'petrel'. Default: 'petrel'.
+
+ .. warning::
+ :class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
+ please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ allowed_backends = ['ceph', 'petrel']
+ if backend not in allowed_backends:
+ raise ValueError(f'Load from Backend {backend} is not supported.')
+
+ if backend == 'ceph':
+ warnings.warn(
+ 'CephBackend will be deprecated, please use PetrelBackend instead')
+
+ # CephClient and PetrelBackend have the same prefix 's3://' and the latter
+ # will be chosen as default. If PetrelBackend can not be instantiated
+ # successfully, the CephClient will be chosen.
+ try:
+ file_client = FileClient(backend=backend)
+ except ImportError:
+ allowed_backends.remove(backend)
+ file_client = FileClient(backend=allowed_backends[0])
+
+ with io.BytesIO(file_client.get(filename)) as buffer:
+ checkpoint = torch.load(buffer, map_location=map_location)
+ return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
+def load_from_torchvision(filename, map_location=None):
+ """load checkpoint through the file path prefixed with modelzoo or
+ torchvision.
+
+ Args:
+ filename (str): checkpoint file path with modelzoo or
+ torchvision prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ model_urls = get_torchvision_models()
+ if filename.startswith('modelzoo://'):
+ warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
+ 'use "torchvision://" instead')
+ model_name = filename[11:]
+ else:
+ model_name = filename[14:]
+ return load_from_http(model_urls[model_name], map_location=map_location)
+
+
+@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
+def load_from_openmmlab(filename, map_location=None):
+ """load checkpoint through the file path prefixed with open-mmlab or
+ openmmlab.
+
+ Args:
+ filename (str): checkpoint file path with open-mmlab or
+ openmmlab prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+
+ model_urls = get_external_models()
+ prefix_str = 'open-mmlab://'
+ if filename.startswith(prefix_str):
+ model_name = filename[13:]
+ else:
+ model_name = filename[12:]
+ prefix_str = 'openmmlab://'
+
+ deprecated_urls = get_deprecated_model_names()
+ if model_name in deprecated_urls:
+ warnings.warn(f'{prefix_str}{model_name} is deprecated in favor '
+ f'of {prefix_str}{deprecated_urls[model_name]}')
+ model_name = deprecated_urls[model_name]
+ model_url = model_urls[model_name]
+ # check if is url
+ if model_url.startswith(('http://', 'https://')):
+ checkpoint = load_from_http(model_url, map_location=map_location)
+ else:
+ filename = osp.join(_get_mmcv_home(), model_url)
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes='mmcls://')
+def load_from_mmcls(filename, map_location=None):
+ """load checkpoint through the file path prefixed with mmcls.
+
+ Args:
+ filename (str): checkpoint file path with mmcls prefix
+ map_location (str, optional): Same as :func:`torch.load`.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+
+ model_urls = get_mmcls_models()
+ model_name = filename[8:]
+ checkpoint = load_from_http(
+ model_urls[model_name], map_location=map_location)
+ checkpoint = _process_mmcls_checkpoint(checkpoint)
+ return checkpoint
+
+
+def _load_checkpoint(filename, map_location=None, logger=None):
+ """Load checkpoint from somewhere (modelzoo, file, url).
+
+ Args:
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default: None.
+ logger (:mod:`logging.Logger`, optional): The logger for error message.
+ Default: None
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint. It can be either an
+ OrderedDict storing model weights or a dict containing other
+ information, which depends on the checkpoint.
+ """
+ return CheckpointLoader.load_checkpoint(filename, map_location, logger)
+
+
+def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
+ """Load partial pretrained model with specific prefix.
+
+ Args:
+ prefix (str): The prefix of sub-module.
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str | None): Same as :func:`torch.load`. Default: None.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+
+ checkpoint = _load_checkpoint(filename, map_location=map_location)
+
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+ if not prefix.endswith('.'):
+ prefix += '.'
+ prefix_len = len(prefix)
+
+ state_dict = {
+ k[prefix_len:]: v
+ for k, v in state_dict.items() if k.startswith(prefix)
+ }
+
+ assert state_dict, f'{prefix} is not in the pretrained model'
+ return state_dict
+
+
+def load_checkpoint(model,
+ filename,
+ map_location=None,
+ strict=False,
+ logger=None,
+ revise_keys=[(r'^module\.', '')]):
+ """Load checkpoint from a file or URI.
+
+ Args:
+ model (Module): Module to load checkpoint.
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str): Same as :func:`torch.load`.
+ strict (bool): Whether to allow different params for the model and
+ checkpoint.
+ logger (:mod:`logging.Logger` or None): The logger for error message.
+ revise_keys (list): A list of customized keywords to modify the
+ state_dict in checkpoint. Each item is a (pattern, replacement)
+ pair of the regular expression operations. Default: strip
+ the prefix 'module.' by [(r'^module\\.', '')].
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ checkpoint = _load_checkpoint(filename, map_location, logger)
+ # OrderedDict is a subclass of dict
+ if not isinstance(checkpoint, dict):
+ raise RuntimeError(
+ f'No state_dict found in checkpoint file {filename}')
+ # get state_dict from checkpoint
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+
+ # strip prefix of state_dict
+ metadata = getattr(state_dict, '_metadata', OrderedDict())
+ for p, r in revise_keys:
+ state_dict = OrderedDict(
+ {re.sub(p, r, k): v
+ for k, v in state_dict.items()})
+ # Keep metadata in state_dict
+ state_dict._metadata = metadata
+
+ # load state_dict
+ load_state_dict(model, state_dict, strict, logger)
+ return checkpoint
+
+
+def weights_to_cpu(state_dict):
+ """Copy a model state_dict to cpu.
+
+ Args:
+ state_dict (OrderedDict): Model weights on GPU.
+
+ Returns:
+ OrderedDict: Model weights on GPU.
+ """
+ state_dict_cpu = OrderedDict()
+ for key, val in state_dict.items():
+ state_dict_cpu[key] = val.cpu()
+ # Keep metadata in state_dict
+ state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
+ return state_dict_cpu
+
+
+def _save_to_state_dict(module, destination, prefix, keep_vars):
+ """Saves module state to `destination` dictionary.
+
+ This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
+
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (dict): A dict where state will be stored.
+ prefix (str): The prefix for parameters and buffers used in this
+ module.
+ """
+ for name, param in module._parameters.items():
+ if param is not None:
+ destination[prefix + name] = param if keep_vars else param.detach()
+ for name, buf in module._buffers.items():
+ # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
+ if buf is not None:
+ destination[prefix + name] = buf if keep_vars else buf.detach()
+
+
+def get_state_dict(module, destination=None, prefix='', keep_vars=False):
+ """Returns a dictionary containing a whole state of the module.
+
+ Both parameters and persistent buffers (e.g. running averages) are
+ included. Keys are corresponding parameter and buffer names.
+
+ This method is modified from :meth:`torch.nn.Module.state_dict` to
+ recursively check parallel module in case that the model has a complicated
+ structure, e.g., nn.Module(nn.Module(DDP)).
+
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (OrderedDict): Returned dict for the state of the
+ module.
+ prefix (str): Prefix of the key.
+ keep_vars (bool): Whether to keep the variable property of the
+ parameters. Default: False.
+
+ Returns:
+ dict: A dictionary containing a whole state of the module.
+ """
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+
+ # below is the same as torch.nn.Module.state_dict()
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(
+ version=module._version)
+ _save_to_state_dict(module, destination, prefix, keep_vars)
+ for name, child in module._modules.items():
+ if child is not None:
+ get_state_dict(
+ child, destination, prefix + name + '.', keep_vars=keep_vars)
+ for hook in module._state_dict_hooks.values():
+ hook_result = hook(module, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+ return destination
+
+
+def save_checkpoint(model,
+ filename,
+ optimizer=None,
+ meta=None,
+ file_client_args=None):
+ """Save checkpoint to file.
+
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
+ ``optimizer``. By default ``meta`` will contain version and time info.
+
+ Args:
+ model (Module): Module whose params are to be saved.
+ filename (str): Checkpoint filename.
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ `New in version 1.3.16.`
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
+
+ if is_module_wrapper(model):
+ model = model.module
+
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
+ # save class name to the meta
+ meta.update(CLASSES=model.CLASSES)
+
+ checkpoint = {
+ 'meta': meta,
+ 'state_dict': weights_to_cpu(get_state_dict(model))
+ }
+ # save optimizer state dict in the checkpoint
+ if isinstance(optimizer, Optimizer):
+ checkpoint['optimizer'] = optimizer.state_dict()
+ elif isinstance(optimizer, dict):
+ checkpoint['optimizer'] = {}
+ for name, optim in optimizer.items():
+ checkpoint['optimizer'][name] = optim.state_dict()
+
+ if filename.startswith('pavi://'):
+ if file_client_args is not None:
+ raise ValueError(
+ 'file_client_args should be "None" if filename starts with'
+ f'"pavi://", but got {file_client_args}')
+ try:
+ from pavi import modelcloud
+ from pavi import exception
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+ model_path = filename[7:]
+ root = modelcloud.Folder()
+ model_dir, model_name = osp.split(model_path)
+ try:
+ model = modelcloud.get(model_dir)
+ except exception.NodeNotFoundError:
+ model = root.create_training_model(model_dir)
+ with TemporaryDirectory() as tmp_dir:
+ checkpoint_file = osp.join(tmp_dir, model_name)
+ with open(checkpoint_file, 'wb') as f:
+ torch.save(checkpoint, f)
+ f.flush()
+ model.create_file(checkpoint_file, name=model_name)
+ else:
+ file_client = FileClient.infer_client(file_client_args, filename)
+ with io.BytesIO() as f:
+ torch.save(checkpoint, f)
+ file_client.put(f.getvalue(), filename)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/default_constructor.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/default_constructor.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdd7803289d6d70240977fa243d7f4432ccde8f8
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/default_constructor.py
@@ -0,0 +1,44 @@
+from .builder import RUNNER_BUILDERS, RUNNERS
+
+
+@RUNNER_BUILDERS.register_module()
+class DefaultRunnerConstructor:
+ """Default constructor for runners.
+
+ Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`.
+ For example, We can inject some new properties and functions for `Runner`.
+
+ Example:
+ >>> from annotator.mmpkg.mmcv.runner import RUNNER_BUILDERS, build_runner
+ >>> # Define a new RunnerReconstructor
+ >>> @RUNNER_BUILDERS.register_module()
+ >>> class MyRunnerConstructor:
+ ... def __init__(self, runner_cfg, default_args=None):
+ ... if not isinstance(runner_cfg, dict):
+ ... raise TypeError('runner_cfg should be a dict',
+ ... f'but got {type(runner_cfg)}')
+ ... self.runner_cfg = runner_cfg
+ ... self.default_args = default_args
+ ...
+ ... def __call__(self):
+ ... runner = RUNNERS.build(self.runner_cfg,
+ ... default_args=self.default_args)
+ ... # Add new properties for existing runner
+ ... runner.my_name = 'my_runner'
+ ... runner.my_function = lambda self: print(self.my_name)
+ ... ...
+ >>> # build your runner
+ >>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40,
+ ... constructor='MyRunnerConstructor')
+ >>> runner = build_runner(runner_cfg)
+ """
+
+ def __init__(self, runner_cfg, default_args=None):
+ if not isinstance(runner_cfg, dict):
+ raise TypeError('runner_cfg should be a dict',
+ f'but got {type(runner_cfg)}')
+ self.runner_cfg = runner_cfg
+ self.default_args = default_args
+
+ def __call__(self):
+ return RUNNERS.build(self.runner_cfg, default_args=self.default_args)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/dist_utils.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3a1ef3fda5ceeb31bf15a73779da1b1903ab0fe
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/dist_utils.py
@@ -0,0 +1,164 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+import os
+import subprocess
+from collections import OrderedDict
+
+import torch
+import torch.multiprocessing as mp
+from torch import distributed as dist
+from torch._utils import (_flatten_dense_tensors, _take_tensors,
+ _unflatten_dense_tensors)
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'mpi':
+ _init_dist_mpi(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ # TODO: use local_rank instead of rank % num_gpus
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_mpi(backend, **kwargs):
+ # TODO: use local_rank instead of rank % num_gpus
+ rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(
+ f'scontrol show hostname {node_list} | head -n1')
+ # specify master port
+ if port is not None:
+ os.environ['MASTER_PORT'] = str(port)
+ elif 'MASTER_PORT' in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ # use MASTER_ADDR in the environment variable if it already exists
+ if 'MASTER_ADDR' not in os.environ:
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def master_only(func):
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
+ """Allreduce parameters.
+
+ Args:
+ params (list[torch.Parameters]): List of parameters or buffers of a
+ model.
+ coalesce (bool, optional): Whether allreduce parameters as a whole.
+ Defaults to True.
+ bucket_size_mb (int, optional): Size of bucket, the unit is MB.
+ Defaults to -1.
+ """
+ _, world_size = get_dist_info()
+ if world_size == 1:
+ return
+ params = [param.data for param in params]
+ if coalesce:
+ _allreduce_coalesced(params, world_size, bucket_size_mb)
+ else:
+ for tensor in params:
+ dist.all_reduce(tensor.div_(world_size))
+
+
+def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
+ """Allreduce gradients.
+
+ Args:
+ params (list[torch.Parameters]): List of parameters of a model
+ coalesce (bool, optional): Whether allreduce parameters as a whole.
+ Defaults to True.
+ bucket_size_mb (int, optional): Size of bucket, the unit is MB.
+ Defaults to -1.
+ """
+ grads = [
+ param.grad.data for param in params
+ if param.requires_grad and param.grad is not None
+ ]
+ _, world_size = get_dist_info()
+ if world_size == 1:
+ return
+ if coalesce:
+ _allreduce_coalesced(grads, world_size, bucket_size_mb)
+ else:
+ for tensor in grads:
+ dist.all_reduce(tensor.div_(world_size))
+
+
+def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
+ if bucket_size_mb > 0:
+ bucket_size_bytes = bucket_size_mb * 1024 * 1024
+ buckets = _take_tensors(tensors, bucket_size_bytes)
+ else:
+ buckets = OrderedDict()
+ for tensor in tensors:
+ tp = tensor.type()
+ if tp not in buckets:
+ buckets[tp] = []
+ buckets[tp].append(tensor)
+ buckets = buckets.values()
+
+ for bucket in buckets:
+ flat_tensors = _flatten_dense_tensors(bucket)
+ dist.all_reduce(flat_tensors)
+ flat_tensors.div_(world_size)
+ for tensor, synced in zip(
+ bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
+ tensor.copy_(synced)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/epoch_based_runner.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/epoch_based_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4df071e1740baa4aea2951590ac929b3715daa2
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/epoch_based_runner.py
@@ -0,0 +1,187 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import platform
+import shutil
+import time
+import warnings
+
+import torch
+
+import annotator.mmpkg.mmcv as mmcv
+from .base_runner import BaseRunner
+from .builder import RUNNERS
+from .checkpoint import save_checkpoint
+from .utils import get_host_info
+
+
+@RUNNERS.register_module()
+class EpochBasedRunner(BaseRunner):
+ """Epoch-based Runner.
+
+ This runner train models epoch by epoch.
+ """
+
+ def run_iter(self, data_batch, train_mode, **kwargs):
+ if self.batch_processor is not None:
+ outputs = self.batch_processor(
+ self.model, data_batch, train_mode=train_mode, **kwargs)
+ elif train_mode:
+ outputs = self.model.train_step(data_batch, self.optimizer,
+ **kwargs)
+ else:
+ outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
+ if not isinstance(outputs, dict):
+ raise TypeError('"batch_processor()" or "model.train_step()"'
+ 'and "model.val_step()" must return a dict')
+ if 'log_vars' in outputs:
+ self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+ self.outputs = outputs
+
+ def train(self, data_loader, **kwargs):
+ self.model.train()
+ self.mode = 'train'
+ self.data_loader = data_loader
+ self._max_iters = self._max_epochs * len(self.data_loader)
+ self.call_hook('before_train_epoch')
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ for i, data_batch in enumerate(self.data_loader):
+ self._inner_iter = i
+ self.call_hook('before_train_iter')
+ self.run_iter(data_batch, train_mode=True, **kwargs)
+ self.call_hook('after_train_iter')
+ self._iter += 1
+
+ self.call_hook('after_train_epoch')
+ self._epoch += 1
+
+ @torch.no_grad()
+ def val(self, data_loader, **kwargs):
+ self.model.eval()
+ self.mode = 'val'
+ self.data_loader = data_loader
+ self.call_hook('before_val_epoch')
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ for i, data_batch in enumerate(self.data_loader):
+ self._inner_iter = i
+ self.call_hook('before_val_iter')
+ self.run_iter(data_batch, train_mode=False)
+ self.call_hook('after_val_iter')
+
+ self.call_hook('after_val_epoch')
+
+ def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
+ """Start running.
+
+ Args:
+ data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
+ and validation.
+ workflow (list[tuple]): A list of (phase, epochs) to specify the
+ running order and epochs. E.g, [('train', 2), ('val', 1)] means
+ running 2 epochs for training and 1 epoch for validation,
+ iteratively.
+ """
+ assert isinstance(data_loaders, list)
+ assert mmcv.is_list_of(workflow, tuple)
+ assert len(data_loaders) == len(workflow)
+ if max_epochs is not None:
+ warnings.warn(
+ 'setting max_epochs in run is deprecated, '
+ 'please set max_epochs in runner_config', DeprecationWarning)
+ self._max_epochs = max_epochs
+
+ assert self._max_epochs is not None, (
+ 'max_epochs must be specified during instantiation')
+
+ for i, flow in enumerate(workflow):
+ mode, epochs = flow
+ if mode == 'train':
+ self._max_iters = self._max_epochs * len(data_loaders[i])
+ break
+
+ work_dir = self.work_dir if self.work_dir is not None else 'NONE'
+ self.logger.info('Start running, host: %s, work_dir: %s',
+ get_host_info(), work_dir)
+ self.logger.info('Hooks will be executed in the following order:\n%s',
+ self.get_hook_info())
+ self.logger.info('workflow: %s, max: %d epochs', workflow,
+ self._max_epochs)
+ self.call_hook('before_run')
+
+ while self.epoch < self._max_epochs:
+ for i, flow in enumerate(workflow):
+ mode, epochs = flow
+ if isinstance(mode, str): # self.train()
+ if not hasattr(self, mode):
+ raise ValueError(
+ f'runner has no method named "{mode}" to run an '
+ 'epoch')
+ epoch_runner = getattr(self, mode)
+ else:
+ raise TypeError(
+ 'mode in workflow must be a str, but got {}'.format(
+ type(mode)))
+
+ for _ in range(epochs):
+ if mode == 'train' and self.epoch >= self._max_epochs:
+ break
+ epoch_runner(data_loaders[i], **kwargs)
+
+ time.sleep(1) # wait for some hooks like loggers to finish
+ self.call_hook('after_run')
+
+ def save_checkpoint(self,
+ out_dir,
+ filename_tmpl='epoch_{}.pth',
+ save_optimizer=True,
+ meta=None,
+ create_symlink=True):
+ """Save the checkpoint.
+
+ Args:
+ out_dir (str): The directory that checkpoints are saved.
+ filename_tmpl (str, optional): The checkpoint filename template,
+ which contains a placeholder for the epoch number.
+ Defaults to 'epoch_{}.pth'.
+ save_optimizer (bool, optional): Whether to save the optimizer to
+ the checkpoint. Defaults to True.
+ meta (dict, optional): The meta information to be saved in the
+ checkpoint. Defaults to None.
+ create_symlink (bool, optional): Whether to create a symlink
+ "latest.pth" to point to the latest checkpoint.
+ Defaults to True.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(
+ f'meta should be a dict or None, but got {type(meta)}')
+ if self.meta is not None:
+ meta.update(self.meta)
+ # Note: meta.update(self.meta) should be done before
+ # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
+ # there will be problems with resumed checkpoints.
+ # More details in https://github.com/open-mmlab/mmcv/pull/1108
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
+
+ filename = filename_tmpl.format(self.epoch + 1)
+ filepath = osp.join(out_dir, filename)
+ optimizer = self.optimizer if save_optimizer else None
+ save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
+ # in some environments, `os.symlink` is not supported, you may need to
+ # set `create_symlink` to False
+ if create_symlink:
+ dst_file = osp.join(out_dir, 'latest.pth')
+ if platform.system() != 'Windows':
+ mmcv.symlink(filename, dst_file)
+ else:
+ shutil.copy(filepath, dst_file)
+
+
+@RUNNERS.register_module()
+class Runner(EpochBasedRunner):
+ """Deprecated name of EpochBasedRunner."""
+
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ 'Runner was deprecated, please use EpochBasedRunner instead')
+ super().__init__(*args, **kwargs)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/fp16_utils.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/fp16_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6b54886519fd2808360b1632e5bebf6563eced2
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/fp16_utils.py
@@ -0,0 +1,410 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+import warnings
+from collections import abc
+from inspect import getfullargspec
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version
+from .dist_utils import allreduce_grads as _allreduce_grads
+
+try:
+ # If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported
+ # and used; otherwise, auto fp16 will adopt mmcv's implementation.
+ # Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16
+ # manually, so the behavior may not be consistent with real amp.
+ from torch.cuda.amp import autocast
+except ImportError:
+ pass
+
+
+def cast_tensor_type(inputs, src_type, dst_type):
+ """Recursively convert Tensor in inputs from src_type to dst_type.
+
+ Args:
+ inputs: Inputs that to be casted.
+ src_type (torch.dtype): Source type..
+ dst_type (torch.dtype): Destination type.
+
+ Returns:
+ The same type with inputs, but all contained Tensors have been cast.
+ """
+ if isinstance(inputs, nn.Module):
+ return inputs
+ elif isinstance(inputs, torch.Tensor):
+ return inputs.to(dst_type)
+ elif isinstance(inputs, str):
+ return inputs
+ elif isinstance(inputs, np.ndarray):
+ return inputs
+ elif isinstance(inputs, abc.Mapping):
+ return type(inputs)({
+ k: cast_tensor_type(v, src_type, dst_type)
+ for k, v in inputs.items()
+ })
+ elif isinstance(inputs, abc.Iterable):
+ return type(inputs)(
+ cast_tensor_type(item, src_type, dst_type) for item in inputs)
+ else:
+ return inputs
+
+
+def auto_fp16(apply_to=None, out_fp32=False):
+ """Decorator to enable fp16 training automatically.
+
+ This decorator is useful when you write custom modules and want to support
+ mixed precision training. If inputs arguments are fp32 tensors, they will
+ be converted to fp16 automatically. Arguments other than fp32 tensors are
+ ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
+ backend, otherwise, original mmcv implementation will be adopted.
+
+ Args:
+ apply_to (Iterable, optional): The argument names to be converted.
+ `None` indicates all arguments.
+ out_fp32 (bool): Whether to convert the output back to fp32.
+
+ Example:
+
+ >>> import torch.nn as nn
+ >>> class MyModule1(nn.Module):
+ >>>
+ >>> # Convert x and y to fp16
+ >>> @auto_fp16()
+ >>> def forward(self, x, y):
+ >>> pass
+
+ >>> import torch.nn as nn
+ >>> class MyModule2(nn.Module):
+ >>>
+ >>> # convert pred to fp16
+ >>> @auto_fp16(apply_to=('pred', ))
+ >>> def do_something(self, pred, others):
+ >>> pass
+ """
+
+ def auto_fp16_wrapper(old_func):
+
+ @functools.wraps(old_func)
+ def new_func(*args, **kwargs):
+ # check if the module has set the attribute `fp16_enabled`, if not,
+ # just fallback to the original method.
+ if not isinstance(args[0], torch.nn.Module):
+ raise TypeError('@auto_fp16 can only be used to decorate the '
+ 'method of nn.Module')
+ if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
+ return old_func(*args, **kwargs)
+
+ # get the arg spec of the decorated method
+ args_info = getfullargspec(old_func)
+ # get the argument names to be casted
+ args_to_cast = args_info.args if apply_to is None else apply_to
+ # convert the args that need to be processed
+ new_args = []
+ # NOTE: default args are not taken into consideration
+ if args:
+ arg_names = args_info.args[:len(args)]
+ for i, arg_name in enumerate(arg_names):
+ if arg_name in args_to_cast:
+ new_args.append(
+ cast_tensor_type(args[i], torch.float, torch.half))
+ else:
+ new_args.append(args[i])
+ # convert the kwargs that need to be processed
+ new_kwargs = {}
+ if kwargs:
+ for arg_name, arg_value in kwargs.items():
+ if arg_name in args_to_cast:
+ new_kwargs[arg_name] = cast_tensor_type(
+ arg_value, torch.float, torch.half)
+ else:
+ new_kwargs[arg_name] = arg_value
+ # apply converted arguments to the decorated method
+ if (TORCH_VERSION != 'parrots' and
+ digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+ with autocast(enabled=True):
+ output = old_func(*new_args, **new_kwargs)
+ else:
+ output = old_func(*new_args, **new_kwargs)
+ # cast the results back to fp32 if necessary
+ if out_fp32:
+ output = cast_tensor_type(output, torch.half, torch.float)
+ return output
+
+ return new_func
+
+ return auto_fp16_wrapper
+
+
+def force_fp32(apply_to=None, out_fp16=False):
+ """Decorator to convert input arguments to fp32 in force.
+
+ This decorator is useful when you write custom modules and want to support
+ mixed precision training. If there are some inputs that must be processed
+ in fp32 mode, then this decorator can handle it. If inputs arguments are
+ fp16 tensors, they will be converted to fp32 automatically. Arguments other
+ than fp16 tensors are ignored. If you are using PyTorch >= 1.6,
+ torch.cuda.amp is used as the backend, otherwise, original mmcv
+ implementation will be adopted.
+
+ Args:
+ apply_to (Iterable, optional): The argument names to be converted.
+ `None` indicates all arguments.
+ out_fp16 (bool): Whether to convert the output back to fp16.
+
+ Example:
+
+ >>> import torch.nn as nn
+ >>> class MyModule1(nn.Module):
+ >>>
+ >>> # Convert x and y to fp32
+ >>> @force_fp32()
+ >>> def loss(self, x, y):
+ >>> pass
+
+ >>> import torch.nn as nn
+ >>> class MyModule2(nn.Module):
+ >>>
+ >>> # convert pred to fp32
+ >>> @force_fp32(apply_to=('pred', ))
+ >>> def post_process(self, pred, others):
+ >>> pass
+ """
+
+ def force_fp32_wrapper(old_func):
+
+ @functools.wraps(old_func)
+ def new_func(*args, **kwargs):
+ # check if the module has set the attribute `fp16_enabled`, if not,
+ # just fallback to the original method.
+ if not isinstance(args[0], torch.nn.Module):
+ raise TypeError('@force_fp32 can only be used to decorate the '
+ 'method of nn.Module')
+ if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
+ return old_func(*args, **kwargs)
+ # get the arg spec of the decorated method
+ args_info = getfullargspec(old_func)
+ # get the argument names to be casted
+ args_to_cast = args_info.args if apply_to is None else apply_to
+ # convert the args that need to be processed
+ new_args = []
+ if args:
+ arg_names = args_info.args[:len(args)]
+ for i, arg_name in enumerate(arg_names):
+ if arg_name in args_to_cast:
+ new_args.append(
+ cast_tensor_type(args[i], torch.half, torch.float))
+ else:
+ new_args.append(args[i])
+ # convert the kwargs that need to be processed
+ new_kwargs = dict()
+ if kwargs:
+ for arg_name, arg_value in kwargs.items():
+ if arg_name in args_to_cast:
+ new_kwargs[arg_name] = cast_tensor_type(
+ arg_value, torch.half, torch.float)
+ else:
+ new_kwargs[arg_name] = arg_value
+ # apply converted arguments to the decorated method
+ if (TORCH_VERSION != 'parrots' and
+ digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+ with autocast(enabled=False):
+ output = old_func(*new_args, **new_kwargs)
+ else:
+ output = old_func(*new_args, **new_kwargs)
+ # cast the results back to fp32 if necessary
+ if out_fp16:
+ output = cast_tensor_type(output, torch.float, torch.half)
+ return output
+
+ return new_func
+
+ return force_fp32_wrapper
+
+
+def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
+ warnings.warning(
+ '"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
+ 'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads')
+ _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)
+
+
+def wrap_fp16_model(model):
+ """Wrap the FP32 model to FP16.
+
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
+ backend, otherwise, original mmcv implementation will be adopted.
+
+ For PyTorch >= 1.6, this function will
+ 1. Set fp16 flag inside the model to True.
+
+ Otherwise:
+ 1. Convert FP32 model to FP16.
+ 2. Remain some necessary layers to be FP32, e.g., normalization layers.
+ 3. Set `fp16_enabled` flag inside the model to True.
+
+ Args:
+ model (nn.Module): Model in FP32.
+ """
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.6.0')):
+ # convert model to fp16
+ model.half()
+ # patch the normalization layers to make it work in fp32 mode
+ patch_norm_fp32(model)
+ # set `fp16_enabled` flag
+ for m in model.modules():
+ if hasattr(m, 'fp16_enabled'):
+ m.fp16_enabled = True
+
+
+def patch_norm_fp32(module):
+ """Recursively convert normalization layers from FP16 to FP32.
+
+ Args:
+ module (nn.Module): The modules to be converted in FP16.
+
+ Returns:
+ nn.Module: The converted module, the normalization layers have been
+ converted to FP32.
+ """
+ if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
+ module.float()
+ if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3':
+ module.forward = patch_forward_method(module.forward, torch.half,
+ torch.float)
+ for child in module.children():
+ patch_norm_fp32(child)
+ return module
+
+
+def patch_forward_method(func, src_type, dst_type, convert_output=True):
+ """Patch the forward method of a module.
+
+ Args:
+ func (callable): The original forward method.
+ src_type (torch.dtype): Type of input arguments to be converted from.
+ dst_type (torch.dtype): Type of input arguments to be converted to.
+ convert_output (bool): Whether to convert the output back to src_type.
+
+ Returns:
+ callable: The patched forward method.
+ """
+
+ def new_forward(*args, **kwargs):
+ output = func(*cast_tensor_type(args, src_type, dst_type),
+ **cast_tensor_type(kwargs, src_type, dst_type))
+ if convert_output:
+ output = cast_tensor_type(output, dst_type, src_type)
+ return output
+
+ return new_forward
+
+
+class LossScaler:
+ """Class that manages loss scaling in mixed precision training which
+ supports both dynamic or static mode.
+
+ The implementation refers to
+ https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py.
+ Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling.
+ It's important to understand how :class:`LossScaler` operates.
+ Loss scaling is designed to combat the problem of underflowing
+ gradients encountered at long times when training fp16 networks.
+ Dynamic loss scaling begins by attempting a very high loss
+ scale. Ironically, this may result in OVERflowing gradients.
+ If overflowing gradients are encountered, :class:`FP16_Optimizer` then
+ skips the update step for this particular iteration/minibatch,
+ and :class:`LossScaler` adjusts the loss scale to a lower value.
+ If a certain number of iterations occur without overflowing gradients
+ detected,:class:`LossScaler` increases the loss scale once more.
+ In this way :class:`LossScaler` attempts to "ride the edge" of always
+ using the highest loss scale possible without incurring overflow.
+
+ Args:
+ init_scale (float): Initial loss scale value, default: 2**32.
+ scale_factor (float): Factor used when adjusting the loss scale.
+ Default: 2.
+ mode (str): Loss scaling mode. 'dynamic' or 'static'
+ scale_window (int): Number of consecutive iterations without an
+ overflow to wait before increasing the loss scale. Default: 1000.
+ """
+
+ def __init__(self,
+ init_scale=2**32,
+ mode='dynamic',
+ scale_factor=2.,
+ scale_window=1000):
+ self.cur_scale = init_scale
+ self.cur_iter = 0
+ assert mode in ('dynamic',
+ 'static'), 'mode can only be dynamic or static'
+ self.mode = mode
+ self.last_overflow_iter = -1
+ self.scale_factor = scale_factor
+ self.scale_window = scale_window
+
+ def has_overflow(self, params):
+ """Check if params contain overflow."""
+ if self.mode != 'dynamic':
+ return False
+ for p in params:
+ if p.grad is not None and LossScaler._has_inf_or_nan(p.grad.data):
+ return True
+ return False
+
+ def _has_inf_or_nan(x):
+ """Check if params contain NaN."""
+ try:
+ cpu_sum = float(x.float().sum())
+ except RuntimeError as instance:
+ if 'value cannot be converted' not in instance.args[0]:
+ raise
+ return True
+ else:
+ if cpu_sum == float('inf') or cpu_sum == -float('inf') \
+ or cpu_sum != cpu_sum:
+ return True
+ return False
+
+ def update_scale(self, overflow):
+ """update the current loss scale value when overflow happens."""
+ if self.mode != 'dynamic':
+ return
+ if overflow:
+ self.cur_scale = max(self.cur_scale / self.scale_factor, 1)
+ self.last_overflow_iter = self.cur_iter
+ else:
+ if (self.cur_iter - self.last_overflow_iter) % \
+ self.scale_window == 0:
+ self.cur_scale *= self.scale_factor
+ self.cur_iter += 1
+
+ def state_dict(self):
+ """Returns the state of the scaler as a :class:`dict`."""
+ return dict(
+ cur_scale=self.cur_scale,
+ cur_iter=self.cur_iter,
+ mode=self.mode,
+ last_overflow_iter=self.last_overflow_iter,
+ scale_factor=self.scale_factor,
+ scale_window=self.scale_window)
+
+ def load_state_dict(self, state_dict):
+ """Loads the loss_scaler state dict.
+
+ Args:
+ state_dict (dict): scaler state.
+ """
+ self.cur_scale = state_dict['cur_scale']
+ self.cur_iter = state_dict['cur_iter']
+ self.mode = state_dict['mode']
+ self.last_overflow_iter = state_dict['last_overflow_iter']
+ self.scale_factor = state_dict['scale_factor']
+ self.scale_window = state_dict['scale_window']
+
+ @property
+ def loss_scale(self):
+ return self.cur_scale
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..915af28cefab14a14c1188ed861161080fd138a3
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/__init__.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .checkpoint import CheckpointHook
+from .closure import ClosureHook
+from .ema import EMAHook
+from .evaluation import DistEvalHook, EvalHook
+from .hook import HOOKS, Hook
+from .iter_timer import IterTimerHook
+from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook,
+ NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook,
+ TextLoggerHook, WandbLoggerHook)
+from .lr_updater import LrUpdaterHook
+from .memory import EmptyCacheHook
+from .momentum_updater import MomentumUpdaterHook
+from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
+ GradientCumulativeOptimizerHook, OptimizerHook)
+from .profiler import ProfilerHook
+from .sampler_seed import DistSamplerSeedHook
+from .sync_buffer import SyncBuffersHook
+
+__all__ = [
+ 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
+ 'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
+ 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
+ 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
+ 'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook',
+ 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook',
+ 'DistEvalHook', 'ProfilerHook', 'GradientCumulativeOptimizerHook',
+ 'GradientCumulativeFp16OptimizerHook'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/checkpoint.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..877aa8b84ac48bea0a06f9d0733d74f88be2ecfc
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/checkpoint.py
@@ -0,0 +1,167 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import warnings
+
+from annotator.mmpkg.mmcv.fileio import FileClient
+from ..dist_utils import allreduce_params, master_only
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class CheckpointHook(Hook):
+ """Save checkpoints periodically.
+
+ Args:
+ interval (int): The saving period. If ``by_epoch=True``, interval
+ indicates epochs, otherwise it indicates iterations.
+ Default: -1, which means "never".
+ by_epoch (bool): Saving checkpoints by epoch or by iteration.
+ Default: True.
+ save_optimizer (bool): Whether to save optimizer state_dict in the
+ checkpoint. It is usually used for resuming experiments.
+ Default: True.
+ out_dir (str, optional): The root directory to save checkpoints. If not
+ specified, ``runner.work_dir`` will be used by default. If
+ specified, the ``out_dir`` will be the concatenation of ``out_dir``
+ and the last level directory of ``runner.work_dir``.
+ `Changed in version 1.3.16.`
+ max_keep_ckpts (int, optional): The maximum checkpoints to keep.
+ In some cases we want only the latest few checkpoints and would
+ like to delete old ones to save the disk space.
+ Default: -1, which means unlimited.
+ save_last (bool, optional): Whether to force the last checkpoint to be
+ saved regardless of interval. Default: True.
+ sync_buffer (bool, optional): Whether to synchronize buffers in
+ different gpus. Default: False.
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ `New in version 1.3.16.`
+
+ .. warning::
+ Before v1.3.16, the ``out_dir`` argument indicates the path where the
+ checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the
+ root directory and the final path to save checkpoint is the
+ concatenation of ``out_dir`` and the last level directory of
+ ``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A"
+ and the value of ``runner.work_dir`` is "/path/of/B", then the final
+ path will be "/path/of/A/B".
+ """
+
+ def __init__(self,
+ interval=-1,
+ by_epoch=True,
+ save_optimizer=True,
+ out_dir=None,
+ max_keep_ckpts=-1,
+ save_last=True,
+ sync_buffer=False,
+ file_client_args=None,
+ **kwargs):
+ self.interval = interval
+ self.by_epoch = by_epoch
+ self.save_optimizer = save_optimizer
+ self.out_dir = out_dir
+ self.max_keep_ckpts = max_keep_ckpts
+ self.save_last = save_last
+ self.args = kwargs
+ self.sync_buffer = sync_buffer
+ self.file_client_args = file_client_args
+
+ def before_run(self, runner):
+ if not self.out_dir:
+ self.out_dir = runner.work_dir
+
+ self.file_client = FileClient.infer_client(self.file_client_args,
+ self.out_dir)
+
+ # if `self.out_dir` is not equal to `runner.work_dir`, it means that
+ # `self.out_dir` is set so the final `self.out_dir` is the
+ # concatenation of `self.out_dir` and the last level directory of
+ # `runner.work_dir`
+ if self.out_dir != runner.work_dir:
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
+
+ runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
+ f'{self.file_client.name}.'))
+
+ # disable the create_symlink option because some file backends do not
+ # allow to create a symlink
+ if 'create_symlink' in self.args:
+ if self.args[
+ 'create_symlink'] and not self.file_client.allow_symlink:
+ self.args['create_symlink'] = False
+ warnings.warn(
+ ('create_symlink is set as True by the user but is changed'
+ 'to be False because creating symbolic link is not '
+ f'allowed in {self.file_client.name}'))
+ else:
+ self.args['create_symlink'] = self.file_client.allow_symlink
+
+ def after_train_epoch(self, runner):
+ if not self.by_epoch:
+ return
+
+ # save checkpoint for following cases:
+ # 1. every ``self.interval`` epochs
+ # 2. reach the last epoch of training
+ if self.every_n_epochs(
+ runner, self.interval) or (self.save_last
+ and self.is_last_epoch(runner)):
+ runner.logger.info(
+ f'Saving checkpoint at {runner.epoch + 1} epochs')
+ if self.sync_buffer:
+ allreduce_params(runner.model.buffers())
+ self._save_checkpoint(runner)
+
+ @master_only
+ def _save_checkpoint(self, runner):
+ """Save the current checkpoint and delete unwanted checkpoint."""
+ runner.save_checkpoint(
+ self.out_dir, save_optimizer=self.save_optimizer, **self.args)
+ if runner.meta is not None:
+ if self.by_epoch:
+ cur_ckpt_filename = self.args.get(
+ 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
+ else:
+ cur_ckpt_filename = self.args.get(
+ 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
+ runner.meta.setdefault('hook_msgs', dict())
+ runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
+ self.out_dir, cur_ckpt_filename)
+ # remove other checkpoints
+ if self.max_keep_ckpts > 0:
+ if self.by_epoch:
+ name = 'epoch_{}.pth'
+ current_ckpt = runner.epoch + 1
+ else:
+ name = 'iter_{}.pth'
+ current_ckpt = runner.iter + 1
+ redundant_ckpts = range(
+ current_ckpt - self.max_keep_ckpts * self.interval, 0,
+ -self.interval)
+ filename_tmpl = self.args.get('filename_tmpl', name)
+ for _step in redundant_ckpts:
+ ckpt_path = self.file_client.join_path(
+ self.out_dir, filename_tmpl.format(_step))
+ if self.file_client.isfile(ckpt_path):
+ self.file_client.remove(ckpt_path)
+ else:
+ break
+
+ def after_train_iter(self, runner):
+ if self.by_epoch:
+ return
+
+ # save checkpoint for following cases:
+ # 1. every ``self.interval`` iterations
+ # 2. reach the last iteration of training
+ if self.every_n_iters(
+ runner, self.interval) or (self.save_last
+ and self.is_last_iter(runner)):
+ runner.logger.info(
+ f'Saving checkpoint at {runner.iter + 1} iterations')
+ if self.sync_buffer:
+ allreduce_params(runner.model.buffers())
+ self._save_checkpoint(runner)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/closure.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/closure.py
new file mode 100644
index 0000000000000000000000000000000000000000..b955f81f425be4ac3e6bb3f4aac653887989e872
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/closure.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class ClosureHook(Hook):
+
+ def __init__(self, fn_name, fn):
+ assert hasattr(self, fn_name)
+ assert callable(fn)
+ setattr(self, fn_name, fn)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/ema.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..15c7e68088f019802a59e7ae41cc1fe0c7f28f96
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/ema.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...parallel import is_module_wrapper
+from ..hooks.hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class EMAHook(Hook):
+ r"""Exponential Moving Average Hook.
+
+ Use Exponential Moving Average on all parameters of model in training
+ process. All parameters have a ema backup, which update by the formula
+ as below. EMAHook takes priority over EvalHook and CheckpointSaverHook.
+
+ .. math::
+
+ \text{Xema\_{t+1}} = (1 - \text{momentum}) \times
+ \text{Xema\_{t}} + \text{momentum} \times X_t
+
+ Args:
+ momentum (float): The momentum used for updating ema parameter.
+ Defaults to 0.0002.
+ interval (int): Update ema parameter every interval iteration.
+ Defaults to 1.
+ warm_up (int): During first warm_up steps, we may use smaller momentum
+ to update ema parameters more slowly. Defaults to 100.
+ resume_from (str): The checkpoint path. Defaults to None.
+ """
+
+ def __init__(self,
+ momentum=0.0002,
+ interval=1,
+ warm_up=100,
+ resume_from=None):
+ assert isinstance(interval, int) and interval > 0
+ self.warm_up = warm_up
+ self.interval = interval
+ assert momentum > 0 and momentum < 1
+ self.momentum = momentum**interval
+ self.checkpoint = resume_from
+
+ def before_run(self, runner):
+ """To resume model with it's ema parameters more friendly.
+
+ Register ema parameter as ``named_buffer`` to model
+ """
+ model = runner.model
+ if is_module_wrapper(model):
+ model = model.module
+ self.param_ema_buffer = {}
+ self.model_parameters = dict(model.named_parameters(recurse=True))
+ for name, value in self.model_parameters.items():
+ # "." is not allowed in module's buffer name
+ buffer_name = f"ema_{name.replace('.', '_')}"
+ self.param_ema_buffer[name] = buffer_name
+ model.register_buffer(buffer_name, value.data.clone())
+ self.model_buffers = dict(model.named_buffers(recurse=True))
+ if self.checkpoint is not None:
+ runner.resume(self.checkpoint)
+
+ def after_train_iter(self, runner):
+ """Update ema parameter every self.interval iterations."""
+ curr_step = runner.iter
+ # We warm up the momentum considering the instability at beginning
+ momentum = min(self.momentum,
+ (1 + curr_step) / (self.warm_up + curr_step))
+ if curr_step % self.interval != 0:
+ return
+ for name, parameter in self.model_parameters.items():
+ buffer_name = self.param_ema_buffer[name]
+ buffer_parameter = self.model_buffers[buffer_name]
+ buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data)
+
+ def after_train_epoch(self, runner):
+ """We load parameter values from ema backup to model before the
+ EvalHook."""
+ self._swap_ema_parameters()
+
+ def before_train_epoch(self, runner):
+ """We recover model's parameter from ema backup after last epoch's
+ EvalHook."""
+ self._swap_ema_parameters()
+
+ def _swap_ema_parameters(self):
+ """Swap the parameter of model with parameter in ema_buffer."""
+ for name, value in self.model_parameters.items():
+ temp = value.data.clone()
+ ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
+ value.data.copy_(ema_buffer.data)
+ ema_buffer.data.copy_(temp)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/evaluation.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1dbdfd593bae505a70534226b79791baec6453e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/evaluation.py
@@ -0,0 +1,509 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import warnings
+from math import inf
+
+import torch.distributed as dist
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.utils.data import DataLoader
+
+from annotator.mmpkg.mmcv.fileio import FileClient
+from annotator.mmpkg.mmcv.utils import is_seq_of
+from .hook import Hook
+from .logger import LoggerHook
+
+
+class EvalHook(Hook):
+ """Non-Distributed evaluation hook.
+
+ This hook will regularly perform evaluation in a given interval when
+ performing in non-distributed environment.
+
+ Args:
+ dataloader (DataLoader): A PyTorch dataloader, whose dataset has
+ implemented ``evaluate`` function.
+ start (int | None, optional): Evaluation starting epoch. It enables
+ evaluation before the training starts if ``start`` <= the resuming
+ epoch. If None, whether to evaluate is merely decided by
+ ``interval``. Default: None.
+ interval (int): Evaluation interval. Default: 1.
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ Default: True.
+ save_best (str, optional): If a metric is specified, it would measure
+ the best checkpoint during evaluation. The information about best
+ checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
+ best score value and best checkpoint path, which will be also
+ loaded when resume checkpoint. Options are the evaluation metrics
+ on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
+ detection and instance segmentation. ``AR@100`` for proposal
+ recall. If ``save_best`` is ``auto``, the first key of the returned
+ ``OrderedDict`` result will be used. Default: None.
+ rule (str | None, optional): Comparison rule for best score. If set to
+ None, it will infer a reasonable rule. Keys such as 'acc', 'top'
+ .etc will be inferred by 'greater' rule. Keys contain 'loss' will
+ be inferred by 'less' rule. Options are 'greater', 'less', None.
+ Default: None.
+ test_fn (callable, optional): test a model with samples from a
+ dataloader, and return the test results. If ``None``, the default
+ test function ``mmcv.engine.single_gpu_test`` will be used.
+ (default: ``None``)
+ greater_keys (List[str] | None, optional): Metric keys that will be
+ inferred by 'greater' comparison rule. If ``None``,
+ _default_greater_keys will be used. (default: ``None``)
+ less_keys (List[str] | None, optional): Metric keys that will be
+ inferred by 'less' comparison rule. If ``None``, _default_less_keys
+ will be used. (default: ``None``)
+ out_dir (str, optional): The root directory to save checkpoints. If not
+ specified, `runner.work_dir` will be used by default. If specified,
+ the `out_dir` will be the concatenation of `out_dir` and the last
+ level directory of `runner.work_dir`.
+ `New in version 1.3.16.`
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details. Default: None.
+ `New in version 1.3.16.`
+ **eval_kwargs: Evaluation arguments fed into the evaluate function of
+ the dataset.
+
+ Notes:
+ If new arguments are added for EvalHook, tools/test.py,
+ tools/eval_metric.py may be affected.
+ """
+
+ # Since the key for determine greater or less is related to the downstream
+ # tasks, downstream repos may need to overwrite the following inner
+ # variable accordingly.
+
+ rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
+ init_value_map = {'greater': -inf, 'less': inf}
+ _default_greater_keys = [
+ 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
+ 'mAcc', 'aAcc'
+ ]
+ _default_less_keys = ['loss']
+
+ def __init__(self,
+ dataloader,
+ start=None,
+ interval=1,
+ by_epoch=True,
+ save_best=None,
+ rule=None,
+ test_fn=None,
+ greater_keys=None,
+ less_keys=None,
+ out_dir=None,
+ file_client_args=None,
+ **eval_kwargs):
+ if not isinstance(dataloader, DataLoader):
+ raise TypeError(f'dataloader must be a pytorch DataLoader, '
+ f'but got {type(dataloader)}')
+
+ if interval <= 0:
+ raise ValueError(f'interval must be a positive number, '
+ f'but got {interval}')
+
+ assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean'
+
+ if start is not None and start < 0:
+ raise ValueError(f'The evaluation start epoch {start} is smaller '
+ f'than 0')
+
+ self.dataloader = dataloader
+ self.interval = interval
+ self.start = start
+ self.by_epoch = by_epoch
+
+ assert isinstance(save_best, str) or save_best is None, \
+ '""save_best"" should be a str or None ' \
+ f'rather than {type(save_best)}'
+ self.save_best = save_best
+ self.eval_kwargs = eval_kwargs
+ self.initial_flag = True
+
+ if test_fn is None:
+ from annotator.mmpkg.mmcv.engine import single_gpu_test
+ self.test_fn = single_gpu_test
+ else:
+ self.test_fn = test_fn
+
+ if greater_keys is None:
+ self.greater_keys = self._default_greater_keys
+ else:
+ if not isinstance(greater_keys, (list, tuple)):
+ greater_keys = (greater_keys, )
+ assert is_seq_of(greater_keys, str)
+ self.greater_keys = greater_keys
+
+ if less_keys is None:
+ self.less_keys = self._default_less_keys
+ else:
+ if not isinstance(less_keys, (list, tuple)):
+ less_keys = (less_keys, )
+ assert is_seq_of(less_keys, str)
+ self.less_keys = less_keys
+
+ if self.save_best is not None:
+ self.best_ckpt_path = None
+ self._init_rule(rule, self.save_best)
+
+ self.out_dir = out_dir
+ self.file_client_args = file_client_args
+
+ def _init_rule(self, rule, key_indicator):
+ """Initialize rule, key_indicator, comparison_func, and best score.
+
+ Here is the rule to determine which rule is used for key indicator
+ when the rule is not specific (note that the key indicator matching
+ is case-insensitive):
+ 1. If the key indicator is in ``self.greater_keys``, the rule will be
+ specified as 'greater'.
+ 2. Or if the key indicator is in ``self.less_keys``, the rule will be
+ specified as 'less'.
+ 3. Or if the key indicator is equal to the substring in any one item
+ in ``self.greater_keys``, the rule will be specified as 'greater'.
+ 4. Or if the key indicator is equal to the substring in any one item
+ in ``self.less_keys``, the rule will be specified as 'less'.
+
+ Args:
+ rule (str | None): Comparison rule for best score.
+ key_indicator (str | None): Key indicator to determine the
+ comparison rule.
+ """
+ if rule not in self.rule_map and rule is not None:
+ raise KeyError(f'rule must be greater, less or None, '
+ f'but got {rule}.')
+
+ if rule is None:
+ if key_indicator != 'auto':
+ # `_lc` here means we use the lower case of keys for
+ # case-insensitive matching
+ key_indicator_lc = key_indicator.lower()
+ greater_keys = [key.lower() for key in self.greater_keys]
+ less_keys = [key.lower() for key in self.less_keys]
+
+ if key_indicator_lc in greater_keys:
+ rule = 'greater'
+ elif key_indicator_lc in less_keys:
+ rule = 'less'
+ elif any(key in key_indicator_lc for key in greater_keys):
+ rule = 'greater'
+ elif any(key in key_indicator_lc for key in less_keys):
+ rule = 'less'
+ else:
+ raise ValueError(f'Cannot infer the rule for key '
+ f'{key_indicator}, thus a specific rule '
+ f'must be specified.')
+ self.rule = rule
+ self.key_indicator = key_indicator
+ if self.rule is not None:
+ self.compare_func = self.rule_map[self.rule]
+
+ def before_run(self, runner):
+ if not self.out_dir:
+ self.out_dir = runner.work_dir
+
+ self.file_client = FileClient.infer_client(self.file_client_args,
+ self.out_dir)
+
+ # if `self.out_dir` is not equal to `runner.work_dir`, it means that
+ # `self.out_dir` is set so the final `self.out_dir` is the
+ # concatenation of `self.out_dir` and the last level directory of
+ # `runner.work_dir`
+ if self.out_dir != runner.work_dir:
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
+ runner.logger.info(
+ (f'The best checkpoint will be saved to {self.out_dir} by '
+ f'{self.file_client.name}'))
+
+ if self.save_best is not None:
+ if runner.meta is None:
+ warnings.warn('runner.meta is None. Creating an empty one.')
+ runner.meta = dict()
+ runner.meta.setdefault('hook_msgs', dict())
+ self.best_ckpt_path = runner.meta['hook_msgs'].get(
+ 'best_ckpt', None)
+
+ def before_train_iter(self, runner):
+ """Evaluate the model only at the start of training by iteration."""
+ if self.by_epoch or not self.initial_flag:
+ return
+ if self.start is not None and runner.iter >= self.start:
+ self.after_train_iter(runner)
+ self.initial_flag = False
+
+ def before_train_epoch(self, runner):
+ """Evaluate the model only at the start of training by epoch."""
+ if not (self.by_epoch and self.initial_flag):
+ return
+ if self.start is not None and runner.epoch >= self.start:
+ self.after_train_epoch(runner)
+ self.initial_flag = False
+
+ def after_train_iter(self, runner):
+ """Called after every training iter to evaluate the results."""
+ if not self.by_epoch and self._should_evaluate(runner):
+ # Because the priority of EvalHook is higher than LoggerHook, the
+ # training log and the evaluating log are mixed. Therefore,
+ # we need to dump the training log and clear it before evaluating
+ # log is generated. In addition, this problem will only appear in
+ # `IterBasedRunner` whose `self.by_epoch` is False, because
+ # `EpochBasedRunner` whose `self.by_epoch` is True calls
+ # `_do_evaluate` in `after_train_epoch` stage, and at this stage
+ # the training log has been printed, so it will not cause any
+ # problem. more details at
+ # https://github.com/open-mmlab/mmsegmentation/issues/694
+ for hook in runner._hooks:
+ if isinstance(hook, LoggerHook):
+ hook.after_train_iter(runner)
+ runner.log_buffer.clear()
+
+ self._do_evaluate(runner)
+
+ def after_train_epoch(self, runner):
+ """Called after every training epoch to evaluate the results."""
+ if self.by_epoch and self._should_evaluate(runner):
+ self._do_evaluate(runner)
+
+ def _do_evaluate(self, runner):
+ """perform evaluation and save ckpt."""
+ results = self.test_fn(runner.model, self.dataloader)
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
+ key_score = self.evaluate(runner, results)
+ # the key_score may be `None` so it needs to skip the action to save
+ # the best checkpoint
+ if self.save_best and key_score:
+ self._save_ckpt(runner, key_score)
+
+ def _should_evaluate(self, runner):
+ """Judge whether to perform evaluation.
+
+ Here is the rule to judge whether to perform evaluation:
+ 1. It will not perform evaluation during the epoch/iteration interval,
+ which is determined by ``self.interval``.
+ 2. It will not perform evaluation if the start time is larger than
+ current time.
+ 3. It will not perform evaluation when current time is larger than
+ the start time but during epoch/iteration interval.
+
+ Returns:
+ bool: The flag indicating whether to perform evaluation.
+ """
+ if self.by_epoch:
+ current = runner.epoch
+ check_time = self.every_n_epochs
+ else:
+ current = runner.iter
+ check_time = self.every_n_iters
+
+ if self.start is None:
+ if not check_time(runner, self.interval):
+ # No evaluation during the interval.
+ return False
+ elif (current + 1) < self.start:
+ # No evaluation if start is larger than the current time.
+ return False
+ else:
+ # Evaluation only at epochs/iters 3, 5, 7...
+ # if start==3 and interval==2
+ if (current + 1 - self.start) % self.interval:
+ return False
+ return True
+
+ def _save_ckpt(self, runner, key_score):
+ """Save the best checkpoint.
+
+ It will compare the score according to the compare function, write
+ related information (best score, best checkpoint path) and save the
+ best checkpoint into ``work_dir``.
+ """
+ if self.by_epoch:
+ current = f'epoch_{runner.epoch + 1}'
+ cur_type, cur_time = 'epoch', runner.epoch + 1
+ else:
+ current = f'iter_{runner.iter + 1}'
+ cur_type, cur_time = 'iter', runner.iter + 1
+
+ best_score = runner.meta['hook_msgs'].get(
+ 'best_score', self.init_value_map[self.rule])
+ if self.compare_func(key_score, best_score):
+ best_score = key_score
+ runner.meta['hook_msgs']['best_score'] = best_score
+
+ if self.best_ckpt_path and self.file_client.isfile(
+ self.best_ckpt_path):
+ self.file_client.remove(self.best_ckpt_path)
+ runner.logger.info(
+ (f'The previous best checkpoint {self.best_ckpt_path} was '
+ 'removed'))
+
+ best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
+ self.best_ckpt_path = self.file_client.join_path(
+ self.out_dir, best_ckpt_name)
+ runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path
+
+ runner.save_checkpoint(
+ self.out_dir, best_ckpt_name, create_symlink=False)
+ runner.logger.info(
+ f'Now best checkpoint is saved as {best_ckpt_name}.')
+ runner.logger.info(
+ f'Best {self.key_indicator} is {best_score:0.4f} '
+ f'at {cur_time} {cur_type}.')
+
+ def evaluate(self, runner, results):
+ """Evaluate the results.
+
+ Args:
+ runner (:obj:`mmcv.Runner`): The underlined training runner.
+ results (list): Output results.
+ """
+ eval_res = self.dataloader.dataset.evaluate(
+ results, logger=runner.logger, **self.eval_kwargs)
+
+ for name, val in eval_res.items():
+ runner.log_buffer.output[name] = val
+ runner.log_buffer.ready = True
+
+ if self.save_best is not None:
+ # If the performance of model is pool, the `eval_res` may be an
+ # empty dict and it will raise exception when `self.save_best` is
+ # not None. More details at
+ # https://github.com/open-mmlab/mmdetection/issues/6265.
+ if not eval_res:
+ warnings.warn(
+ 'Since `eval_res` is an empty dict, the behavior to save '
+ 'the best checkpoint will be skipped in this evaluation.')
+ return None
+
+ if self.key_indicator == 'auto':
+ # infer from eval_results
+ self._init_rule(self.rule, list(eval_res.keys())[0])
+ return eval_res[self.key_indicator]
+
+ return None
+
+
+class DistEvalHook(EvalHook):
+ """Distributed evaluation hook.
+
+ This hook will regularly perform evaluation in a given interval when
+ performing in distributed environment.
+
+ Args:
+ dataloader (DataLoader): A PyTorch dataloader, whose dataset has
+ implemented ``evaluate`` function.
+ start (int | None, optional): Evaluation starting epoch. It enables
+ evaluation before the training starts if ``start`` <= the resuming
+ epoch. If None, whether to evaluate is merely decided by
+ ``interval``. Default: None.
+ interval (int): Evaluation interval. Default: 1.
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ default: True.
+ save_best (str, optional): If a metric is specified, it would measure
+ the best checkpoint during evaluation. The information about best
+ checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
+ best score value and best checkpoint path, which will be also
+ loaded when resume checkpoint. Options are the evaluation metrics
+ on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
+ detection and instance segmentation. ``AR@100`` for proposal
+ recall. If ``save_best`` is ``auto``, the first key of the returned
+ ``OrderedDict`` result will be used. Default: None.
+ rule (str | None, optional): Comparison rule for best score. If set to
+ None, it will infer a reasonable rule. Keys such as 'acc', 'top'
+ .etc will be inferred by 'greater' rule. Keys contain 'loss' will
+ be inferred by 'less' rule. Options are 'greater', 'less', None.
+ Default: None.
+ test_fn (callable, optional): test a model with samples from a
+ dataloader in a multi-gpu manner, and return the test results. If
+ ``None``, the default test function ``mmcv.engine.multi_gpu_test``
+ will be used. (default: ``None``)
+ tmpdir (str | None): Temporary directory to save the results of all
+ processes. Default: None.
+ gpu_collect (bool): Whether to use gpu or cpu to collect results.
+ Default: False.
+ broadcast_bn_buffer (bool): Whether to broadcast the
+ buffer(running_mean and running_var) of rank 0 to other rank
+ before evaluation. Default: True.
+ out_dir (str, optional): The root directory to save checkpoints. If not
+ specified, `runner.work_dir` will be used by default. If specified,
+ the `out_dir` will be the concatenation of `out_dir` and the last
+ level directory of `runner.work_dir`.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details. Default: None.
+ **eval_kwargs: Evaluation arguments fed into the evaluate function of
+ the dataset.
+ """
+
+ def __init__(self,
+ dataloader,
+ start=None,
+ interval=1,
+ by_epoch=True,
+ save_best=None,
+ rule=None,
+ test_fn=None,
+ greater_keys=None,
+ less_keys=None,
+ broadcast_bn_buffer=True,
+ tmpdir=None,
+ gpu_collect=False,
+ out_dir=None,
+ file_client_args=None,
+ **eval_kwargs):
+
+ if test_fn is None:
+ from annotator.mmpkg.mmcv.engine import multi_gpu_test
+ test_fn = multi_gpu_test
+
+ super().__init__(
+ dataloader,
+ start=start,
+ interval=interval,
+ by_epoch=by_epoch,
+ save_best=save_best,
+ rule=rule,
+ test_fn=test_fn,
+ greater_keys=greater_keys,
+ less_keys=less_keys,
+ out_dir=out_dir,
+ file_client_args=file_client_args,
+ **eval_kwargs)
+
+ self.broadcast_bn_buffer = broadcast_bn_buffer
+ self.tmpdir = tmpdir
+ self.gpu_collect = gpu_collect
+
+ def _do_evaluate(self, runner):
+ """perform evaluation and save ckpt."""
+ # Synchronization of BatchNorm's buffer (running_mean
+ # and running_var) is not supported in the DDP of pytorch,
+ # which may cause the inconsistent performance of models in
+ # different ranks, so we broadcast BatchNorm's buffers
+ # of rank 0 to other ranks to avoid this.
+ if self.broadcast_bn_buffer:
+ model = runner.model
+ for name, module in model.named_modules():
+ if isinstance(module,
+ _BatchNorm) and module.track_running_stats:
+ dist.broadcast(module.running_var, 0)
+ dist.broadcast(module.running_mean, 0)
+
+ tmpdir = self.tmpdir
+ if tmpdir is None:
+ tmpdir = osp.join(runner.work_dir, '.eval_hook')
+
+ results = self.test_fn(
+ runner.model,
+ self.dataloader,
+ tmpdir=tmpdir,
+ gpu_collect=self.gpu_collect)
+ if runner.rank == 0:
+ print('\n')
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
+ key_score = self.evaluate(runner, results)
+ # the key_score may be `None` so it needs to skip the action to
+ # save the best checkpoint
+ if self.save_best and key_score:
+ self._save_ckpt(runner, key_score)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/hook.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd31f985fee739ccb7ac62eefc6cef9f0c0d65d0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/hook.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from annotator.mmpkg.mmcv.utils import Registry, is_method_overridden
+
+HOOKS = Registry('hook')
+
+
+class Hook:
+ stages = ('before_run', 'before_train_epoch', 'before_train_iter',
+ 'after_train_iter', 'after_train_epoch', 'before_val_epoch',
+ 'before_val_iter', 'after_val_iter', 'after_val_epoch',
+ 'after_run')
+
+ def before_run(self, runner):
+ pass
+
+ def after_run(self, runner):
+ pass
+
+ def before_epoch(self, runner):
+ pass
+
+ def after_epoch(self, runner):
+ pass
+
+ def before_iter(self, runner):
+ pass
+
+ def after_iter(self, runner):
+ pass
+
+ def before_train_epoch(self, runner):
+ self.before_epoch(runner)
+
+ def before_val_epoch(self, runner):
+ self.before_epoch(runner)
+
+ def after_train_epoch(self, runner):
+ self.after_epoch(runner)
+
+ def after_val_epoch(self, runner):
+ self.after_epoch(runner)
+
+ def before_train_iter(self, runner):
+ self.before_iter(runner)
+
+ def before_val_iter(self, runner):
+ self.before_iter(runner)
+
+ def after_train_iter(self, runner):
+ self.after_iter(runner)
+
+ def after_val_iter(self, runner):
+ self.after_iter(runner)
+
+ def every_n_epochs(self, runner, n):
+ return (runner.epoch + 1) % n == 0 if n > 0 else False
+
+ def every_n_inner_iters(self, runner, n):
+ return (runner.inner_iter + 1) % n == 0 if n > 0 else False
+
+ def every_n_iters(self, runner, n):
+ return (runner.iter + 1) % n == 0 if n > 0 else False
+
+ def end_of_epoch(self, runner):
+ return runner.inner_iter + 1 == len(runner.data_loader)
+
+ def is_last_epoch(self, runner):
+ return runner.epoch + 1 == runner._max_epochs
+
+ def is_last_iter(self, runner):
+ return runner.iter + 1 == runner._max_iters
+
+ def get_triggered_stages(self):
+ trigger_stages = set()
+ for stage in Hook.stages:
+ if is_method_overridden(stage, Hook, self):
+ trigger_stages.add(stage)
+
+ # some methods will be triggered in multi stages
+ # use this dict to map method to stages.
+ method_stages_map = {
+ 'before_epoch': ['before_train_epoch', 'before_val_epoch'],
+ 'after_epoch': ['after_train_epoch', 'after_val_epoch'],
+ 'before_iter': ['before_train_iter', 'before_val_iter'],
+ 'after_iter': ['after_train_iter', 'after_val_iter'],
+ }
+
+ for method, map_stages in method_stages_map.items():
+ if is_method_overridden(method, Hook, self):
+ trigger_stages.update(map_stages)
+
+ return [stage for stage in Hook.stages if stage in trigger_stages]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/iter_timer.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/iter_timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfd5002fe85ffc6992155ac01003878064a1d9be
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/iter_timer.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import time
+
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class IterTimerHook(Hook):
+
+ def before_epoch(self, runner):
+ self.t = time.time()
+
+ def before_iter(self, runner):
+ runner.log_buffer.update({'data_time': time.time() - self.t})
+
+ def after_iter(self, runner):
+ runner.log_buffer.update({'time': time.time() - self.t})
+ self.t = time.time()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b6b345640a895368ac8a647afef6f24333d90e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import LoggerHook
+from .dvclive import DvcliveLoggerHook
+from .mlflow import MlflowLoggerHook
+from .neptune import NeptuneLoggerHook
+from .pavi import PaviLoggerHook
+from .tensorboard import TensorboardLoggerHook
+from .text import TextLoggerHook
+from .wandb import WandbLoggerHook
+
+__all__ = [
+ 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
+ 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook',
+ 'NeptuneLoggerHook', 'DvcliveLoggerHook'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/base.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f845256729458ced821762a1b8ef881e17ff9955
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/base.py
@@ -0,0 +1,166 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+from abc import ABCMeta, abstractmethod
+
+import numpy as np
+import torch
+
+from ..hook import Hook
+
+
+class LoggerHook(Hook):
+ """Base class for logger hooks.
+
+ Args:
+ interval (int): Logging interval (every k iterations).
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ reset_flag (bool): Whether to clear the output buffer after logging.
+ by_epoch (bool): Whether EpochBasedRunner is used.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True):
+ self.interval = interval
+ self.ignore_last = ignore_last
+ self.reset_flag = reset_flag
+ self.by_epoch = by_epoch
+
+ @abstractmethod
+ def log(self, runner):
+ pass
+
+ @staticmethod
+ def is_scalar(val, include_np=True, include_torch=True):
+ """Tell the input variable is a scalar or not.
+
+ Args:
+ val: Input variable.
+ include_np (bool): Whether include 0-d np.ndarray as a scalar.
+ include_torch (bool): Whether include 0-d torch.Tensor as a scalar.
+
+ Returns:
+ bool: True or False.
+ """
+ if isinstance(val, numbers.Number):
+ return True
+ elif include_np and isinstance(val, np.ndarray) and val.ndim == 0:
+ return True
+ elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1:
+ return True
+ else:
+ return False
+
+ def get_mode(self, runner):
+ if runner.mode == 'train':
+ if 'time' in runner.log_buffer.output:
+ mode = 'train'
+ else:
+ mode = 'val'
+ elif runner.mode == 'val':
+ mode = 'val'
+ else:
+ raise ValueError(f"runner mode should be 'train' or 'val', "
+ f'but got {runner.mode}')
+ return mode
+
+ def get_epoch(self, runner):
+ if runner.mode == 'train':
+ epoch = runner.epoch + 1
+ elif runner.mode == 'val':
+ # normal val mode
+ # runner.epoch += 1 has been done before val workflow
+ epoch = runner.epoch
+ else:
+ raise ValueError(f"runner mode should be 'train' or 'val', "
+ f'but got {runner.mode}')
+ return epoch
+
+ def get_iter(self, runner, inner_iter=False):
+ """Get the current training iteration step."""
+ if self.by_epoch and inner_iter:
+ current_iter = runner.inner_iter + 1
+ else:
+ current_iter = runner.iter + 1
+ return current_iter
+
+ def get_lr_tags(self, runner):
+ tags = {}
+ lrs = runner.current_lr()
+ if isinstance(lrs, dict):
+ for name, value in lrs.items():
+ tags[f'learning_rate/{name}'] = value[0]
+ else:
+ tags['learning_rate'] = lrs[0]
+ return tags
+
+ def get_momentum_tags(self, runner):
+ tags = {}
+ momentums = runner.current_momentum()
+ if isinstance(momentums, dict):
+ for name, value in momentums.items():
+ tags[f'momentum/{name}'] = value[0]
+ else:
+ tags['momentum'] = momentums[0]
+ return tags
+
+ def get_loggable_tags(self,
+ runner,
+ allow_scalar=True,
+ allow_text=False,
+ add_mode=True,
+ tags_to_skip=('time', 'data_time')):
+ tags = {}
+ for var, val in runner.log_buffer.output.items():
+ if var in tags_to_skip:
+ continue
+ if self.is_scalar(val) and not allow_scalar:
+ continue
+ if isinstance(val, str) and not allow_text:
+ continue
+ if add_mode:
+ var = f'{self.get_mode(runner)}/{var}'
+ tags[var] = val
+ tags.update(self.get_lr_tags(runner))
+ tags.update(self.get_momentum_tags(runner))
+ return tags
+
+ def before_run(self, runner):
+ for hook in runner.hooks[::-1]:
+ if isinstance(hook, LoggerHook):
+ hook.reset_flag = True
+ break
+
+ def before_epoch(self, runner):
+ runner.log_buffer.clear() # clear logs of last epoch
+
+ def after_train_iter(self, runner):
+ if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
+ runner.log_buffer.average(self.interval)
+ elif not self.by_epoch and self.every_n_iters(runner, self.interval):
+ runner.log_buffer.average(self.interval)
+ elif self.end_of_epoch(runner) and not self.ignore_last:
+ # not precise but more stable
+ runner.log_buffer.average(self.interval)
+
+ if runner.log_buffer.ready:
+ self.log(runner)
+ if self.reset_flag:
+ runner.log_buffer.clear_output()
+
+ def after_train_epoch(self, runner):
+ if runner.log_buffer.ready:
+ self.log(runner)
+ if self.reset_flag:
+ runner.log_buffer.clear_output()
+
+ def after_val_epoch(self, runner):
+ runner.log_buffer.average()
+ self.log(runner)
+ if self.reset_flag:
+ runner.log_buffer.clear_output()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/dvclive.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/dvclive.py
new file mode 100644
index 0000000000000000000000000000000000000000..687cdc58c0336c92b1e4f9a410ba67ebaab2bc7a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/dvclive.py
@@ -0,0 +1,58 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class DvcliveLoggerHook(LoggerHook):
+ """Class to log metrics with dvclive.
+
+ It requires `dvclive`_ to be installed.
+
+ Args:
+ path (str): Directory where dvclive will write TSV log files.
+ interval (int): Logging interval (every k iterations).
+ Default 10.
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ Default: True.
+ reset_flag (bool): Whether to clear the output buffer after logging.
+ Default: True.
+ by_epoch (bool): Whether EpochBasedRunner is used.
+ Default: True.
+
+ .. _dvclive:
+ https://dvc.org/doc/dvclive
+ """
+
+ def __init__(self,
+ path,
+ interval=10,
+ ignore_last=True,
+ reset_flag=True,
+ by_epoch=True):
+
+ super(DvcliveLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.path = path
+ self.import_dvclive()
+
+ def import_dvclive(self):
+ try:
+ import dvclive
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install dvclive" to install dvclive')
+ self.dvclive = dvclive
+
+ @master_only
+ def before_run(self, runner):
+ self.dvclive.init(self.path)
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ for k, v in tags.items():
+ self.dvclive.log(k, v, step=self.get_iter(runner))
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/mlflow.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/mlflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a72592be47b534ce22573775fd5a7e8e86d72d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/mlflow.py
@@ -0,0 +1,78 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class MlflowLoggerHook(LoggerHook):
+
+ def __init__(self,
+ exp_name=None,
+ tags=None,
+ log_model=True,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True):
+ """Class to log metrics and (optionally) a trained model to MLflow.
+
+ It requires `MLflow`_ to be installed.
+
+ Args:
+ exp_name (str, optional): Name of the experiment to be used.
+ Default None.
+ If not None, set the active experiment.
+ If experiment does not exist, an experiment with provided name
+ will be created.
+ tags (dict of str: str, optional): Tags for the current run.
+ Default None.
+ If not None, set tags for the current run.
+ log_model (bool, optional): Whether to log an MLflow artifact.
+ Default True.
+ If True, log runner.model as an MLflow artifact
+ for the current run.
+ interval (int): Logging interval (every k iterations).
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ reset_flag (bool): Whether to clear the output buffer after logging
+ by_epoch (bool): Whether EpochBasedRunner is used.
+
+ .. _MLflow:
+ https://www.mlflow.org/docs/latest/index.html
+ """
+ super(MlflowLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.import_mlflow()
+ self.exp_name = exp_name
+ self.tags = tags
+ self.log_model = log_model
+
+ def import_mlflow(self):
+ try:
+ import mlflow
+ import mlflow.pytorch as mlflow_pytorch
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install mlflow" to install mlflow')
+ self.mlflow = mlflow
+ self.mlflow_pytorch = mlflow_pytorch
+
+ @master_only
+ def before_run(self, runner):
+ super(MlflowLoggerHook, self).before_run(runner)
+ if self.exp_name is not None:
+ self.mlflow.set_experiment(self.exp_name)
+ if self.tags is not None:
+ self.mlflow.set_tags(self.tags)
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ self.mlflow.log_metrics(tags, step=self.get_iter(runner))
+
+ @master_only
+ def after_run(self, runner):
+ if self.log_model:
+ self.mlflow_pytorch.log_model(runner.model, 'models')
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/neptune.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/neptune.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a38772b0c93a8608f32c6357b8616e77c139dc9
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/neptune.py
@@ -0,0 +1,82 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class NeptuneLoggerHook(LoggerHook):
+ """Class to log metrics to NeptuneAI.
+
+ It requires `neptune-client` to be installed.
+
+ Args:
+ init_kwargs (dict): a dict contains the initialization keys as below:
+ - project (str): Name of a project in a form of
+ namespace/project_name. If None, the value of
+ NEPTUNE_PROJECT environment variable will be taken.
+ - api_token (str): User’s API token.
+ If None, the value of NEPTUNE_API_TOKEN environment
+ variable will be taken. Note: It is strongly recommended
+ to use NEPTUNE_API_TOKEN environment variable rather than
+ placing your API token in plain text in your source code.
+ - name (str, optional, default is 'Untitled'): Editable name of
+ the run. Name is displayed in the run's Details and in
+ Runs table as a column.
+ Check https://docs.neptune.ai/api-reference/neptune#init for
+ more init arguments.
+ interval (int): Logging interval (every k iterations).
+ ignore_last (bool): Ignore the log of last iterations in each epoch
+ if less than `interval`.
+ reset_flag (bool): Whether to clear the output buffer after logging
+ by_epoch (bool): Whether EpochBasedRunner is used.
+
+ .. _NeptuneAI:
+ https://docs.neptune.ai/you-should-know/logging-metadata
+ """
+
+ def __init__(self,
+ init_kwargs=None,
+ interval=10,
+ ignore_last=True,
+ reset_flag=True,
+ with_step=True,
+ by_epoch=True):
+
+ super(NeptuneLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.import_neptune()
+ self.init_kwargs = init_kwargs
+ self.with_step = with_step
+
+ def import_neptune(self):
+ try:
+ import neptune.new as neptune
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install neptune-client" to install neptune')
+ self.neptune = neptune
+ self.run = None
+
+ @master_only
+ def before_run(self, runner):
+ if self.init_kwargs:
+ self.run = self.neptune.init(**self.init_kwargs)
+ else:
+ self.run = self.neptune.init()
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ for tag_name, tag_value in tags.items():
+ if self.with_step:
+ self.run[tag_name].log(
+ tag_value, step=self.get_iter(runner))
+ else:
+ tags['global_step'] = self.get_iter(runner)
+ self.run[tag_name].log(tags)
+
+ @master_only
+ def after_run(self, runner):
+ self.run.stop()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/pavi.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/pavi.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d1c4286920361e6b80f135b8d60b250f98f507a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/pavi.py
@@ -0,0 +1,117 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os
+import os.path as osp
+
+import torch
+import yaml
+
+import annotator.mmpkg.mmcv as mmcv
+from ....parallel.utils import is_module_wrapper
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class PaviLoggerHook(LoggerHook):
+
+ def __init__(self,
+ init_kwargs=None,
+ add_graph=False,
+ add_last_ckpt=False,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True,
+ img_key='img_info'):
+ super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
+ by_epoch)
+ self.init_kwargs = init_kwargs
+ self.add_graph = add_graph
+ self.add_last_ckpt = add_last_ckpt
+ self.img_key = img_key
+
+ @master_only
+ def before_run(self, runner):
+ super(PaviLoggerHook, self).before_run(runner)
+ try:
+ from pavi import SummaryWriter
+ except ImportError:
+ raise ImportError('Please run "pip install pavi" to install pavi.')
+
+ self.run_name = runner.work_dir.split('/')[-1]
+
+ if not self.init_kwargs:
+ self.init_kwargs = dict()
+ self.init_kwargs['name'] = self.run_name
+ self.init_kwargs['model'] = runner._model_name
+ if runner.meta is not None:
+ if 'config_dict' in runner.meta:
+ config_dict = runner.meta['config_dict']
+ assert isinstance(
+ config_dict,
+ dict), ('meta["config_dict"] has to be of a dict, '
+ f'but got {type(config_dict)}')
+ elif 'config_file' in runner.meta:
+ config_file = runner.meta['config_file']
+ config_dict = dict(mmcv.Config.fromfile(config_file))
+ else:
+ config_dict = None
+ if config_dict is not None:
+ # 'max_.*iter' is parsed in pavi sdk as the maximum iterations
+ # to properly set up the progress bar.
+ config_dict = config_dict.copy()
+ config_dict.setdefault('max_iter', runner.max_iters)
+ # non-serializable values are first converted in
+ # mmcv.dump to json
+ config_dict = json.loads(
+ mmcv.dump(config_dict, file_format='json'))
+ session_text = yaml.dump(config_dict)
+ self.init_kwargs['session_text'] = session_text
+ self.writer = SummaryWriter(**self.init_kwargs)
+
+ def get_step(self, runner):
+ """Get the total training step/epoch."""
+ if self.get_mode(runner) == 'val' and self.by_epoch:
+ return self.get_epoch(runner)
+ else:
+ return self.get_iter(runner)
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner, add_mode=False)
+ if tags:
+ self.writer.add_scalars(
+ self.get_mode(runner), tags, self.get_step(runner))
+
+ @master_only
+ def after_run(self, runner):
+ if self.add_last_ckpt:
+ ckpt_path = osp.join(runner.work_dir, 'latest.pth')
+ if osp.islink(ckpt_path):
+ ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path))
+
+ if osp.isfile(ckpt_path):
+ # runner.epoch += 1 has been done before `after_run`.
+ iteration = runner.epoch if self.by_epoch else runner.iter
+ return self.writer.add_snapshot_file(
+ tag=self.run_name,
+ snapshot_file_path=ckpt_path,
+ iteration=iteration)
+
+ # flush the buffer and send a task ending signal to Pavi
+ self.writer.close()
+
+ @master_only
+ def before_epoch(self, runner):
+ if runner.epoch == 0 and self.add_graph:
+ if is_module_wrapper(runner.model):
+ _model = runner.model.module
+ else:
+ _model = runner.model
+ device = next(_model.parameters()).device
+ data = next(iter(runner.data_loader))
+ image = data[self.img_key][0:1].to(device)
+ with torch.no_grad():
+ self.writer.add_graph(_model, image)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/tensorboard.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/tensorboard.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c480a560e90f5b06abb4afaf9597aaf7c1eaa82
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/tensorboard.py
@@ -0,0 +1,57 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+
+from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class TensorboardLoggerHook(LoggerHook):
+
+ def __init__(self,
+ log_dir=None,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ by_epoch=True):
+ super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.log_dir = log_dir
+
+ @master_only
+ def before_run(self, runner):
+ super(TensorboardLoggerHook, self).before_run(runner)
+ if (TORCH_VERSION == 'parrots'
+ or digit_version(TORCH_VERSION) < digit_version('1.1')):
+ try:
+ from tensorboardX import SummaryWriter
+ except ImportError:
+ raise ImportError('Please install tensorboardX to use '
+ 'TensorboardLoggerHook.')
+ else:
+ try:
+ from torch.utils.tensorboard import SummaryWriter
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install future tensorboard" to install '
+ 'the dependencies to use torch.utils.tensorboard '
+ '(applicable to PyTorch 1.1 or higher)')
+
+ if self.log_dir is None:
+ self.log_dir = osp.join(runner.work_dir, 'tf_logs')
+ self.writer = SummaryWriter(self.log_dir)
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner, allow_text=True)
+ for tag, val in tags.items():
+ if isinstance(val, str):
+ self.writer.add_text(tag, val, self.get_iter(runner))
+ else:
+ self.writer.add_scalar(tag, val, self.get_iter(runner))
+
+ @master_only
+ def after_run(self, runner):
+ self.writer.close()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/text.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b30577469d5f70e544e1ce73816326e38dadb20
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/text.py
@@ -0,0 +1,256 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import datetime
+import os
+import os.path as osp
+from collections import OrderedDict
+
+import torch
+import torch.distributed as dist
+
+import annotator.mmpkg.mmcv as mmcv
+from annotator.mmpkg.mmcv.fileio.file_client import FileClient
+from annotator.mmpkg.mmcv.utils import is_tuple_of, scandir
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class TextLoggerHook(LoggerHook):
+ """Logger hook in text.
+
+ In this logger hook, the information will be printed on terminal and
+ saved in json file.
+
+ Args:
+ by_epoch (bool, optional): Whether EpochBasedRunner is used.
+ Default: True.
+ interval (int, optional): Logging interval (every k iterations).
+ Default: 10.
+ ignore_last (bool, optional): Ignore the log of last iterations in each
+ epoch if less than :attr:`interval`. Default: True.
+ reset_flag (bool, optional): Whether to clear the output buffer after
+ logging. Default: False.
+ interval_exp_name (int, optional): Logging interval for experiment
+ name. This feature is to help users conveniently get the experiment
+ information from screen or log file. Default: 1000.
+ out_dir (str, optional): Logs are saved in ``runner.work_dir`` default.
+ If ``out_dir`` is specified, logs will be copied to a new directory
+ which is the concatenation of ``out_dir`` and the last level
+ directory of ``runner.work_dir``. Default: None.
+ `New in version 1.3.16.`
+ out_suffix (str or tuple[str], optional): Those filenames ending with
+ ``out_suffix`` will be copied to ``out_dir``.
+ Default: ('.log.json', '.log', '.py').
+ `New in version 1.3.16.`
+ keep_local (bool, optional): Whether to keep local log when
+ :attr:`out_dir` is specified. If False, the local log will be
+ removed. Default: True.
+ `New in version 1.3.16.`
+ file_client_args (dict, optional): Arguments to instantiate a
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
+ Default: None.
+ `New in version 1.3.16.`
+ """
+
+ def __init__(self,
+ by_epoch=True,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ interval_exp_name=1000,
+ out_dir=None,
+ out_suffix=('.log.json', '.log', '.py'),
+ keep_local=True,
+ file_client_args=None):
+ super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
+ by_epoch)
+ self.by_epoch = by_epoch
+ self.time_sec_tot = 0
+ self.interval_exp_name = interval_exp_name
+
+ if out_dir is None and file_client_args is not None:
+ raise ValueError(
+ 'file_client_args should be "None" when `out_dir` is not'
+ 'specified.')
+ self.out_dir = out_dir
+
+ if not (out_dir is None or isinstance(out_dir, str)
+ or is_tuple_of(out_dir, str)):
+ raise TypeError('out_dir should be "None" or string or tuple of '
+ 'string, but got {out_dir}')
+ self.out_suffix = out_suffix
+
+ self.keep_local = keep_local
+ self.file_client_args = file_client_args
+ if self.out_dir is not None:
+ self.file_client = FileClient.infer_client(file_client_args,
+ self.out_dir)
+
+ def before_run(self, runner):
+ super(TextLoggerHook, self).before_run(runner)
+
+ if self.out_dir is not None:
+ self.file_client = FileClient.infer_client(self.file_client_args,
+ self.out_dir)
+ # The final `self.out_dir` is the concatenation of `self.out_dir`
+ # and the last level directory of `runner.work_dir`
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
+ runner.logger.info(
+ (f'Text logs will be saved to {self.out_dir} by '
+ f'{self.file_client.name} after the training process.'))
+
+ self.start_iter = runner.iter
+ self.json_log_path = osp.join(runner.work_dir,
+ f'{runner.timestamp}.log.json')
+ if runner.meta is not None:
+ self._dump_log(runner.meta, runner)
+
+ def _get_max_memory(self, runner):
+ device = getattr(runner.model, 'output_device', None)
+ mem = torch.cuda.max_memory_allocated(device=device)
+ mem_mb = torch.tensor([mem / (1024 * 1024)],
+ dtype=torch.int,
+ device=device)
+ if runner.world_size > 1:
+ dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
+ return mem_mb.item()
+
+ def _log_info(self, log_dict, runner):
+ # print exp name for users to distinguish experiments
+ # at every ``interval_exp_name`` iterations and the end of each epoch
+ if runner.meta is not None and 'exp_name' in runner.meta:
+ if (self.every_n_iters(runner, self.interval_exp_name)) or (
+ self.by_epoch and self.end_of_epoch(runner)):
+ exp_info = f'Exp name: {runner.meta["exp_name"]}'
+ runner.logger.info(exp_info)
+
+ if log_dict['mode'] == 'train':
+ if isinstance(log_dict['lr'], dict):
+ lr_str = []
+ for k, val in log_dict['lr'].items():
+ lr_str.append(f'lr_{k}: {val:.3e}')
+ lr_str = ' '.join(lr_str)
+ else:
+ lr_str = f'lr: {log_dict["lr"]:.3e}'
+
+ # by epoch: Epoch [4][100/1000]
+ # by iter: Iter [100/100000]
+ if self.by_epoch:
+ log_str = f'Epoch [{log_dict["epoch"]}]' \
+ f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
+ else:
+ log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
+ log_str += f'{lr_str}, '
+
+ if 'time' in log_dict.keys():
+ self.time_sec_tot += (log_dict['time'] * self.interval)
+ time_sec_avg = self.time_sec_tot / (
+ runner.iter - self.start_iter + 1)
+ eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+ log_str += f'eta: {eta_str}, '
+ log_str += f'time: {log_dict["time"]:.3f}, ' \
+ f'data_time: {log_dict["data_time"]:.3f}, '
+ # statistic memory
+ if torch.cuda.is_available():
+ log_str += f'memory: {log_dict["memory"]}, '
+ else:
+ # val/test time
+ # here 1000 is the length of the val dataloader
+ # by epoch: Epoch[val] [4][1000]
+ # by iter: Iter[val] [1000]
+ if self.by_epoch:
+ log_str = f'Epoch({log_dict["mode"]}) ' \
+ f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
+ else:
+ log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'
+
+ log_items = []
+ for name, val in log_dict.items():
+ # TODO: resolve this hack
+ # these items have been in log_str
+ if name in [
+ 'mode', 'Epoch', 'iter', 'lr', 'time', 'data_time',
+ 'memory', 'epoch'
+ ]:
+ continue
+ if isinstance(val, float):
+ val = f'{val:.4f}'
+ log_items.append(f'{name}: {val}')
+ log_str += ', '.join(log_items)
+
+ runner.logger.info(log_str)
+
+ def _dump_log(self, log_dict, runner):
+ # dump log in json format
+ json_log = OrderedDict()
+ for k, v in log_dict.items():
+ json_log[k] = self._round_float(v)
+ # only append log at last line
+ if runner.rank == 0:
+ with open(self.json_log_path, 'a+') as f:
+ mmcv.dump(json_log, f, file_format='json')
+ f.write('\n')
+
+ def _round_float(self, items):
+ if isinstance(items, list):
+ return [self._round_float(item) for item in items]
+ elif isinstance(items, float):
+ return round(items, 5)
+ else:
+ return items
+
+ def log(self, runner):
+ if 'eval_iter_num' in runner.log_buffer.output:
+ # this doesn't modify runner.iter and is regardless of by_epoch
+ cur_iter = runner.log_buffer.output.pop('eval_iter_num')
+ else:
+ cur_iter = self.get_iter(runner, inner_iter=True)
+
+ log_dict = OrderedDict(
+ mode=self.get_mode(runner),
+ epoch=self.get_epoch(runner),
+ iter=cur_iter)
+
+ # only record lr of the first param group
+ cur_lr = runner.current_lr()
+ if isinstance(cur_lr, list):
+ log_dict['lr'] = cur_lr[0]
+ else:
+ assert isinstance(cur_lr, dict)
+ log_dict['lr'] = {}
+ for k, lr_ in cur_lr.items():
+ assert isinstance(lr_, list)
+ log_dict['lr'].update({k: lr_[0]})
+
+ if 'time' in runner.log_buffer.output:
+ # statistic memory
+ if torch.cuda.is_available():
+ log_dict['memory'] = self._get_max_memory(runner)
+
+ log_dict = dict(log_dict, **runner.log_buffer.output)
+
+ self._log_info(log_dict, runner)
+ self._dump_log(log_dict, runner)
+ return log_dict
+
+ def after_run(self, runner):
+ # copy or upload logs to self.out_dir
+ if self.out_dir is not None:
+ for filename in scandir(runner.work_dir, self.out_suffix, True):
+ local_filepath = osp.join(runner.work_dir, filename)
+ out_filepath = self.file_client.join_path(
+ self.out_dir, filename)
+ with open(local_filepath, 'r') as f:
+ self.file_client.put_text(f.read(), out_filepath)
+
+ runner.logger.info(
+ (f'The file {local_filepath} has been uploaded to '
+ f'{out_filepath}.'))
+
+ if not self.keep_local:
+ os.remove(local_filepath)
+ runner.logger.info(
+ (f'{local_filepath} was removed due to the '
+ '`self.keep_local=False`'))
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/wandb.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/wandb.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f6808462eb79ab2b04806a5d9f0d3dd079b5ea9
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/wandb.py
@@ -0,0 +1,56 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class WandbLoggerHook(LoggerHook):
+
+ def __init__(self,
+ init_kwargs=None,
+ interval=10,
+ ignore_last=True,
+ reset_flag=False,
+ commit=True,
+ by_epoch=True,
+ with_step=True):
+ super(WandbLoggerHook, self).__init__(interval, ignore_last,
+ reset_flag, by_epoch)
+ self.import_wandb()
+ self.init_kwargs = init_kwargs
+ self.commit = commit
+ self.with_step = with_step
+
+ def import_wandb(self):
+ try:
+ import wandb
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install wandb" to install wandb')
+ self.wandb = wandb
+
+ @master_only
+ def before_run(self, runner):
+ super(WandbLoggerHook, self).before_run(runner)
+ if self.wandb is None:
+ self.import_wandb()
+ if self.init_kwargs:
+ self.wandb.init(**self.init_kwargs)
+ else:
+ self.wandb.init()
+
+ @master_only
+ def log(self, runner):
+ tags = self.get_loggable_tags(runner)
+ if tags:
+ if self.with_step:
+ self.wandb.log(
+ tags, step=self.get_iter(runner), commit=self.commit)
+ else:
+ tags['global_step'] = self.get_iter(runner)
+ self.wandb.log(tags, commit=self.commit)
+
+ @master_only
+ def after_run(self, runner):
+ self.wandb.join()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/lr_updater.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/lr_updater.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9851d2ca3c4e60b95ad734c19a2484b9ca7c708
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/lr_updater.py
@@ -0,0 +1,670 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+from math import cos, pi
+
+import annotator.mmpkg.mmcv as mmcv
+from .hook import HOOKS, Hook
+
+
+class LrUpdaterHook(Hook):
+ """LR Scheduler in MMCV.
+
+ Args:
+ by_epoch (bool): LR changes epoch by epoch
+ warmup (string): Type of warmup used. It can be None(use no warmup),
+ 'constant', 'linear' or 'exp'
+ warmup_iters (int): The number of iterations or epochs that warmup
+ lasts
+ warmup_ratio (float): LR used at the beginning of warmup equals to
+ warmup_ratio * initial_lr
+ warmup_by_epoch (bool): When warmup_by_epoch == True, warmup_iters
+ means the number of epochs that warmup lasts, otherwise means the
+ number of iteration that warmup lasts
+ """
+
+ def __init__(self,
+ by_epoch=True,
+ warmup=None,
+ warmup_iters=0,
+ warmup_ratio=0.1,
+ warmup_by_epoch=False):
+ # validate the "warmup" argument
+ if warmup is not None:
+ if warmup not in ['constant', 'linear', 'exp']:
+ raise ValueError(
+ f'"{warmup}" is not a supported type for warming up, valid'
+ ' types are "constant" and "linear"')
+ if warmup is not None:
+ assert warmup_iters > 0, \
+ '"warmup_iters" must be a positive integer'
+ assert 0 < warmup_ratio <= 1.0, \
+ '"warmup_ratio" must be in range (0,1]'
+
+ self.by_epoch = by_epoch
+ self.warmup = warmup
+ self.warmup_iters = warmup_iters
+ self.warmup_ratio = warmup_ratio
+ self.warmup_by_epoch = warmup_by_epoch
+
+ if self.warmup_by_epoch:
+ self.warmup_epochs = self.warmup_iters
+ self.warmup_iters = None
+ else:
+ self.warmup_epochs = None
+
+ self.base_lr = [] # initial lr for all param groups
+ self.regular_lr = [] # expected lr if no warming up is performed
+
+ def _set_lr(self, runner, lr_groups):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ for param_group, lr in zip(optim.param_groups, lr_groups[k]):
+ param_group['lr'] = lr
+ else:
+ for param_group, lr in zip(runner.optimizer.param_groups,
+ lr_groups):
+ param_group['lr'] = lr
+
+ def get_lr(self, runner, base_lr):
+ raise NotImplementedError
+
+ def get_regular_lr(self, runner):
+ if isinstance(runner.optimizer, dict):
+ lr_groups = {}
+ for k in runner.optimizer.keys():
+ _lr_group = [
+ self.get_lr(runner, _base_lr)
+ for _base_lr in self.base_lr[k]
+ ]
+ lr_groups.update({k: _lr_group})
+
+ return lr_groups
+ else:
+ return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
+
+ def get_warmup_lr(self, cur_iters):
+
+ def _get_warmup_lr(cur_iters, regular_lr):
+ if self.warmup == 'constant':
+ warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr]
+ elif self.warmup == 'linear':
+ k = (1 - cur_iters / self.warmup_iters) * (1 -
+ self.warmup_ratio)
+ warmup_lr = [_lr * (1 - k) for _lr in regular_lr]
+ elif self.warmup == 'exp':
+ k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
+ warmup_lr = [_lr * k for _lr in regular_lr]
+ return warmup_lr
+
+ if isinstance(self.regular_lr, dict):
+ lr_groups = {}
+ for key, regular_lr in self.regular_lr.items():
+ lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr)
+ return lr_groups
+ else:
+ return _get_warmup_lr(cur_iters, self.regular_lr)
+
+ def before_run(self, runner):
+ # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
+ # it will be set according to the optimizer params
+ if isinstance(runner.optimizer, dict):
+ self.base_lr = {}
+ for k, optim in runner.optimizer.items():
+ for group in optim.param_groups:
+ group.setdefault('initial_lr', group['lr'])
+ _base_lr = [
+ group['initial_lr'] for group in optim.param_groups
+ ]
+ self.base_lr.update({k: _base_lr})
+ else:
+ for group in runner.optimizer.param_groups:
+ group.setdefault('initial_lr', group['lr'])
+ self.base_lr = [
+ group['initial_lr'] for group in runner.optimizer.param_groups
+ ]
+
+ def before_train_epoch(self, runner):
+ if self.warmup_iters is None:
+ epoch_len = len(runner.data_loader)
+ self.warmup_iters = self.warmup_epochs * epoch_len
+
+ if not self.by_epoch:
+ return
+
+ self.regular_lr = self.get_regular_lr(runner)
+ self._set_lr(runner, self.regular_lr)
+
+ def before_train_iter(self, runner):
+ cur_iter = runner.iter
+ if not self.by_epoch:
+ self.regular_lr = self.get_regular_lr(runner)
+ if self.warmup is None or cur_iter >= self.warmup_iters:
+ self._set_lr(runner, self.regular_lr)
+ else:
+ warmup_lr = self.get_warmup_lr(cur_iter)
+ self._set_lr(runner, warmup_lr)
+ elif self.by_epoch:
+ if self.warmup is None or cur_iter > self.warmup_iters:
+ return
+ elif cur_iter == self.warmup_iters:
+ self._set_lr(runner, self.regular_lr)
+ else:
+ warmup_lr = self.get_warmup_lr(cur_iter)
+ self._set_lr(runner, warmup_lr)
+
+
+@HOOKS.register_module()
+class FixedLrUpdaterHook(LrUpdaterHook):
+
+ def __init__(self, **kwargs):
+ super(FixedLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ return base_lr
+
+
+@HOOKS.register_module()
+class StepLrUpdaterHook(LrUpdaterHook):
+ """Step LR scheduler with min_lr clipping.
+
+ Args:
+ step (int | list[int]): Step to decay the LR. If an int value is given,
+ regard it as the decay interval. If a list is given, decay LR at
+ these steps.
+ gamma (float, optional): Decay LR ratio. Default: 0.1.
+ min_lr (float, optional): Minimum LR value to keep. If LR after decay
+ is lower than `min_lr`, it will be clipped to this value. If None
+ is given, we don't perform lr clipping. Default: None.
+ """
+
+ def __init__(self, step, gamma=0.1, min_lr=None, **kwargs):
+ if isinstance(step, list):
+ assert mmcv.is_list_of(step, int)
+ assert all([s > 0 for s in step])
+ elif isinstance(step, int):
+ assert step > 0
+ else:
+ raise TypeError('"step" must be a list or integer')
+ self.step = step
+ self.gamma = gamma
+ self.min_lr = min_lr
+ super(StepLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ progress = runner.epoch if self.by_epoch else runner.iter
+
+ # calculate exponential term
+ if isinstance(self.step, int):
+ exp = progress // self.step
+ else:
+ exp = len(self.step)
+ for i, s in enumerate(self.step):
+ if progress < s:
+ exp = i
+ break
+
+ lr = base_lr * (self.gamma**exp)
+ if self.min_lr is not None:
+ # clip to a minimum value
+ lr = max(lr, self.min_lr)
+ return lr
+
+
+@HOOKS.register_module()
+class ExpLrUpdaterHook(LrUpdaterHook):
+
+ def __init__(self, gamma, **kwargs):
+ self.gamma = gamma
+ super(ExpLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ progress = runner.epoch if self.by_epoch else runner.iter
+ return base_lr * self.gamma**progress
+
+
+@HOOKS.register_module()
+class PolyLrUpdaterHook(LrUpdaterHook):
+
+ def __init__(self, power=1., min_lr=0., **kwargs):
+ self.power = power
+ self.min_lr = min_lr
+ super(PolyLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ progress = runner.epoch
+ max_progress = runner.max_epochs
+ else:
+ progress = runner.iter
+ max_progress = runner.max_iters
+ coeff = (1 - progress / max_progress)**self.power
+ return (base_lr - self.min_lr) * coeff + self.min_lr
+
+
+@HOOKS.register_module()
+class InvLrUpdaterHook(LrUpdaterHook):
+
+ def __init__(self, gamma, power=1., **kwargs):
+ self.gamma = gamma
+ self.power = power
+ super(InvLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ progress = runner.epoch if self.by_epoch else runner.iter
+ return base_lr * (1 + self.gamma * progress)**(-self.power)
+
+
+@HOOKS.register_module()
+class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
+
+ def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs):
+ assert (min_lr is None) ^ (min_lr_ratio is None)
+ self.min_lr = min_lr
+ self.min_lr_ratio = min_lr_ratio
+ super(CosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ progress = runner.epoch
+ max_progress = runner.max_epochs
+ else:
+ progress = runner.iter
+ max_progress = runner.max_iters
+
+ if self.min_lr_ratio is not None:
+ target_lr = base_lr * self.min_lr_ratio
+ else:
+ target_lr = self.min_lr
+ return annealing_cos(base_lr, target_lr, progress / max_progress)
+
+
+@HOOKS.register_module()
+class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
+ """Flat + Cosine lr schedule.
+
+ Modified from https://github.com/fastai/fastai/blob/master/fastai/callback/schedule.py#L128 # noqa: E501
+
+ Args:
+ start_percent (float): When to start annealing the learning rate
+ after the percentage of the total training steps.
+ The value should be in range [0, 1).
+ Default: 0.75
+ min_lr (float, optional): The minimum lr. Default: None.
+ min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
+ Either `min_lr` or `min_lr_ratio` should be specified.
+ Default: None.
+ """
+
+ def __init__(self,
+ start_percent=0.75,
+ min_lr=None,
+ min_lr_ratio=None,
+ **kwargs):
+ assert (min_lr is None) ^ (min_lr_ratio is None)
+ if start_percent < 0 or start_percent > 1 or not isinstance(
+ start_percent, float):
+ raise ValueError(
+ 'expected float between 0 and 1 start_percent, but '
+ f'got {start_percent}')
+ self.start_percent = start_percent
+ self.min_lr = min_lr
+ self.min_lr_ratio = min_lr_ratio
+ super(FlatCosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ start = round(runner.max_epochs * self.start_percent)
+ progress = runner.epoch - start
+ max_progress = runner.max_epochs - start
+ else:
+ start = round(runner.max_iters * self.start_percent)
+ progress = runner.iter - start
+ max_progress = runner.max_iters - start
+
+ if self.min_lr_ratio is not None:
+ target_lr = base_lr * self.min_lr_ratio
+ else:
+ target_lr = self.min_lr
+
+ if progress < 0:
+ return base_lr
+ else:
+ return annealing_cos(base_lr, target_lr, progress / max_progress)
+
+
+@HOOKS.register_module()
+class CosineRestartLrUpdaterHook(LrUpdaterHook):
+ """Cosine annealing with restarts learning rate scheme.
+
+ Args:
+ periods (list[int]): Periods for each cosine anneling cycle.
+ restart_weights (list[float], optional): Restart weights at each
+ restart iteration. Default: [1].
+ min_lr (float, optional): The minimum lr. Default: None.
+ min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
+ Either `min_lr` or `min_lr_ratio` should be specified.
+ Default: None.
+ """
+
+ def __init__(self,
+ periods,
+ restart_weights=[1],
+ min_lr=None,
+ min_lr_ratio=None,
+ **kwargs):
+ assert (min_lr is None) ^ (min_lr_ratio is None)
+ self.periods = periods
+ self.min_lr = min_lr
+ self.min_lr_ratio = min_lr_ratio
+ self.restart_weights = restart_weights
+ assert (len(self.periods) == len(self.restart_weights)
+ ), 'periods and restart_weights should have the same length.'
+ super(CosineRestartLrUpdaterHook, self).__init__(**kwargs)
+
+ self.cumulative_periods = [
+ sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
+ ]
+
+ def get_lr(self, runner, base_lr):
+ if self.by_epoch:
+ progress = runner.epoch
+ else:
+ progress = runner.iter
+
+ if self.min_lr_ratio is not None:
+ target_lr = base_lr * self.min_lr_ratio
+ else:
+ target_lr = self.min_lr
+
+ idx = get_position_from_periods(progress, self.cumulative_periods)
+ current_weight = self.restart_weights[idx]
+ nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1]
+ current_periods = self.periods[idx]
+
+ alpha = min((progress - nearest_restart) / current_periods, 1)
+ return annealing_cos(base_lr, target_lr, alpha, current_weight)
+
+
+def get_position_from_periods(iteration, cumulative_periods):
+ """Get the position from a period list.
+
+ It will return the index of the right-closest number in the period list.
+ For example, the cumulative_periods = [100, 200, 300, 400],
+ if iteration == 50, return 0;
+ if iteration == 210, return 2;
+ if iteration == 300, return 3.
+
+ Args:
+ iteration (int): Current iteration.
+ cumulative_periods (list[int]): Cumulative period list.
+
+ Returns:
+ int: The position of the right-closest number in the period list.
+ """
+ for i, period in enumerate(cumulative_periods):
+ if iteration < period:
+ return i
+ raise ValueError(f'Current iteration {iteration} exceeds '
+ f'cumulative_periods {cumulative_periods}')
+
+
+@HOOKS.register_module()
+class CyclicLrUpdaterHook(LrUpdaterHook):
+ """Cyclic LR Scheduler.
+
+ Implement the cyclical learning rate policy (CLR) described in
+ https://arxiv.org/pdf/1506.01186.pdf
+
+ Different from the original paper, we use cosine annealing rather than
+ triangular policy inside a cycle. This improves the performance in the
+ 3D detection area.
+
+ Args:
+ by_epoch (bool): Whether to update LR by epoch.
+ target_ratio (tuple[float]): Relative ratio of the highest LR and the
+ lowest LR to the initial LR.
+ cyclic_times (int): Number of cycles during training
+ step_ratio_up (float): The ratio of the increasing process of LR in
+ the total cycle.
+ anneal_strategy (str): {'cos', 'linear'}
+ Specifies the annealing strategy: 'cos' for cosine annealing,
+ 'linear' for linear annealing. Default: 'cos'.
+ """
+
+ def __init__(self,
+ by_epoch=False,
+ target_ratio=(10, 1e-4),
+ cyclic_times=1,
+ step_ratio_up=0.4,
+ anneal_strategy='cos',
+ **kwargs):
+ if isinstance(target_ratio, float):
+ target_ratio = (target_ratio, target_ratio / 1e5)
+ elif isinstance(target_ratio, tuple):
+ target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
+ if len(target_ratio) == 1 else target_ratio
+ else:
+ raise ValueError('target_ratio should be either float '
+ f'or tuple, got {type(target_ratio)}')
+
+ assert len(target_ratio) == 2, \
+ '"target_ratio" must be list or tuple of two floats'
+ assert 0 <= step_ratio_up < 1.0, \
+ '"step_ratio_up" must be in range [0,1)'
+
+ self.target_ratio = target_ratio
+ self.cyclic_times = cyclic_times
+ self.step_ratio_up = step_ratio_up
+ self.lr_phases = [] # init lr_phases
+ # validate anneal_strategy
+ if anneal_strategy not in ['cos', 'linear']:
+ raise ValueError('anneal_strategy must be one of "cos" or '
+ f'"linear", instead got {anneal_strategy}')
+ elif anneal_strategy == 'cos':
+ self.anneal_func = annealing_cos
+ elif anneal_strategy == 'linear':
+ self.anneal_func = annealing_linear
+
+ assert not by_epoch, \
+ 'currently only support "by_epoch" = False'
+ super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs)
+
+ def before_run(self, runner):
+ super(CyclicLrUpdaterHook, self).before_run(runner)
+ # initiate lr_phases
+ # total lr_phases are separated as up and down
+ max_iter_per_phase = runner.max_iters // self.cyclic_times
+ iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
+ self.lr_phases.append(
+ [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
+ self.lr_phases.append([
+ iter_up_phase, max_iter_per_phase, max_iter_per_phase,
+ self.target_ratio[0], self.target_ratio[1]
+ ])
+
+ def get_lr(self, runner, base_lr):
+ curr_iter = runner.iter
+ for (start_iter, end_iter, max_iter_per_phase, start_ratio,
+ end_ratio) in self.lr_phases:
+ curr_iter %= max_iter_per_phase
+ if start_iter <= curr_iter < end_iter:
+ progress = curr_iter - start_iter
+ return self.anneal_func(base_lr * start_ratio,
+ base_lr * end_ratio,
+ progress / (end_iter - start_iter))
+
+
+@HOOKS.register_module()
+class OneCycleLrUpdaterHook(LrUpdaterHook):
+ """One Cycle LR Scheduler.
+
+ The 1cycle learning rate policy changes the learning rate after every
+ batch. The one cycle learning rate policy is described in
+ https://arxiv.org/pdf/1708.07120.pdf
+
+ Args:
+ max_lr (float or list): Upper learning rate boundaries in the cycle
+ for each parameter group.
+ total_steps (int, optional): The total number of steps in the cycle.
+ Note that if a value is not provided here, it will be the max_iter
+ of runner. Default: None.
+ pct_start (float): The percentage of the cycle (in number of steps)
+ spent increasing the learning rate.
+ Default: 0.3
+ anneal_strategy (str): {'cos', 'linear'}
+ Specifies the annealing strategy: 'cos' for cosine annealing,
+ 'linear' for linear annealing.
+ Default: 'cos'
+ div_factor (float): Determines the initial learning rate via
+ initial_lr = max_lr/div_factor
+ Default: 25
+ final_div_factor (float): Determines the minimum learning rate via
+ min_lr = initial_lr/final_div_factor
+ Default: 1e4
+ three_phase (bool): If three_phase is True, use a third phase of the
+ schedule to annihilate the learning rate according to
+ final_div_factor instead of modifying the second phase (the first
+ two phases will be symmetrical about the step indicated by
+ pct_start).
+ Default: False
+ """
+
+ def __init__(self,
+ max_lr,
+ total_steps=None,
+ pct_start=0.3,
+ anneal_strategy='cos',
+ div_factor=25,
+ final_div_factor=1e4,
+ three_phase=False,
+ **kwargs):
+ # validate by_epoch, currently only support by_epoch = False
+ if 'by_epoch' not in kwargs:
+ kwargs['by_epoch'] = False
+ else:
+ assert not kwargs['by_epoch'], \
+ 'currently only support "by_epoch" = False'
+ if not isinstance(max_lr, (numbers.Number, list, dict)):
+ raise ValueError('the type of max_lr must be the one of list or '
+ f'dict, but got {type(max_lr)}')
+ self._max_lr = max_lr
+ if total_steps is not None:
+ if not isinstance(total_steps, int):
+ raise ValueError('the type of total_steps must be int, but'
+ f'got {type(total_steps)}')
+ self.total_steps = total_steps
+ # validate pct_start
+ if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
+ raise ValueError('expected float between 0 and 1 pct_start, but '
+ f'got {pct_start}')
+ self.pct_start = pct_start
+ # validate anneal_strategy
+ if anneal_strategy not in ['cos', 'linear']:
+ raise ValueError('anneal_strategy must be one of "cos" or '
+ f'"linear", instead got {anneal_strategy}')
+ elif anneal_strategy == 'cos':
+ self.anneal_func = annealing_cos
+ elif anneal_strategy == 'linear':
+ self.anneal_func = annealing_linear
+ self.div_factor = div_factor
+ self.final_div_factor = final_div_factor
+ self.three_phase = three_phase
+ self.lr_phases = [] # init lr_phases
+ super(OneCycleLrUpdaterHook, self).__init__(**kwargs)
+
+ def before_run(self, runner):
+ if hasattr(self, 'total_steps'):
+ total_steps = self.total_steps
+ else:
+ total_steps = runner.max_iters
+ if total_steps < runner.max_iters:
+ raise ValueError(
+ 'The total steps must be greater than or equal to max '
+ f'iterations {runner.max_iters} of runner, but total steps '
+ f'is {total_steps}.')
+
+ if isinstance(runner.optimizer, dict):
+ self.base_lr = {}
+ for k, optim in runner.optimizer.items():
+ _max_lr = format_param(k, optim, self._max_lr)
+ self.base_lr[k] = [lr / self.div_factor for lr in _max_lr]
+ for group, lr in zip(optim.param_groups, self.base_lr[k]):
+ group.setdefault('initial_lr', lr)
+ else:
+ k = type(runner.optimizer).__name__
+ _max_lr = format_param(k, runner.optimizer, self._max_lr)
+ self.base_lr = [lr / self.div_factor for lr in _max_lr]
+ for group, lr in zip(runner.optimizer.param_groups, self.base_lr):
+ group.setdefault('initial_lr', lr)
+
+ if self.three_phase:
+ self.lr_phases.append(
+ [float(self.pct_start * total_steps) - 1, 1, self.div_factor])
+ self.lr_phases.append([
+ float(2 * self.pct_start * total_steps) - 2, self.div_factor, 1
+ ])
+ self.lr_phases.append(
+ [total_steps - 1, 1, 1 / self.final_div_factor])
+ else:
+ self.lr_phases.append(
+ [float(self.pct_start * total_steps) - 1, 1, self.div_factor])
+ self.lr_phases.append(
+ [total_steps - 1, self.div_factor, 1 / self.final_div_factor])
+
+ def get_lr(self, runner, base_lr):
+ curr_iter = runner.iter
+ start_iter = 0
+ for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases):
+ if curr_iter <= end_iter:
+ pct = (curr_iter - start_iter) / (end_iter - start_iter)
+ lr = self.anneal_func(base_lr * start_lr, base_lr * end_lr,
+ pct)
+ break
+ start_iter = end_iter
+ return lr
+
+
+def annealing_cos(start, end, factor, weight=1):
+ """Calculate annealing cos learning rate.
+
+ Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
+ percentage goes from 0.0 to 1.0.
+
+ Args:
+ start (float): The starting learning rate of the cosine annealing.
+ end (float): The ending learing rate of the cosine annealing.
+ factor (float): The coefficient of `pi` when calculating the current
+ percentage. Range from 0.0 to 1.0.
+ weight (float, optional): The combination factor of `start` and `end`
+ when calculating the actual starting learning rate. Default to 1.
+ """
+ cos_out = cos(pi * factor) + 1
+ return end + 0.5 * weight * (start - end) * cos_out
+
+
+def annealing_linear(start, end, factor):
+ """Calculate annealing linear learning rate.
+
+ Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0.
+
+ Args:
+ start (float): The starting learning rate of the linear annealing.
+ end (float): The ending learing rate of the linear annealing.
+ factor (float): The coefficient of `pi` when calculating the current
+ percentage. Range from 0.0 to 1.0.
+ """
+ return start + (end - start) * factor
+
+
+def format_param(name, optim, param):
+ if isinstance(param, numbers.Number):
+ return [param] * len(optim.param_groups)
+ elif isinstance(param, (list, tuple)): # multi param groups
+ if len(param) != len(optim.param_groups):
+ raise ValueError(f'expected {len(optim.param_groups)} '
+ f'values for {name}, got {len(param)}')
+ return param
+ else: # multi optimizers
+ if name not in param:
+ raise KeyError(f'{name} is not found in {param.keys()}')
+ return param[name]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/memory.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..70cf9a838fb314e3bd3c07aadbc00921a81e83ed
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/memory.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class EmptyCacheHook(Hook):
+
+ def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
+ self._before_epoch = before_epoch
+ self._after_epoch = after_epoch
+ self._after_iter = after_iter
+
+ def after_iter(self, runner):
+ if self._after_iter:
+ torch.cuda.empty_cache()
+
+ def before_epoch(self, runner):
+ if self._before_epoch:
+ torch.cuda.empty_cache()
+
+ def after_epoch(self, runner):
+ if self._after_epoch:
+ torch.cuda.empty_cache()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/momentum_updater.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/momentum_updater.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdc70246280c2318f51034bb6b66eade7b478b79
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/momentum_updater.py
@@ -0,0 +1,493 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import annotator.mmpkg.mmcv as mmcv
+from .hook import HOOKS, Hook
+from .lr_updater import annealing_cos, annealing_linear, format_param
+
+
+class MomentumUpdaterHook(Hook):
+
+ def __init__(self,
+ by_epoch=True,
+ warmup=None,
+ warmup_iters=0,
+ warmup_ratio=0.9):
+ # validate the "warmup" argument
+ if warmup is not None:
+ if warmup not in ['constant', 'linear', 'exp']:
+ raise ValueError(
+ f'"{warmup}" is not a supported type for warming up, valid'
+ ' types are "constant" and "linear"')
+ if warmup is not None:
+ assert warmup_iters > 0, \
+ '"warmup_iters" must be a positive integer'
+ assert 0 < warmup_ratio <= 1.0, \
+ '"warmup_momentum" must be in range (0,1]'
+
+ self.by_epoch = by_epoch
+ self.warmup = warmup
+ self.warmup_iters = warmup_iters
+ self.warmup_ratio = warmup_ratio
+
+ self.base_momentum = [] # initial momentum for all param groups
+ self.regular_momentum = [
+ ] # expected momentum if no warming up is performed
+
+ def _set_momentum(self, runner, momentum_groups):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ for param_group, mom in zip(optim.param_groups,
+ momentum_groups[k]):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+ else:
+ for param_group, mom in zip(runner.optimizer.param_groups,
+ momentum_groups):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+
+ def get_momentum(self, runner, base_momentum):
+ raise NotImplementedError
+
+ def get_regular_momentum(self, runner):
+ if isinstance(runner.optimizer, dict):
+ momentum_groups = {}
+ for k in runner.optimizer.keys():
+ _momentum_group = [
+ self.get_momentum(runner, _base_momentum)
+ for _base_momentum in self.base_momentum[k]
+ ]
+ momentum_groups.update({k: _momentum_group})
+ return momentum_groups
+ else:
+ return [
+ self.get_momentum(runner, _base_momentum)
+ for _base_momentum in self.base_momentum
+ ]
+
+ def get_warmup_momentum(self, cur_iters):
+
+ def _get_warmup_momentum(cur_iters, regular_momentum):
+ if self.warmup == 'constant':
+ warmup_momentum = [
+ _momentum / self.warmup_ratio
+ for _momentum in self.regular_momentum
+ ]
+ elif self.warmup == 'linear':
+ k = (1 - cur_iters / self.warmup_iters) * (1 -
+ self.warmup_ratio)
+ warmup_momentum = [
+ _momentum / (1 - k) for _momentum in self.regular_mom
+ ]
+ elif self.warmup == 'exp':
+ k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
+ warmup_momentum = [
+ _momentum / k for _momentum in self.regular_mom
+ ]
+ return warmup_momentum
+
+ if isinstance(self.regular_momentum, dict):
+ momentum_groups = {}
+ for key, regular_momentum in self.regular_momentum.items():
+ momentum_groups[key] = _get_warmup_momentum(
+ cur_iters, regular_momentum)
+ return momentum_groups
+ else:
+ return _get_warmup_momentum(cur_iters, self.regular_momentum)
+
+ def before_run(self, runner):
+ # NOTE: when resuming from a checkpoint,
+ # if 'initial_momentum' is not saved,
+ # it will be set according to the optimizer params
+ if isinstance(runner.optimizer, dict):
+ self.base_momentum = {}
+ for k, optim in runner.optimizer.items():
+ for group in optim.param_groups:
+ if 'momentum' in group.keys():
+ group.setdefault('initial_momentum', group['momentum'])
+ else:
+ group.setdefault('initial_momentum', group['betas'][0])
+ _base_momentum = [
+ group['initial_momentum'] for group in optim.param_groups
+ ]
+ self.base_momentum.update({k: _base_momentum})
+ else:
+ for group in runner.optimizer.param_groups:
+ if 'momentum' in group.keys():
+ group.setdefault('initial_momentum', group['momentum'])
+ else:
+ group.setdefault('initial_momentum', group['betas'][0])
+ self.base_momentum = [
+ group['initial_momentum']
+ for group in runner.optimizer.param_groups
+ ]
+
+ def before_train_epoch(self, runner):
+ if not self.by_epoch:
+ return
+ self.regular_mom = self.get_regular_momentum(runner)
+ self._set_momentum(runner, self.regular_mom)
+
+ def before_train_iter(self, runner):
+ cur_iter = runner.iter
+ if not self.by_epoch:
+ self.regular_mom = self.get_regular_momentum(runner)
+ if self.warmup is None or cur_iter >= self.warmup_iters:
+ self._set_momentum(runner, self.regular_mom)
+ else:
+ warmup_momentum = self.get_warmup_momentum(cur_iter)
+ self._set_momentum(runner, warmup_momentum)
+ elif self.by_epoch:
+ if self.warmup is None or cur_iter > self.warmup_iters:
+ return
+ elif cur_iter == self.warmup_iters:
+ self._set_momentum(runner, self.regular_mom)
+ else:
+ warmup_momentum = self.get_warmup_momentum(cur_iter)
+ self._set_momentum(runner, warmup_momentum)
+
+
+@HOOKS.register_module()
+class StepMomentumUpdaterHook(MomentumUpdaterHook):
+ """Step momentum scheduler with min value clipping.
+
+ Args:
+ step (int | list[int]): Step to decay the momentum. If an int value is
+ given, regard it as the decay interval. If a list is given, decay
+ momentum at these steps.
+ gamma (float, optional): Decay momentum ratio. Default: 0.5.
+ min_momentum (float, optional): Minimum momentum value to keep. If
+ momentum after decay is lower than this value, it will be clipped
+ accordingly. If None is given, we don't perform lr clipping.
+ Default: None.
+ """
+
+ def __init__(self, step, gamma=0.5, min_momentum=None, **kwargs):
+ if isinstance(step, list):
+ assert mmcv.is_list_of(step, int)
+ assert all([s > 0 for s in step])
+ elif isinstance(step, int):
+ assert step > 0
+ else:
+ raise TypeError('"step" must be a list or integer')
+ self.step = step
+ self.gamma = gamma
+ self.min_momentum = min_momentum
+ super(StepMomentumUpdaterHook, self).__init__(**kwargs)
+
+ def get_momentum(self, runner, base_momentum):
+ progress = runner.epoch if self.by_epoch else runner.iter
+
+ # calculate exponential term
+ if isinstance(self.step, int):
+ exp = progress // self.step
+ else:
+ exp = len(self.step)
+ for i, s in enumerate(self.step):
+ if progress < s:
+ exp = i
+ break
+
+ momentum = base_momentum * (self.gamma**exp)
+ if self.min_momentum is not None:
+ # clip to a minimum value
+ momentum = max(momentum, self.min_momentum)
+ return momentum
+
+
+@HOOKS.register_module()
+class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
+
+ def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs):
+ assert (min_momentum is None) ^ (min_momentum_ratio is None)
+ self.min_momentum = min_momentum
+ self.min_momentum_ratio = min_momentum_ratio
+ super(CosineAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
+
+ def get_momentum(self, runner, base_momentum):
+ if self.by_epoch:
+ progress = runner.epoch
+ max_progress = runner.max_epochs
+ else:
+ progress = runner.iter
+ max_progress = runner.max_iters
+ if self.min_momentum_ratio is not None:
+ target_momentum = base_momentum * self.min_momentum_ratio
+ else:
+ target_momentum = self.min_momentum
+ return annealing_cos(base_momentum, target_momentum,
+ progress / max_progress)
+
+
+@HOOKS.register_module()
+class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
+ """Cyclic momentum Scheduler.
+
+ Implement the cyclical momentum scheduler policy described in
+ https://arxiv.org/pdf/1708.07120.pdf
+
+ This momentum scheduler usually used together with the CyclicLRUpdater
+ to improve the performance in the 3D detection area.
+
+ Attributes:
+ target_ratio (tuple[float]): Relative ratio of the lowest momentum and
+ the highest momentum to the initial momentum.
+ cyclic_times (int): Number of cycles during training
+ step_ratio_up (float): The ratio of the increasing process of momentum
+ in the total cycle.
+ by_epoch (bool): Whether to update momentum by epoch.
+ """
+
+ def __init__(self,
+ by_epoch=False,
+ target_ratio=(0.85 / 0.95, 1),
+ cyclic_times=1,
+ step_ratio_up=0.4,
+ **kwargs):
+ if isinstance(target_ratio, float):
+ target_ratio = (target_ratio, target_ratio / 1e5)
+ elif isinstance(target_ratio, tuple):
+ target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
+ if len(target_ratio) == 1 else target_ratio
+ else:
+ raise ValueError('target_ratio should be either float '
+ f'or tuple, got {type(target_ratio)}')
+
+ assert len(target_ratio) == 2, \
+ '"target_ratio" must be list or tuple of two floats'
+ assert 0 <= step_ratio_up < 1.0, \
+ '"step_ratio_up" must be in range [0,1)'
+
+ self.target_ratio = target_ratio
+ self.cyclic_times = cyclic_times
+ self.step_ratio_up = step_ratio_up
+ self.momentum_phases = [] # init momentum_phases
+ # currently only support by_epoch=False
+ assert not by_epoch, \
+ 'currently only support "by_epoch" = False'
+ super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs)
+
+ def before_run(self, runner):
+ super(CyclicMomentumUpdaterHook, self).before_run(runner)
+ # initiate momentum_phases
+ # total momentum_phases are separated as up and down
+ max_iter_per_phase = runner.max_iters // self.cyclic_times
+ iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
+ self.momentum_phases.append(
+ [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
+ self.momentum_phases.append([
+ iter_up_phase, max_iter_per_phase, max_iter_per_phase,
+ self.target_ratio[0], self.target_ratio[1]
+ ])
+
+ def get_momentum(self, runner, base_momentum):
+ curr_iter = runner.iter
+ for (start_iter, end_iter, max_iter_per_phase, start_ratio,
+ end_ratio) in self.momentum_phases:
+ curr_iter %= max_iter_per_phase
+ if start_iter <= curr_iter < end_iter:
+ progress = curr_iter - start_iter
+ return annealing_cos(base_momentum * start_ratio,
+ base_momentum * end_ratio,
+ progress / (end_iter - start_iter))
+
+
+@HOOKS.register_module()
+class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
+ """OneCycle momentum Scheduler.
+
+ This momentum scheduler usually used together with the OneCycleLrUpdater
+ to improve the performance.
+
+ Args:
+ base_momentum (float or list): Lower momentum boundaries in the cycle
+ for each parameter group. Note that momentum is cycled inversely
+ to learning rate; at the peak of a cycle, momentum is
+ 'base_momentum' and learning rate is 'max_lr'.
+ Default: 0.85
+ max_momentum (float or list): Upper momentum boundaries in the cycle
+ for each parameter group. Functionally,
+ it defines the cycle amplitude (max_momentum - base_momentum).
+ Note that momentum is cycled inversely
+ to learning rate; at the start of a cycle, momentum is
+ 'max_momentum' and learning rate is 'base_lr'
+ Default: 0.95
+ pct_start (float): The percentage of the cycle (in number of steps)
+ spent increasing the learning rate.
+ Default: 0.3
+ anneal_strategy (str): {'cos', 'linear'}
+ Specifies the annealing strategy: 'cos' for cosine annealing,
+ 'linear' for linear annealing.
+ Default: 'cos'
+ three_phase (bool): If three_phase is True, use a third phase of the
+ schedule to annihilate the learning rate according to
+ final_div_factor instead of modifying the second phase (the first
+ two phases will be symmetrical about the step indicated by
+ pct_start).
+ Default: False
+ """
+
+ def __init__(self,
+ base_momentum=0.85,
+ max_momentum=0.95,
+ pct_start=0.3,
+ anneal_strategy='cos',
+ three_phase=False,
+ **kwargs):
+ # validate by_epoch, currently only support by_epoch=False
+ if 'by_epoch' not in kwargs:
+ kwargs['by_epoch'] = False
+ else:
+ assert not kwargs['by_epoch'], \
+ 'currently only support "by_epoch" = False'
+ if not isinstance(base_momentum, (float, list, dict)):
+ raise ValueError('base_momentum must be the type among of float,'
+ 'list or dict.')
+ self._base_momentum = base_momentum
+ if not isinstance(max_momentum, (float, list, dict)):
+ raise ValueError('max_momentum must be the type among of float,'
+ 'list or dict.')
+ self._max_momentum = max_momentum
+ # validate pct_start
+ if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
+ raise ValueError('Expected float between 0 and 1 pct_start, but '
+ f'got {pct_start}')
+ self.pct_start = pct_start
+ # validate anneal_strategy
+ if anneal_strategy not in ['cos', 'linear']:
+ raise ValueError('anneal_strategy must by one of "cos" or '
+ f'"linear", instead got {anneal_strategy}')
+ elif anneal_strategy == 'cos':
+ self.anneal_func = annealing_cos
+ elif anneal_strategy == 'linear':
+ self.anneal_func = annealing_linear
+ self.three_phase = three_phase
+ self.momentum_phases = [] # init momentum_phases
+ super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs)
+
+ def before_run(self, runner):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ if ('momentum' not in optim.defaults
+ and 'betas' not in optim.defaults):
+ raise ValueError('optimizer must support momentum with'
+ 'option enabled')
+ self.use_beta1 = 'betas' in optim.defaults
+ _base_momentum = format_param(k, optim, self._base_momentum)
+ _max_momentum = format_param(k, optim, self._max_momentum)
+ for group, b_momentum, m_momentum in zip(
+ optim.param_groups, _base_momentum, _max_momentum):
+ if self.use_beta1:
+ _, beta2 = group['betas']
+ group['betas'] = (m_momentum, beta2)
+ else:
+ group['momentum'] = m_momentum
+ group['base_momentum'] = b_momentum
+ group['max_momentum'] = m_momentum
+ else:
+ optim = runner.optimizer
+ if ('momentum' not in optim.defaults
+ and 'betas' not in optim.defaults):
+ raise ValueError('optimizer must support momentum with'
+ 'option enabled')
+ self.use_beta1 = 'betas' in optim.defaults
+ k = type(optim).__name__
+ _base_momentum = format_param(k, optim, self._base_momentum)
+ _max_momentum = format_param(k, optim, self._max_momentum)
+ for group, b_momentum, m_momentum in zip(optim.param_groups,
+ _base_momentum,
+ _max_momentum):
+ if self.use_beta1:
+ _, beta2 = group['betas']
+ group['betas'] = (m_momentum, beta2)
+ else:
+ group['momentum'] = m_momentum
+ group['base_momentum'] = b_momentum
+ group['max_momentum'] = m_momentum
+
+ if self.three_phase:
+ self.momentum_phases.append({
+ 'end_iter':
+ float(self.pct_start * runner.max_iters) - 1,
+ 'start_momentum':
+ 'max_momentum',
+ 'end_momentum':
+ 'base_momentum'
+ })
+ self.momentum_phases.append({
+ 'end_iter':
+ float(2 * self.pct_start * runner.max_iters) - 2,
+ 'start_momentum':
+ 'base_momentum',
+ 'end_momentum':
+ 'max_momentum'
+ })
+ self.momentum_phases.append({
+ 'end_iter': runner.max_iters - 1,
+ 'start_momentum': 'max_momentum',
+ 'end_momentum': 'max_momentum'
+ })
+ else:
+ self.momentum_phases.append({
+ 'end_iter':
+ float(self.pct_start * runner.max_iters) - 1,
+ 'start_momentum':
+ 'max_momentum',
+ 'end_momentum':
+ 'base_momentum'
+ })
+ self.momentum_phases.append({
+ 'end_iter': runner.max_iters - 1,
+ 'start_momentum': 'base_momentum',
+ 'end_momentum': 'max_momentum'
+ })
+
+ def _set_momentum(self, runner, momentum_groups):
+ if isinstance(runner.optimizer, dict):
+ for k, optim in runner.optimizer.items():
+ for param_group, mom in zip(optim.param_groups,
+ momentum_groups[k]):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+ else:
+ for param_group, mom in zip(runner.optimizer.param_groups,
+ momentum_groups):
+ if 'momentum' in param_group.keys():
+ param_group['momentum'] = mom
+ elif 'betas' in param_group.keys():
+ param_group['betas'] = (mom, param_group['betas'][1])
+
+ def get_momentum(self, runner, param_group):
+ curr_iter = runner.iter
+ start_iter = 0
+ for i, phase in enumerate(self.momentum_phases):
+ end_iter = phase['end_iter']
+ if curr_iter <= end_iter or i == len(self.momentum_phases) - 1:
+ pct = (curr_iter - start_iter) / (end_iter - start_iter)
+ momentum = self.anneal_func(
+ param_group[phase['start_momentum']],
+ param_group[phase['end_momentum']], pct)
+ break
+ start_iter = end_iter
+ return momentum
+
+ def get_regular_momentum(self, runner):
+ if isinstance(runner.optimizer, dict):
+ momentum_groups = {}
+ for k, optim in runner.optimizer.items():
+ _momentum_group = [
+ self.get_momentum(runner, param_group)
+ for param_group in optim.param_groups
+ ]
+ momentum_groups.update({k: _momentum_group})
+ return momentum_groups
+ else:
+ momentum_groups = []
+ for param_group in runner.optimizer.param_groups:
+ momentum_groups.append(self.get_momentum(runner, param_group))
+ return momentum_groups
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/optimizer.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..580a183639a5d95c04ecae9c619afb795a169e9e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/optimizer.py
@@ -0,0 +1,508 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+from collections import defaultdict
+from itertools import chain
+
+from torch.nn.utils import clip_grad
+
+from annotator.mmpkg.mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version
+from ..dist_utils import allreduce_grads
+from ..fp16_utils import LossScaler, wrap_fp16_model
+from .hook import HOOKS, Hook
+
+try:
+ # If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported
+ # and used; otherwise, auto fp16 will adopt mmcv's implementation.
+ from torch.cuda.amp import GradScaler
+except ImportError:
+ pass
+
+
+@HOOKS.register_module()
+class OptimizerHook(Hook):
+
+ def __init__(self, grad_clip=None):
+ self.grad_clip = grad_clip
+
+ def clip_grads(self, params):
+ params = list(
+ filter(lambda p: p.requires_grad and p.grad is not None, params))
+ if len(params) > 0:
+ return clip_grad.clip_grad_norm_(params, **self.grad_clip)
+
+ def after_train_iter(self, runner):
+ runner.optimizer.zero_grad()
+ runner.outputs['loss'].backward()
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ runner.optimizer.step()
+
+
+@HOOKS.register_module()
+class GradientCumulativeOptimizerHook(OptimizerHook):
+ """Optimizer Hook implements multi-iters gradient cumulating.
+
+ Args:
+ cumulative_iters (int, optional): Num of gradient cumulative iters.
+ The optimizer will step every `cumulative_iters` iters.
+ Defaults to 1.
+
+ Examples:
+ >>> # Use cumulative_iters to simulate a large batch size
+ >>> # It is helpful when the hardware cannot handle a large batch size.
+ >>> loader = DataLoader(data, batch_size=64)
+ >>> optim_hook = GradientCumulativeOptimizerHook(cumulative_iters=4)
+ >>> # almost equals to
+ >>> loader = DataLoader(data, batch_size=256)
+ >>> optim_hook = OptimizerHook()
+ """
+
+ def __init__(self, cumulative_iters=1, **kwargs):
+ super(GradientCumulativeOptimizerHook, self).__init__(**kwargs)
+
+ assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
+ f'cumulative_iters only accepts positive int, but got ' \
+ f'{type(cumulative_iters)} instead.'
+
+ self.cumulative_iters = cumulative_iters
+ self.divisible_iters = 0
+ self.remainder_iters = 0
+ self.initialized = False
+
+ def has_batch_norm(self, module):
+ if isinstance(module, _BatchNorm):
+ return True
+ for m in module.children():
+ if self.has_batch_norm(m):
+ return True
+ return False
+
+ def _init(self, runner):
+ if runner.iter % self.cumulative_iters != 0:
+ runner.logger.warning(
+ 'Resume iter number is not divisible by cumulative_iters in '
+ 'GradientCumulativeOptimizerHook, which means the gradient of '
+ 'some iters is lost and the result may be influenced slightly.'
+ )
+
+ if self.has_batch_norm(runner.model) and self.cumulative_iters > 1:
+ runner.logger.warning(
+ 'GradientCumulativeOptimizerHook may slightly decrease '
+ 'performance if the model has BatchNorm layers.')
+
+ residual_iters = runner.max_iters - runner.iter
+
+ self.divisible_iters = (
+ residual_iters // self.cumulative_iters * self.cumulative_iters)
+ self.remainder_iters = residual_iters - self.divisible_iters
+
+ self.initialized = True
+
+ def after_train_iter(self, runner):
+ if not self.initialized:
+ self._init(runner)
+
+ if runner.iter < self.divisible_iters:
+ loss_factor = self.cumulative_iters
+ else:
+ loss_factor = self.remainder_iters
+ loss = runner.outputs['loss']
+ loss = loss / loss_factor
+ loss.backward()
+
+ if (self.every_n_iters(runner, self.cumulative_iters)
+ or self.is_last_iter(runner)):
+
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ runner.optimizer.step()
+ runner.optimizer.zero_grad()
+
+
+if (TORCH_VERSION != 'parrots'
+ and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+
+ @HOOKS.register_module()
+ class Fp16OptimizerHook(OptimizerHook):
+ """FP16 optimizer hook (using PyTorch's implementation).
+
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
+ to take care of the optimization procedure.
+
+ Args:
+ loss_scale (float | str | dict): Scale factor configuration.
+ If loss_scale is a float, static loss scaling will be used with
+ the specified scale. If loss_scale is a string, it must be
+ 'dynamic', then dynamic loss scaling will be used.
+ It can also be a dict containing arguments of GradScalar.
+ Defaults to 512. For Pytorch >= 1.6, mmcv uses official
+ implementation of GradScaler. If you use a dict version of
+ loss_scale to create GradScaler, please refer to:
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
+ for the parameters.
+
+ Examples:
+ >>> loss_scale = dict(
+ ... init_scale=65536.0,
+ ... growth_factor=2.0,
+ ... backoff_factor=0.5,
+ ... growth_interval=2000
+ ... )
+ >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale)
+ """
+
+ def __init__(self,
+ grad_clip=None,
+ coalesce=True,
+ bucket_size_mb=-1,
+ loss_scale=512.,
+ distributed=True):
+ self.grad_clip = grad_clip
+ self.coalesce = coalesce
+ self.bucket_size_mb = bucket_size_mb
+ self.distributed = distributed
+ self._scale_update_param = None
+ if loss_scale == 'dynamic':
+ self.loss_scaler = GradScaler()
+ elif isinstance(loss_scale, float):
+ self._scale_update_param = loss_scale
+ self.loss_scaler = GradScaler(init_scale=loss_scale)
+ elif isinstance(loss_scale, dict):
+ self.loss_scaler = GradScaler(**loss_scale)
+ else:
+ raise ValueError('loss_scale must be of type float, dict, or '
+ f'"dynamic", got {loss_scale}')
+
+ def before_run(self, runner):
+ """Preparing steps before Mixed Precision Training."""
+ # wrap model mode to fp16
+ wrap_fp16_model(runner.model)
+ # resume from state dict
+ if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
+ scaler_state_dict = runner.meta['fp16']['loss_scaler']
+ self.loss_scaler.load_state_dict(scaler_state_dict)
+
+ def copy_grads_to_fp32(self, fp16_net, fp32_weights):
+ """Copy gradients from fp16 model to fp32 weight copy."""
+ for fp32_param, fp16_param in zip(fp32_weights,
+ fp16_net.parameters()):
+ if fp16_param.grad is not None:
+ if fp32_param.grad is None:
+ fp32_param.grad = fp32_param.data.new(
+ fp32_param.size())
+ fp32_param.grad.copy_(fp16_param.grad)
+
+ def copy_params_to_fp16(self, fp16_net, fp32_weights):
+ """Copy updated params from fp32 weight copy to fp16 model."""
+ for fp16_param, fp32_param in zip(fp16_net.parameters(),
+ fp32_weights):
+ fp16_param.data.copy_(fp32_param.data)
+
+ def after_train_iter(self, runner):
+ """Backward optimization steps for Mixed Precision Training. For
+ dynamic loss scaling, please refer to
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.
+
+ 1. Scale the loss by a scale factor.
+ 2. Backward the loss to obtain the gradients.
+ 3. Unscale the optimizer’s gradient tensors.
+ 4. Call optimizer.step() and update scale factor.
+ 5. Save loss_scaler state_dict for resume purpose.
+ """
+ # clear grads of last iteration
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
+
+ self.loss_scaler.scale(runner.outputs['loss']).backward()
+ self.loss_scaler.unscale_(runner.optimizer)
+ # grad clip
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ # backward and update scaler
+ self.loss_scaler.step(runner.optimizer)
+ self.loss_scaler.update(self._scale_update_param)
+
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+ @HOOKS.register_module()
+ class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
+ Fp16OptimizerHook):
+ """Fp16 optimizer Hook (using PyTorch's implementation) implements
+ multi-iters gradient cumulating.
+
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
+ to take care of the optimization procedure.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(GradientCumulativeFp16OptimizerHook,
+ self).__init__(*args, **kwargs)
+
+ def after_train_iter(self, runner):
+ if not self.initialized:
+ self._init(runner)
+
+ if runner.iter < self.divisible_iters:
+ loss_factor = self.cumulative_iters
+ else:
+ loss_factor = self.remainder_iters
+ loss = runner.outputs['loss']
+ loss = loss / loss_factor
+
+ self.loss_scaler.scale(loss).backward()
+
+ if (self.every_n_iters(runner, self.cumulative_iters)
+ or self.is_last_iter(runner)):
+
+ # copy fp16 grads in the model to fp32 params in the optimizer
+ self.loss_scaler.unscale_(runner.optimizer)
+
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(runner.model.parameters())
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update(
+ {'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+
+ # backward and update scaler
+ self.loss_scaler.step(runner.optimizer)
+ self.loss_scaler.update(self._scale_update_param)
+
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+ # clear grads
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
+
+else:
+
+ @HOOKS.register_module()
+ class Fp16OptimizerHook(OptimizerHook):
+ """FP16 optimizer hook (mmcv's implementation).
+
+ The steps of fp16 optimizer is as follows.
+ 1. Scale the loss value.
+ 2. BP in the fp16 model.
+ 2. Copy gradients from fp16 model to fp32 weights.
+ 3. Update fp32 weights.
+ 4. Copy updated parameters from fp32 weights to fp16 model.
+
+ Refer to https://arxiv.org/abs/1710.03740 for more details.
+
+ Args:
+ loss_scale (float | str | dict): Scale factor configuration.
+ If loss_scale is a float, static loss scaling will be used with
+ the specified scale. If loss_scale is a string, it must be
+ 'dynamic', then dynamic loss scaling will be used.
+ It can also be a dict containing arguments of LossScaler.
+ Defaults to 512.
+ """
+
+ def __init__(self,
+ grad_clip=None,
+ coalesce=True,
+ bucket_size_mb=-1,
+ loss_scale=512.,
+ distributed=True):
+ self.grad_clip = grad_clip
+ self.coalesce = coalesce
+ self.bucket_size_mb = bucket_size_mb
+ self.distributed = distributed
+ if loss_scale == 'dynamic':
+ self.loss_scaler = LossScaler(mode='dynamic')
+ elif isinstance(loss_scale, float):
+ self.loss_scaler = LossScaler(
+ init_scale=loss_scale, mode='static')
+ elif isinstance(loss_scale, dict):
+ self.loss_scaler = LossScaler(**loss_scale)
+ else:
+ raise ValueError('loss_scale must be of type float, dict, or '
+ f'"dynamic", got {loss_scale}')
+
+ def before_run(self, runner):
+ """Preparing steps before Mixed Precision Training.
+
+ 1. Make a master copy of fp32 weights for optimization.
+ 2. Convert the main model from fp32 to fp16.
+ """
+ # keep a copy of fp32 weights
+ old_groups = runner.optimizer.param_groups
+ runner.optimizer.param_groups = copy.deepcopy(
+ runner.optimizer.param_groups)
+ state = defaultdict(dict)
+ p_map = {
+ old_p: p
+ for old_p, p in zip(
+ chain(*(g['params'] for g in old_groups)),
+ chain(*(g['params']
+ for g in runner.optimizer.param_groups)))
+ }
+ for k, v in runner.optimizer.state.items():
+ state[p_map[k]] = v
+ runner.optimizer.state = state
+ # convert model to fp16
+ wrap_fp16_model(runner.model)
+ # resume from state dict
+ if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
+ scaler_state_dict = runner.meta['fp16']['loss_scaler']
+ self.loss_scaler.load_state_dict(scaler_state_dict)
+
+ def copy_grads_to_fp32(self, fp16_net, fp32_weights):
+ """Copy gradients from fp16 model to fp32 weight copy."""
+ for fp32_param, fp16_param in zip(fp32_weights,
+ fp16_net.parameters()):
+ if fp16_param.grad is not None:
+ if fp32_param.grad is None:
+ fp32_param.grad = fp32_param.data.new(
+ fp32_param.size())
+ fp32_param.grad.copy_(fp16_param.grad)
+
+ def copy_params_to_fp16(self, fp16_net, fp32_weights):
+ """Copy updated params from fp32 weight copy to fp16 model."""
+ for fp16_param, fp32_param in zip(fp16_net.parameters(),
+ fp32_weights):
+ fp16_param.data.copy_(fp32_param.data)
+
+ def after_train_iter(self, runner):
+ """Backward optimization steps for Mixed Precision Training. For
+ dynamic loss scaling, please refer `loss_scalar.py`
+
+ 1. Scale the loss by a scale factor.
+ 2. Backward the loss to obtain the gradients (fp16).
+ 3. Copy gradients from the model to the fp32 weight copy.
+ 4. Scale the gradients back and update the fp32 weight copy.
+ 5. Copy back the params from fp32 weight copy to the fp16 model.
+ 6. Save loss_scaler state_dict for resume purpose.
+ """
+ # clear grads of last iteration
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
+ # scale the loss value
+ scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale
+ scaled_loss.backward()
+ # copy fp16 grads in the model to fp32 params in the optimizer
+
+ fp32_weights = []
+ for param_group in runner.optimizer.param_groups:
+ fp32_weights += param_group['params']
+ self.copy_grads_to_fp32(runner.model, fp32_weights)
+ # allreduce grads
+ if self.distributed:
+ allreduce_grads(fp32_weights, self.coalesce,
+ self.bucket_size_mb)
+
+ has_overflow = self.loss_scaler.has_overflow(fp32_weights)
+ # if has overflow, skip this iteration
+ if not has_overflow:
+ # scale the gradients back
+ for param in fp32_weights:
+ if param.grad is not None:
+ param.grad.div_(self.loss_scaler.loss_scale)
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(fp32_weights)
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update(
+ {'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ # update fp32 params
+ runner.optimizer.step()
+ # copy fp32 params to the fp16 model
+ self.copy_params_to_fp16(runner.model, fp32_weights)
+ self.loss_scaler.update_scale(has_overflow)
+ if has_overflow:
+ runner.logger.warning('Check overflow, downscale loss scale '
+ f'to {self.loss_scaler.cur_scale}')
+
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+ @HOOKS.register_module()
+ class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
+ Fp16OptimizerHook):
+ """Fp16 optimizer Hook (using mmcv implementation) implements multi-
+ iters gradient cumulating."""
+
+ def __init__(self, *args, **kwargs):
+ super(GradientCumulativeFp16OptimizerHook,
+ self).__init__(*args, **kwargs)
+
+ def after_train_iter(self, runner):
+ if not self.initialized:
+ self._init(runner)
+
+ if runner.iter < self.divisible_iters:
+ loss_factor = self.cumulative_iters
+ else:
+ loss_factor = self.remainder_iters
+
+ loss = runner.outputs['loss']
+ loss = loss / loss_factor
+
+ # scale the loss value
+ scaled_loss = loss * self.loss_scaler.loss_scale
+ scaled_loss.backward()
+
+ if (self.every_n_iters(runner, self.cumulative_iters)
+ or self.is_last_iter(runner)):
+
+ # copy fp16 grads in the model to fp32 params in the optimizer
+ fp32_weights = []
+ for param_group in runner.optimizer.param_groups:
+ fp32_weights += param_group['params']
+ self.copy_grads_to_fp32(runner.model, fp32_weights)
+ # allreduce grads
+ if self.distributed:
+ allreduce_grads(fp32_weights, self.coalesce,
+ self.bucket_size_mb)
+
+ has_overflow = self.loss_scaler.has_overflow(fp32_weights)
+ # if has overflow, skip this iteration
+ if not has_overflow:
+ # scale the gradients back
+ for param in fp32_weights:
+ if param.grad is not None:
+ param.grad.div_(self.loss_scaler.loss_scale)
+ if self.grad_clip is not None:
+ grad_norm = self.clip_grads(fp32_weights)
+ if grad_norm is not None:
+ # Add grad norm to the logger
+ runner.log_buffer.update(
+ {'grad_norm': float(grad_norm)},
+ runner.outputs['num_samples'])
+ # update fp32 params
+ runner.optimizer.step()
+ # copy fp32 params to the fp16 model
+ self.copy_params_to_fp16(runner.model, fp32_weights)
+ else:
+ runner.logger.warning(
+ 'Check overflow, downscale loss scale '
+ f'to {self.loss_scaler.cur_scale}')
+
+ self.loss_scaler.update_scale(has_overflow)
+
+ # save state_dict of loss_scaler
+ runner.meta.setdefault(
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+ # clear grads
+ runner.model.zero_grad()
+ runner.optimizer.zero_grad()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/profiler.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b70236997eec59c2209ef351ae38863b4112d0ec
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/profiler.py
@@ -0,0 +1,180 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Callable, List, Optional, Union
+
+import torch
+
+from ..dist_utils import master_only
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class ProfilerHook(Hook):
+ """Profiler to analyze performance during training.
+
+ PyTorch Profiler is a tool that allows the collection of the performance
+ metrics during the training. More details on Profiler can be found at
+ https://pytorch.org/docs/1.8.1/profiler.html#torch.profiler.profile
+
+ Args:
+ by_epoch (bool): Profile performance by epoch or by iteration.
+ Default: True.
+ profile_iters (int): Number of iterations for profiling.
+ If ``by_epoch=True``, profile_iters indicates that they are the
+ first profile_iters epochs at the beginning of the
+ training, otherwise it indicates the first profile_iters
+ iterations. Default: 1.
+ activities (list[str]): List of activity groups (CPU, CUDA) to use in
+ profiling. Default: ['cpu', 'cuda'].
+ schedule (dict, optional): Config of generating the callable schedule.
+ if schedule is None, profiler will not add step markers into the
+ trace and table view. Default: None.
+ on_trace_ready (callable, dict): Either a handler or a dict of generate
+ handler. Default: None.
+ record_shapes (bool): Save information about operator's input shapes.
+ Default: False.
+ profile_memory (bool): Track tensor memory allocation/deallocation.
+ Default: False.
+ with_stack (bool): Record source information (file and line number)
+ for the ops. Default: False.
+ with_flops (bool): Use formula to estimate the FLOPS of specific
+ operators (matrix multiplication and 2D convolution).
+ Default: False.
+ json_trace_path (str, optional): Exports the collected trace in Chrome
+ JSON format. Default: None.
+
+ Example:
+ >>> runner = ... # instantiate a Runner
+ >>> # tensorboard trace
+ >>> trace_config = dict(type='tb_trace', dir_name='work_dir')
+ >>> profiler_config = dict(on_trace_ready=trace_config)
+ >>> runner.register_profiler_hook(profiler_config)
+ >>> runner.run(data_loaders=[trainloader], workflow=[('train', 1)])
+ """
+
+ def __init__(self,
+ by_epoch: bool = True,
+ profile_iters: int = 1,
+ activities: List[str] = ['cpu', 'cuda'],
+ schedule: Optional[dict] = None,
+ on_trace_ready: Optional[Union[Callable, dict]] = None,
+ record_shapes: bool = False,
+ profile_memory: bool = False,
+ with_stack: bool = False,
+ with_flops: bool = False,
+ json_trace_path: Optional[str] = None) -> None:
+ try:
+ from torch import profiler # torch version >= 1.8.1
+ except ImportError:
+ raise ImportError('profiler is the new feature of torch1.8.1, '
+ f'but your version is {torch.__version__}')
+
+ assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.'
+ self.by_epoch = by_epoch
+
+ if profile_iters < 1:
+ raise ValueError('profile_iters should be greater than 0, but got '
+ f'{profile_iters}')
+ self.profile_iters = profile_iters
+
+ if not isinstance(activities, list):
+ raise ValueError(
+ f'activities should be list, but got {type(activities)}')
+ self.activities = []
+ for activity in activities:
+ activity = activity.lower()
+ if activity == 'cpu':
+ self.activities.append(profiler.ProfilerActivity.CPU)
+ elif activity == 'cuda':
+ self.activities.append(profiler.ProfilerActivity.CUDA)
+ else:
+ raise ValueError(
+ f'activity should be "cpu" or "cuda", but got {activity}')
+
+ if schedule is not None:
+ self.schedule = profiler.schedule(**schedule)
+ else:
+ self.schedule = None
+
+ self.on_trace_ready = on_trace_ready
+ self.record_shapes = record_shapes
+ self.profile_memory = profile_memory
+ self.with_stack = with_stack
+ self.with_flops = with_flops
+ self.json_trace_path = json_trace_path
+
+ @master_only
+ def before_run(self, runner):
+ if self.by_epoch and runner.max_epochs < self.profile_iters:
+ raise ValueError('self.profile_iters should not be greater than '
+ f'{runner.max_epochs}')
+
+ if not self.by_epoch and runner.max_iters < self.profile_iters:
+ raise ValueError('self.profile_iters should not be greater than '
+ f'{runner.max_iters}')
+
+ if callable(self.on_trace_ready): # handler
+ _on_trace_ready = self.on_trace_ready
+ elif isinstance(self.on_trace_ready, dict): # config of handler
+ trace_cfg = self.on_trace_ready.copy()
+ trace_type = trace_cfg.pop('type') # log_trace handler
+ if trace_type == 'log_trace':
+
+ def _log_handler(prof):
+ print(prof.key_averages().table(**trace_cfg))
+
+ _on_trace_ready = _log_handler
+ elif trace_type == 'tb_trace': # tensorboard_trace handler
+ try:
+ import torch_tb_profiler # noqa: F401
+ except ImportError:
+ raise ImportError('please run "pip install '
+ 'torch-tb-profiler" to install '
+ 'torch_tb_profiler')
+ _on_trace_ready = torch.profiler.tensorboard_trace_handler(
+ **trace_cfg)
+ else:
+ raise ValueError('trace_type should be "log_trace" or '
+ f'"tb_trace", but got {trace_type}')
+ elif self.on_trace_ready is None:
+ _on_trace_ready = None # type: ignore
+ else:
+ raise ValueError('on_trace_ready should be handler, dict or None, '
+ f'but got {type(self.on_trace_ready)}')
+
+ if runner.max_epochs > 1:
+ warnings.warn(f'profiler will profile {runner.max_epochs} epochs '
+ 'instead of 1 epoch. Since profiler will slow down '
+ 'the training, it is recommended to train 1 epoch '
+ 'with ProfilerHook and adjust your setting according'
+ ' to the profiler summary. During normal training '
+ '(epoch > 1), you may disable the ProfilerHook.')
+
+ self.profiler = torch.profiler.profile(
+ activities=self.activities,
+ schedule=self.schedule,
+ on_trace_ready=_on_trace_ready,
+ record_shapes=self.record_shapes,
+ profile_memory=self.profile_memory,
+ with_stack=self.with_stack,
+ with_flops=self.with_flops)
+
+ self.profiler.__enter__()
+ runner.logger.info('profiler is profiling...')
+
+ @master_only
+ def after_train_epoch(self, runner):
+ if self.by_epoch and runner.epoch == self.profile_iters - 1:
+ runner.logger.info('profiler may take a few minutes...')
+ self.profiler.__exit__(None, None, None)
+ if self.json_trace_path is not None:
+ self.profiler.export_chrome_trace(self.json_trace_path)
+
+ @master_only
+ def after_train_iter(self, runner):
+ self.profiler.step()
+ if not self.by_epoch and runner.iter == self.profile_iters - 1:
+ runner.logger.info('profiler may take a few minutes...')
+ self.profiler.__exit__(None, None, None)
+ if self.json_trace_path is not None:
+ self.profiler.export_chrome_trace(self.json_trace_path)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/sampler_seed.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/sampler_seed.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee0dc6bdd8df5775857028aaed5444c0f59caf80
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/sampler_seed.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class DistSamplerSeedHook(Hook):
+ """Data-loading sampler for distributed training.
+
+ When distributed training, it is only useful in conjunction with
+ :obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same
+ purpose with :obj:`IterLoader`.
+ """
+
+ def before_epoch(self, runner):
+ if hasattr(runner.data_loader.sampler, 'set_epoch'):
+ # in case the data loader uses `SequentialSampler` in Pytorch
+ runner.data_loader.sampler.set_epoch(runner.epoch)
+ elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
+ # batch sampler in pytorch warps the sampler as its attributes.
+ runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/sync_buffer.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/sync_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6376b7ff894280cb2782243b25e8973650591577
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/hooks/sync_buffer.py
@@ -0,0 +1,22 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..dist_utils import allreduce_params
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class SyncBuffersHook(Hook):
+ """Synchronize model buffers such as running_mean and running_var in BN at
+ the end of each epoch.
+
+ Args:
+ distributed (bool): Whether distributed training is used. It is
+ effective only for distributed training. Defaults to True.
+ """
+
+ def __init__(self, distributed=True):
+ self.distributed = distributed
+
+ def after_epoch(self, runner):
+ """All-reduce model buffers at the end of each epoch."""
+ if self.distributed:
+ allreduce_params(runner.model.buffers())
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/iter_based_runner.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/iter_based_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..e93849ba8a0960d958c76151d5bdd406e4b795a4
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/iter_based_runner.py
@@ -0,0 +1,273 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import platform
+import shutil
+import time
+import warnings
+
+import torch
+from torch.optim import Optimizer
+
+import annotator.mmpkg.mmcv as mmcv
+from .base_runner import BaseRunner
+from .builder import RUNNERS
+from .checkpoint import save_checkpoint
+from .hooks import IterTimerHook
+from .utils import get_host_info
+
+
+class IterLoader:
+
+ def __init__(self, dataloader):
+ self._dataloader = dataloader
+ self.iter_loader = iter(self._dataloader)
+ self._epoch = 0
+
+ @property
+ def epoch(self):
+ return self._epoch
+
+ def __next__(self):
+ try:
+ data = next(self.iter_loader)
+ except StopIteration:
+ self._epoch += 1
+ if hasattr(self._dataloader.sampler, 'set_epoch'):
+ self._dataloader.sampler.set_epoch(self._epoch)
+ time.sleep(2) # Prevent possible deadlock during epoch transition
+ self.iter_loader = iter(self._dataloader)
+ data = next(self.iter_loader)
+
+ return data
+
+ def __len__(self):
+ return len(self._dataloader)
+
+
+@RUNNERS.register_module()
+class IterBasedRunner(BaseRunner):
+ """Iteration-based Runner.
+
+ This runner train models iteration by iteration.
+ """
+
+ def train(self, data_loader, **kwargs):
+ self.model.train()
+ self.mode = 'train'
+ self.data_loader = data_loader
+ self._epoch = data_loader.epoch
+ data_batch = next(data_loader)
+ self.call_hook('before_train_iter')
+ outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
+ if not isinstance(outputs, dict):
+ raise TypeError('model.train_step() must return a dict')
+ if 'log_vars' in outputs:
+ self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+ self.outputs = outputs
+ self.call_hook('after_train_iter')
+ self._inner_iter += 1
+ self._iter += 1
+
+ @torch.no_grad()
+ def val(self, data_loader, **kwargs):
+ self.model.eval()
+ self.mode = 'val'
+ self.data_loader = data_loader
+ data_batch = next(data_loader)
+ self.call_hook('before_val_iter')
+ outputs = self.model.val_step(data_batch, **kwargs)
+ if not isinstance(outputs, dict):
+ raise TypeError('model.val_step() must return a dict')
+ if 'log_vars' in outputs:
+ self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+ self.outputs = outputs
+ self.call_hook('after_val_iter')
+ self._inner_iter += 1
+
+ def run(self, data_loaders, workflow, max_iters=None, **kwargs):
+ """Start running.
+
+ Args:
+ data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
+ and validation.
+ workflow (list[tuple]): A list of (phase, iters) to specify the
+ running order and iterations. E.g, [('train', 10000),
+ ('val', 1000)] means running 10000 iterations for training and
+ 1000 iterations for validation, iteratively.
+ """
+ assert isinstance(data_loaders, list)
+ assert mmcv.is_list_of(workflow, tuple)
+ assert len(data_loaders) == len(workflow)
+ if max_iters is not None:
+ warnings.warn(
+ 'setting max_iters in run is deprecated, '
+ 'please set max_iters in runner_config', DeprecationWarning)
+ self._max_iters = max_iters
+ assert self._max_iters is not None, (
+ 'max_iters must be specified during instantiation')
+
+ work_dir = self.work_dir if self.work_dir is not None else 'NONE'
+ self.logger.info('Start running, host: %s, work_dir: %s',
+ get_host_info(), work_dir)
+ self.logger.info('Hooks will be executed in the following order:\n%s',
+ self.get_hook_info())
+ self.logger.info('workflow: %s, max: %d iters', workflow,
+ self._max_iters)
+ self.call_hook('before_run')
+
+ iter_loaders = [IterLoader(x) for x in data_loaders]
+
+ self.call_hook('before_epoch')
+
+ while self.iter < self._max_iters:
+ for i, flow in enumerate(workflow):
+ self._inner_iter = 0
+ mode, iters = flow
+ if not isinstance(mode, str) or not hasattr(self, mode):
+ raise ValueError(
+ 'runner has no method named "{}" to run a workflow'.
+ format(mode))
+ iter_runner = getattr(self, mode)
+ for _ in range(iters):
+ if mode == 'train' and self.iter >= self._max_iters:
+ break
+ iter_runner(iter_loaders[i], **kwargs)
+
+ time.sleep(1) # wait for some hooks like loggers to finish
+ self.call_hook('after_epoch')
+ self.call_hook('after_run')
+
+ def resume(self,
+ checkpoint,
+ resume_optimizer=True,
+ map_location='default'):
+ """Resume model from checkpoint.
+
+ Args:
+ checkpoint (str): Checkpoint to resume from.
+ resume_optimizer (bool, optional): Whether resume the optimizer(s)
+ if the checkpoint file includes optimizer(s). Default to True.
+ map_location (str, optional): Same as :func:`torch.load`.
+ Default to 'default'.
+ """
+ if map_location == 'default':
+ device_id = torch.cuda.current_device()
+ checkpoint = self.load_checkpoint(
+ checkpoint,
+ map_location=lambda storage, loc: storage.cuda(device_id))
+ else:
+ checkpoint = self.load_checkpoint(
+ checkpoint, map_location=map_location)
+
+ self._epoch = checkpoint['meta']['epoch']
+ self._iter = checkpoint['meta']['iter']
+ self._inner_iter = checkpoint['meta']['iter']
+ if 'optimizer' in checkpoint and resume_optimizer:
+ if isinstance(self.optimizer, Optimizer):
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
+ elif isinstance(self.optimizer, dict):
+ for k in self.optimizer.keys():
+ self.optimizer[k].load_state_dict(
+ checkpoint['optimizer'][k])
+ else:
+ raise TypeError(
+ 'Optimizer should be dict or torch.optim.Optimizer '
+ f'but got {type(self.optimizer)}')
+
+ self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
+
+ def save_checkpoint(self,
+ out_dir,
+ filename_tmpl='iter_{}.pth',
+ meta=None,
+ save_optimizer=True,
+ create_symlink=True):
+ """Save checkpoint to file.
+
+ Args:
+ out_dir (str): Directory to save checkpoint files.
+ filename_tmpl (str, optional): Checkpoint file template.
+ Defaults to 'iter_{}.pth'.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ Defaults to None.
+ save_optimizer (bool, optional): Whether save optimizer.
+ Defaults to True.
+ create_symlink (bool, optional): Whether create symlink to the
+ latest checkpoint file. Defaults to True.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(
+ f'meta should be a dict or None, but got {type(meta)}')
+ if self.meta is not None:
+ meta.update(self.meta)
+ # Note: meta.update(self.meta) should be done before
+ # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
+ # there will be problems with resumed checkpoints.
+ # More details in https://github.com/open-mmlab/mmcv/pull/1108
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
+
+ filename = filename_tmpl.format(self.iter + 1)
+ filepath = osp.join(out_dir, filename)
+ optimizer = self.optimizer if save_optimizer else None
+ save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
+ # in some environments, `os.symlink` is not supported, you may need to
+ # set `create_symlink` to False
+ if create_symlink:
+ dst_file = osp.join(out_dir, 'latest.pth')
+ if platform.system() != 'Windows':
+ mmcv.symlink(filename, dst_file)
+ else:
+ shutil.copy(filepath, dst_file)
+
+ def register_training_hooks(self,
+ lr_config,
+ optimizer_config=None,
+ checkpoint_config=None,
+ log_config=None,
+ momentum_config=None,
+ custom_hooks_config=None):
+ """Register default hooks for iter-based training.
+
+ Checkpoint hook, optimizer stepper hook and logger hooks will be set to
+ `by_epoch=False` by default.
+
+ Default hooks include:
+
+ +----------------------+-------------------------+
+ | Hooks | Priority |
+ +======================+=========================+
+ | LrUpdaterHook | VERY_HIGH (10) |
+ +----------------------+-------------------------+
+ | MomentumUpdaterHook | HIGH (30) |
+ +----------------------+-------------------------+
+ | OptimizerStepperHook | ABOVE_NORMAL (40) |
+ +----------------------+-------------------------+
+ | CheckpointSaverHook | NORMAL (50) |
+ +----------------------+-------------------------+
+ | IterTimerHook | LOW (70) |
+ +----------------------+-------------------------+
+ | LoggerHook(s) | VERY_LOW (90) |
+ +----------------------+-------------------------+
+ | CustomHook(s) | defaults to NORMAL (50) |
+ +----------------------+-------------------------+
+
+ If custom hooks have same priority with default hooks, custom hooks
+ will be triggered after default hooks.
+ """
+ if checkpoint_config is not None:
+ checkpoint_config.setdefault('by_epoch', False)
+ if lr_config is not None:
+ lr_config.setdefault('by_epoch', False)
+ if log_config is not None:
+ for info in log_config['hooks']:
+ info.setdefault('by_epoch', False)
+ super(IterBasedRunner, self).register_training_hooks(
+ lr_config=lr_config,
+ momentum_config=momentum_config,
+ optimizer_config=optimizer_config,
+ checkpoint_config=checkpoint_config,
+ log_config=log_config,
+ timer_config=IterTimerHook(),
+ custom_hooks_config=custom_hooks_config)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/log_buffer.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/log_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d949e2941c5400088c7cd8a1dc893d8b233ae785
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/log_buffer.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+
+import numpy as np
+
+
+class LogBuffer:
+
+ def __init__(self):
+ self.val_history = OrderedDict()
+ self.n_history = OrderedDict()
+ self.output = OrderedDict()
+ self.ready = False
+
+ def clear(self):
+ self.val_history.clear()
+ self.n_history.clear()
+ self.clear_output()
+
+ def clear_output(self):
+ self.output.clear()
+ self.ready = False
+
+ def update(self, vars, count=1):
+ assert isinstance(vars, dict)
+ for key, var in vars.items():
+ if key not in self.val_history:
+ self.val_history[key] = []
+ self.n_history[key] = []
+ self.val_history[key].append(var)
+ self.n_history[key].append(count)
+
+ def average(self, n=0):
+ """Average latest n values or all values."""
+ assert n >= 0
+ for key in self.val_history:
+ values = np.array(self.val_history[key][-n:])
+ nums = np.array(self.n_history[key][-n:])
+ avg = np.sum(values * nums) / np.sum(nums)
+ self.output[key] = avg
+ self.ready = True
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/optimizer/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/optimizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..53c34d0470992cbc374f29681fdd00dc0e57968d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/optimizer/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .builder import (OPTIMIZER_BUILDERS, OPTIMIZERS, build_optimizer,
+ build_optimizer_constructor)
+from .default_constructor import DefaultOptimizerConstructor
+
+__all__ = [
+ 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
+ 'build_optimizer', 'build_optimizer_constructor'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/optimizer/builder.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/optimizer/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9234eed8f1f186d9d8dfda34562157ee39bdb3a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/optimizer/builder.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import inspect
+
+import torch
+
+from ...utils import Registry, build_from_cfg
+
+OPTIMIZERS = Registry('optimizer')
+OPTIMIZER_BUILDERS = Registry('optimizer builder')
+
+
+def register_torch_optimizers():
+ torch_optimizers = []
+ for module_name in dir(torch.optim):
+ if module_name.startswith('__'):
+ continue
+ _optim = getattr(torch.optim, module_name)
+ if inspect.isclass(_optim) and issubclass(_optim,
+ torch.optim.Optimizer):
+ OPTIMIZERS.register_module()(_optim)
+ torch_optimizers.append(module_name)
+ return torch_optimizers
+
+
+TORCH_OPTIMIZERS = register_torch_optimizers()
+
+
+def build_optimizer_constructor(cfg):
+ return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
+
+
+def build_optimizer(model, cfg):
+ optimizer_cfg = copy.deepcopy(cfg)
+ constructor_type = optimizer_cfg.pop('constructor',
+ 'DefaultOptimizerConstructor')
+ paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
+ optim_constructor = build_optimizer_constructor(
+ dict(
+ type=constructor_type,
+ optimizer_cfg=optimizer_cfg,
+ paramwise_cfg=paramwise_cfg))
+ optimizer = optim_constructor(model)
+ return optimizer
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/optimizer/default_constructor.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/optimizer/default_constructor.py
new file mode 100644
index 0000000000000000000000000000000000000000..de2ae39cb6378cc17c098f5324f5d5c321879b91
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/optimizer/default_constructor.py
@@ -0,0 +1,249 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+from torch.nn import GroupNorm, LayerNorm
+
+from annotator.mmpkg.mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of
+from annotator.mmpkg.mmcv.utils.ext_loader import check_ops_exist
+from .builder import OPTIMIZER_BUILDERS, OPTIMIZERS
+
+
+@OPTIMIZER_BUILDERS.register_module()
+class DefaultOptimizerConstructor:
+ """Default constructor for optimizers.
+
+ By default each parameter share the same optimizer settings, and we
+ provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
+ It is a dict and may contain the following fields:
+
+ - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
+ one of the keys in ``custom_keys`` is a substring of the name of one
+ parameter, then the setting of the parameter will be specified by
+ ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
+ be ignored. It should be noted that the aforementioned ``key`` is the
+ longest key that is a substring of the name of the parameter. If there
+ are multiple matched keys with the same length, then the key with lower
+ alphabet order will be chosen.
+ ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
+ and ``decay_mult``. See Example 2 below.
+ - ``bias_lr_mult`` (float): It will be multiplied to the learning
+ rate for all bias parameters (except for those in normalization
+ layers and offset layers of DCN).
+ - ``bias_decay_mult`` (float): It will be multiplied to the weight
+ decay for all bias parameters (except for those in
+ normalization layers, depthwise conv layers, offset layers of DCN).
+ - ``norm_decay_mult`` (float): It will be multiplied to the weight
+ decay for all weight and bias parameters of normalization
+ layers.
+ - ``dwconv_decay_mult`` (float): It will be multiplied to the weight
+ decay for all weight and bias parameters of depthwise conv
+ layers.
+ - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning
+ rate for parameters of offset layer in the deformable convs
+ of a model.
+ - ``bypass_duplicate`` (bool): If true, the duplicate parameters
+ would not be added into optimizer. Default: False.
+
+ Note:
+ 1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
+ override the effect of ``bias_lr_mult`` in the bias of offset
+ layer. So be careful when using both ``bias_lr_mult`` and
+ ``dcn_offset_lr_mult``. If you wish to apply both of them to the
+ offset layer in deformable convs, set ``dcn_offset_lr_mult``
+ to the original ``dcn_offset_lr_mult`` * ``bias_lr_mult``.
+ 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
+ apply it to all the DCN layers in the model. So be careful when
+ the model contains multiple DCN layers in places other than
+ backbone.
+
+ Args:
+ model (:obj:`nn.Module`): The model with parameters to be optimized.
+ optimizer_cfg (dict): The config dict of the optimizer.
+ Positional fields are
+
+ - `type`: class name of the optimizer.
+
+ Optional fields are
+
+ - any arguments of the corresponding optimizer type, e.g.,
+ lr, weight_decay, momentum, etc.
+ paramwise_cfg (dict, optional): Parameter-wise options.
+
+ Example 1:
+ >>> model = torch.nn.modules.Conv1d(1, 1, 1)
+ >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
+ >>> weight_decay=0.0001)
+ >>> paramwise_cfg = dict(norm_decay_mult=0.)
+ >>> optim_builder = DefaultOptimizerConstructor(
+ >>> optimizer_cfg, paramwise_cfg)
+ >>> optimizer = optim_builder(model)
+
+ Example 2:
+ >>> # assume model have attribute model.backbone and model.cls_head
+ >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
+ >>> paramwise_cfg = dict(custom_keys={
+ '.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
+ >>> optim_builder = DefaultOptimizerConstructor(
+ >>> optimizer_cfg, paramwise_cfg)
+ >>> optimizer = optim_builder(model)
+ >>> # Then the `lr` and `weight_decay` for model.backbone is
+ >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
+ >>> # model.cls_head is (0.01, 0.95).
+ """
+
+ def __init__(self, optimizer_cfg, paramwise_cfg=None):
+ if not isinstance(optimizer_cfg, dict):
+ raise TypeError('optimizer_cfg should be a dict',
+ f'but got {type(optimizer_cfg)}')
+ self.optimizer_cfg = optimizer_cfg
+ self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
+ self.base_lr = optimizer_cfg.get('lr', None)
+ self.base_wd = optimizer_cfg.get('weight_decay', None)
+ self._validate_cfg()
+
+ def _validate_cfg(self):
+ if not isinstance(self.paramwise_cfg, dict):
+ raise TypeError('paramwise_cfg should be None or a dict, '
+ f'but got {type(self.paramwise_cfg)}')
+
+ if 'custom_keys' in self.paramwise_cfg:
+ if not isinstance(self.paramwise_cfg['custom_keys'], dict):
+ raise TypeError(
+ 'If specified, custom_keys must be a dict, '
+ f'but got {type(self.paramwise_cfg["custom_keys"])}')
+ if self.base_wd is None:
+ for key in self.paramwise_cfg['custom_keys']:
+ if 'decay_mult' in self.paramwise_cfg['custom_keys'][key]:
+ raise ValueError('base_wd should not be None')
+
+ # get base lr and weight decay
+ # weight_decay must be explicitly specified if mult is specified
+ if ('bias_decay_mult' in self.paramwise_cfg
+ or 'norm_decay_mult' in self.paramwise_cfg
+ or 'dwconv_decay_mult' in self.paramwise_cfg):
+ if self.base_wd is None:
+ raise ValueError('base_wd should not be None')
+
+ def _is_in(self, param_group, param_group_list):
+ assert is_list_of(param_group_list, dict)
+ param = set(param_group['params'])
+ param_set = set()
+ for group in param_group_list:
+ param_set.update(set(group['params']))
+
+ return not param.isdisjoint(param_set)
+
+ def add_params(self, params, module, prefix='', is_dcn_module=None):
+ """Add all parameters of module to the params list.
+
+ The parameters of the given module will be added to the list of param
+ groups, with specific rules defined by paramwise_cfg.
+
+ Args:
+ params (list[dict]): A list of param groups, it will be modified
+ in place.
+ module (nn.Module): The module to be added.
+ prefix (str): The prefix of the module
+ is_dcn_module (int|float|None): If the current module is a
+ submodule of DCN, `is_dcn_module` will be passed to
+ control conv_offset layer's learning rate. Defaults to None.
+ """
+ # get param-wise options
+ custom_keys = self.paramwise_cfg.get('custom_keys', {})
+ # first sort with alphabet order and then sort with reversed len of str
+ sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
+
+ bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
+ bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
+ norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
+ dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.)
+ bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
+ dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', 1.)
+
+ # special rules for norm layers and depth-wise conv layers
+ is_norm = isinstance(module,
+ (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
+ is_dwconv = (
+ isinstance(module, torch.nn.Conv2d)
+ and module.in_channels == module.groups)
+
+ for name, param in module.named_parameters(recurse=False):
+ param_group = {'params': [param]}
+ if not param.requires_grad:
+ params.append(param_group)
+ continue
+ if bypass_duplicate and self._is_in(param_group, params):
+ warnings.warn(f'{prefix} is duplicate. It is skipped since '
+ f'bypass_duplicate={bypass_duplicate}')
+ continue
+ # if the parameter match one of the custom keys, ignore other rules
+ is_custom = False
+ for key in sorted_keys:
+ if key in f'{prefix}.{name}':
+ is_custom = True
+ lr_mult = custom_keys[key].get('lr_mult', 1.)
+ param_group['lr'] = self.base_lr * lr_mult
+ if self.base_wd is not None:
+ decay_mult = custom_keys[key].get('decay_mult', 1.)
+ param_group['weight_decay'] = self.base_wd * decay_mult
+ break
+
+ if not is_custom:
+ # bias_lr_mult affects all bias parameters
+ # except for norm.bias dcn.conv_offset.bias
+ if name == 'bias' and not (is_norm or is_dcn_module):
+ param_group['lr'] = self.base_lr * bias_lr_mult
+
+ if (prefix.find('conv_offset') != -1 and is_dcn_module
+ and isinstance(module, torch.nn.Conv2d)):
+ # deal with both dcn_offset's bias & weight
+ param_group['lr'] = self.base_lr * dcn_offset_lr_mult
+
+ # apply weight decay policies
+ if self.base_wd is not None:
+ # norm decay
+ if is_norm:
+ param_group[
+ 'weight_decay'] = self.base_wd * norm_decay_mult
+ # depth-wise conv
+ elif is_dwconv:
+ param_group[
+ 'weight_decay'] = self.base_wd * dwconv_decay_mult
+ # bias lr and decay
+ elif name == 'bias' and not is_dcn_module:
+ # TODO: current bias_decay_mult will have affect on DCN
+ param_group[
+ 'weight_decay'] = self.base_wd * bias_decay_mult
+ params.append(param_group)
+
+ if check_ops_exist():
+ from annotator.mmpkg.mmcv.ops import DeformConv2d, ModulatedDeformConv2d
+ is_dcn_module = isinstance(module,
+ (DeformConv2d, ModulatedDeformConv2d))
+ else:
+ is_dcn_module = False
+ for child_name, child_mod in module.named_children():
+ child_prefix = f'{prefix}.{child_name}' if prefix else child_name
+ self.add_params(
+ params,
+ child_mod,
+ prefix=child_prefix,
+ is_dcn_module=is_dcn_module)
+
+ def __call__(self, model):
+ if hasattr(model, 'module'):
+ model = model.module
+
+ optimizer_cfg = self.optimizer_cfg.copy()
+ # if no paramwise option is specified, just use the global setting
+ if not self.paramwise_cfg:
+ optimizer_cfg['params'] = model.parameters()
+ return build_from_cfg(optimizer_cfg, OPTIMIZERS)
+
+ # set param-wise lr and weight decay recursively
+ params = []
+ self.add_params(params, model)
+ optimizer_cfg['params'] = params
+
+ return build_from_cfg(optimizer_cfg, OPTIMIZERS)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/priority.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/priority.py
new file mode 100644
index 0000000000000000000000000000000000000000..64cc4e3a05f8d5b89ab6eb32461e6e80f1d62e67
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/priority.py
@@ -0,0 +1,60 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from enum import Enum
+
+
+class Priority(Enum):
+ """Hook priority levels.
+
+ +--------------+------------+
+ | Level | Value |
+ +==============+============+
+ | HIGHEST | 0 |
+ +--------------+------------+
+ | VERY_HIGH | 10 |
+ +--------------+------------+
+ | HIGH | 30 |
+ +--------------+------------+
+ | ABOVE_NORMAL | 40 |
+ +--------------+------------+
+ | NORMAL | 50 |
+ +--------------+------------+
+ | BELOW_NORMAL | 60 |
+ +--------------+------------+
+ | LOW | 70 |
+ +--------------+------------+
+ | VERY_LOW | 90 |
+ +--------------+------------+
+ | LOWEST | 100 |
+ +--------------+------------+
+ """
+
+ HIGHEST = 0
+ VERY_HIGH = 10
+ HIGH = 30
+ ABOVE_NORMAL = 40
+ NORMAL = 50
+ BELOW_NORMAL = 60
+ LOW = 70
+ VERY_LOW = 90
+ LOWEST = 100
+
+
+def get_priority(priority):
+ """Get priority value.
+
+ Args:
+ priority (int or str or :obj:`Priority`): Priority.
+
+ Returns:
+ int: The priority value.
+ """
+ if isinstance(priority, int):
+ if priority < 0 or priority > 100:
+ raise ValueError('priority must be between 0 and 100')
+ return priority
+ elif isinstance(priority, Priority):
+ return priority.value
+ elif isinstance(priority, str):
+ return Priority[priority.upper()].value
+ else:
+ raise TypeError('priority must be an integer or Priority enum value')
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/utils.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..11bbc523e9a009119531c5eb903a93fe40cc5bca
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/runner/utils.py
@@ -0,0 +1,93 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import random
+import sys
+import time
+import warnings
+from getpass import getuser
+from socket import gethostname
+
+import numpy as np
+import torch
+
+import annotator.mmpkg.mmcv as mmcv
+
+
+def get_host_info():
+ """Get hostname and username.
+
+ Return empty string if exception raised, e.g. ``getpass.getuser()`` will
+ lead to error in docker container
+ """
+ host = ''
+ try:
+ host = f'{getuser()}@{gethostname()}'
+ except Exception as e:
+ warnings.warn(f'Host or user not found: {str(e)}')
+ finally:
+ return host
+
+
+def get_time_str():
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
+
+
+def obj_from_dict(info, parent=None, default_args=None):
+ """Initialize an object from dict.
+
+ The dict must contain the key "type", which indicates the object type, it
+ can be either a string or type, such as "list" or ``list``. Remaining
+ fields are treated as the arguments for constructing the object.
+
+ Args:
+ info (dict): Object types and arguments.
+ parent (:class:`module`): Module which may containing expected object
+ classes.
+ default_args (dict, optional): Default arguments for initializing the
+ object.
+
+ Returns:
+ any type: Object built from the dict.
+ """
+ assert isinstance(info, dict) and 'type' in info
+ assert isinstance(default_args, dict) or default_args is None
+ args = info.copy()
+ obj_type = args.pop('type')
+ if mmcv.is_str(obj_type):
+ if parent is not None:
+ obj_type = getattr(parent, obj_type)
+ else:
+ obj_type = sys.modules[obj_type]
+ elif not isinstance(obj_type, type):
+ raise TypeError('type must be a str or valid type, but '
+ f'got {type(obj_type)}')
+ if default_args is not None:
+ for name, value in default_args.items():
+ args.setdefault(name, value)
+ return obj_type(**args)
+
+
+def set_random_seed(seed, deterministic=False, use_rank_shift=False):
+ """Set random seed.
+
+ Args:
+ seed (int): Seed to be used.
+ deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+ rank_shift (bool): Whether to add rank number to the random seed to
+ have different random seed in different threads. Default: False.
+ """
+ if use_rank_shift:
+ rank, _ = mmcv.runner.get_dist_info()
+ seed += rank
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ if deterministic:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..378a0068432a371af364de9d73785901c0f83383
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/__init__.py
@@ -0,0 +1,69 @@
+# flake8: noqa
+# Copyright (c) OpenMMLab. All rights reserved.
+from .config import Config, ConfigDict, DictAction
+from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
+ has_method, import_modules_from_strings, is_list_of,
+ is_method_overridden, is_seq_of, is_str, is_tuple_of,
+ iter_cast, list_cast, requires_executable, requires_package,
+ slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
+ to_ntuple, tuple_cast)
+from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
+ scandir, symlink)
+from .progressbar import (ProgressBar, track_iter_progress,
+ track_parallel_progress, track_progress)
+from .testing import (assert_attrs_equal, assert_dict_contains_subset,
+ assert_dict_has_keys, assert_is_norm_layer,
+ assert_keys_equal, assert_params_all_zeros,
+ check_python_script)
+from .timer import Timer, TimerError, check_time
+from .version_utils import digit_version, get_git_hash
+
+try:
+ import torch
+except ImportError:
+ __all__ = [
+ 'Config', 'ConfigDict', 'DictAction', 'is_str', 'iter_cast',
+ 'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of',
+ 'slice_list', 'concat_list', 'check_prerequisites', 'requires_package',
+ 'requires_executable', 'is_filepath', 'fopen', 'check_file_exist',
+ 'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar',
+ 'track_progress', 'track_iter_progress', 'track_parallel_progress',
+ 'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
+ 'digit_version', 'get_git_hash', 'import_modules_from_strings',
+ 'assert_dict_contains_subset', 'assert_attrs_equal',
+ 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
+ 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
+ 'is_method_overridden', 'has_method'
+ ]
+else:
+ from .env import collect_env
+ from .logging import get_logger, print_log
+ from .parrots_jit import jit, skip_no_elena
+ from .parrots_wrapper import (
+ TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader,
+ PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
+ _AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm,
+ _MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home)
+ from .registry import Registry, build_from_cfg
+ from .trace import is_jit_tracing
+ __all__ = [
+ 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
+ 'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
+ 'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
+ 'check_prerequisites', 'requires_package', 'requires_executable',
+ 'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist',
+ 'symlink', 'scandir', 'ProgressBar', 'track_progress',
+ 'track_iter_progress', 'track_parallel_progress', 'Registry',
+ 'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm',
+ '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm',
+ '_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd',
+ 'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension',
+ 'DataLoader', 'PoolDataLoader', 'TORCH_VERSION',
+ 'deprecated_api_warning', 'digit_version', 'get_git_hash',
+ 'import_modules_from_strings', 'jit', 'skip_no_elena',
+ 'assert_dict_contains_subset', 'assert_attrs_equal',
+ 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
+ 'assert_params_all_zeros', 'check_python_script',
+ 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
+ '_get_cuda_home', 'has_method'
+ ]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/config.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2f7551f95cbf5d8ffa225bba7325632b5e7f01b
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/config.py
@@ -0,0 +1,688 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import ast
+import copy
+import os
+import os.path as osp
+import platform
+import shutil
+import sys
+import tempfile
+import uuid
+import warnings
+from argparse import Action, ArgumentParser
+from collections import abc
+from importlib import import_module
+
+from addict import Dict
+from yapf.yapflib.yapf_api import FormatCode
+
+from .misc import import_modules_from_strings
+from .path import check_file_exist
+
+if platform.system() == 'Windows':
+ import regex as re
+else:
+ import re
+
+BASE_KEY = '_base_'
+DELETE_KEY = '_delete_'
+DEPRECATION_KEY = '_deprecation_'
+RESERVED_KEYS = ['filename', 'text', 'pretty_text']
+
+
+class ConfigDict(Dict):
+
+ def __missing__(self, name):
+ raise KeyError(name)
+
+ def __getattr__(self, name):
+ try:
+ value = super(ConfigDict, self).__getattr__(name)
+ except KeyError:
+ ex = AttributeError(f"'{self.__class__.__name__}' object has no "
+ f"attribute '{name}'")
+ except Exception as e:
+ ex = e
+ else:
+ return value
+ raise ex
+
+
+def add_args(parser, cfg, prefix=''):
+ for k, v in cfg.items():
+ if isinstance(v, str):
+ parser.add_argument('--' + prefix + k)
+ elif isinstance(v, int):
+ parser.add_argument('--' + prefix + k, type=int)
+ elif isinstance(v, float):
+ parser.add_argument('--' + prefix + k, type=float)
+ elif isinstance(v, bool):
+ parser.add_argument('--' + prefix + k, action='store_true')
+ elif isinstance(v, dict):
+ add_args(parser, v, prefix + k + '.')
+ elif isinstance(v, abc.Iterable):
+ parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
+ else:
+ print(f'cannot parse key {prefix + k} of type {type(v)}')
+ return parser
+
+
+class Config:
+ """A facility for config and config files.
+
+ It supports common file formats as configs: python/json/yaml. The interface
+ is the same as a dict object and also allows access config values as
+ attributes.
+
+ Example:
+ >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
+ >>> cfg.a
+ 1
+ >>> cfg.b
+ {'b1': [0, 1]}
+ >>> cfg.b.b1
+ [0, 1]
+ >>> cfg = Config.fromfile('tests/data/config/a.py')
+ >>> cfg.filename
+ "/home/kchen/projects/mmcv/tests/data/config/a.py"
+ >>> cfg.item4
+ 'test'
+ >>> cfg
+ "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
+ "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
+ """
+
+ @staticmethod
+ def _validate_py_syntax(filename):
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ content = f.read()
+ try:
+ ast.parse(content)
+ except SyntaxError as e:
+ raise SyntaxError('There are syntax errors in config '
+ f'file {filename}: {e}')
+
+ @staticmethod
+ def _substitute_predefined_vars(filename, temp_config_name):
+ file_dirname = osp.dirname(filename)
+ file_basename = osp.basename(filename)
+ file_basename_no_extension = osp.splitext(file_basename)[0]
+ file_extname = osp.splitext(filename)[1]
+ support_templates = dict(
+ fileDirname=file_dirname,
+ fileBasename=file_basename,
+ fileBasenameNoExtension=file_basename_no_extension,
+ fileExtname=file_extname)
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ config_file = f.read()
+ for key, value in support_templates.items():
+ regexp = r'\{\{\s*' + str(key) + r'\s*\}\}'
+ value = value.replace('\\', '/')
+ config_file = re.sub(regexp, value, config_file)
+ with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
+ tmp_config_file.write(config_file)
+
+ @staticmethod
+ def _pre_substitute_base_vars(filename, temp_config_name):
+ """Substitute base variable placehoders to string, so that parsing
+ would work."""
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ config_file = f.read()
+ base_var_dict = {}
+ regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}'
+ base_vars = set(re.findall(regexp, config_file))
+ for base_var in base_vars:
+ randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
+ base_var_dict[randstr] = base_var
+ regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}'
+ config_file = re.sub(regexp, f'"{randstr}"', config_file)
+ with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
+ tmp_config_file.write(config_file)
+ return base_var_dict
+
+ @staticmethod
+ def _substitute_base_vars(cfg, base_var_dict, base_cfg):
+ """Substitute variable strings to their actual values."""
+ cfg = copy.deepcopy(cfg)
+
+ if isinstance(cfg, dict):
+ for k, v in cfg.items():
+ if isinstance(v, str) and v in base_var_dict:
+ new_v = base_cfg
+ for new_k in base_var_dict[v].split('.'):
+ new_v = new_v[new_k]
+ cfg[k] = new_v
+ elif isinstance(v, (list, tuple, dict)):
+ cfg[k] = Config._substitute_base_vars(
+ v, base_var_dict, base_cfg)
+ elif isinstance(cfg, tuple):
+ cfg = tuple(
+ Config._substitute_base_vars(c, base_var_dict, base_cfg)
+ for c in cfg)
+ elif isinstance(cfg, list):
+ cfg = [
+ Config._substitute_base_vars(c, base_var_dict, base_cfg)
+ for c in cfg
+ ]
+ elif isinstance(cfg, str) and cfg in base_var_dict:
+ new_v = base_cfg
+ for new_k in base_var_dict[cfg].split('.'):
+ new_v = new_v[new_k]
+ cfg = new_v
+
+ return cfg
+
+ @staticmethod
+ def _file2dict(filename, use_predefined_variables=True):
+ filename = osp.abspath(osp.expanduser(filename))
+ check_file_exist(filename)
+ fileExtname = osp.splitext(filename)[1]
+ if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
+ raise IOError('Only py/yml/yaml/json type are supported now!')
+
+ with tempfile.TemporaryDirectory() as temp_config_dir:
+ temp_config_file = tempfile.NamedTemporaryFile(
+ dir=temp_config_dir, suffix=fileExtname)
+ if platform.system() == 'Windows':
+ temp_config_file.close()
+ temp_config_name = osp.basename(temp_config_file.name)
+ # Substitute predefined variables
+ if use_predefined_variables:
+ Config._substitute_predefined_vars(filename,
+ temp_config_file.name)
+ else:
+ shutil.copyfile(filename, temp_config_file.name)
+ # Substitute base variables from placeholders to strings
+ base_var_dict = Config._pre_substitute_base_vars(
+ temp_config_file.name, temp_config_file.name)
+
+ if filename.endswith('.py'):
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ sys.path.insert(0, temp_config_dir)
+ Config._validate_py_syntax(filename)
+ mod = import_module(temp_module_name)
+ sys.path.pop(0)
+ cfg_dict = {
+ name: value
+ for name, value in mod.__dict__.items()
+ if not name.startswith('__')
+ }
+ # delete imported module
+ del sys.modules[temp_module_name]
+ elif filename.endswith(('.yml', '.yaml', '.json')):
+ import annotator.mmpkg.mmcv as mmcv
+ cfg_dict = mmcv.load(temp_config_file.name)
+ # close temp file
+ temp_config_file.close()
+
+ # check deprecation information
+ if DEPRECATION_KEY in cfg_dict:
+ deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
+ warning_msg = f'The config file {filename} will be deprecated ' \
+ 'in the future.'
+ if 'expected' in deprecation_info:
+ warning_msg += f' Please use {deprecation_info["expected"]} ' \
+ 'instead.'
+ if 'reference' in deprecation_info:
+ warning_msg += ' More information can be found at ' \
+ f'{deprecation_info["reference"]}'
+ warnings.warn(warning_msg)
+
+ cfg_text = filename + '\n'
+ with open(filename, 'r', encoding='utf-8') as f:
+ # Setting encoding explicitly to resolve coding issue on windows
+ cfg_text += f.read()
+
+ if BASE_KEY in cfg_dict:
+ cfg_dir = osp.dirname(filename)
+ base_filename = cfg_dict.pop(BASE_KEY)
+ base_filename = base_filename if isinstance(
+ base_filename, list) else [base_filename]
+
+ cfg_dict_list = list()
+ cfg_text_list = list()
+ for f in base_filename:
+ _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
+ cfg_dict_list.append(_cfg_dict)
+ cfg_text_list.append(_cfg_text)
+
+ base_cfg_dict = dict()
+ for c in cfg_dict_list:
+ duplicate_keys = base_cfg_dict.keys() & c.keys()
+ if len(duplicate_keys) > 0:
+ raise KeyError('Duplicate key is not allowed among bases. '
+ f'Duplicate keys: {duplicate_keys}')
+ base_cfg_dict.update(c)
+
+ # Substitute base variables from strings to their actual values
+ cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
+ base_cfg_dict)
+
+ base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
+ cfg_dict = base_cfg_dict
+
+ # merge cfg_text
+ cfg_text_list.append(cfg_text)
+ cfg_text = '\n'.join(cfg_text_list)
+
+ return cfg_dict, cfg_text
+
+ @staticmethod
+ def _merge_a_into_b(a, b, allow_list_keys=False):
+ """merge dict ``a`` into dict ``b`` (non-inplace).
+
+ Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
+ in-place modifications.
+
+ Args:
+ a (dict): The source dict to be merged into ``b``.
+ b (dict): The origin dict to be fetch keys from ``a``.
+ allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
+ are allowed in source ``a`` and will replace the element of the
+ corresponding index in b if b is a list. Default: False.
+
+ Returns:
+ dict: The modified dict of ``b`` using ``a``.
+
+ Examples:
+ # Normally merge a into b.
+ >>> Config._merge_a_into_b(
+ ... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
+ {'obj': {'a': 2}}
+
+ # Delete b first and merge a into b.
+ >>> Config._merge_a_into_b(
+ ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
+ {'obj': {'a': 2}}
+
+ # b is a list
+ >>> Config._merge_a_into_b(
+ ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
+ [{'a': 2}, {'b': 2}]
+ """
+ b = b.copy()
+ for k, v in a.items():
+ if allow_list_keys and k.isdigit() and isinstance(b, list):
+ k = int(k)
+ if len(b) <= k:
+ raise KeyError(f'Index {k} exceeds the length of list {b}')
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
+ elif isinstance(v,
+ dict) and k in b and not v.pop(DELETE_KEY, False):
+ allowed_types = (dict, list) if allow_list_keys else dict
+ if not isinstance(b[k], allowed_types):
+ raise TypeError(
+ f'{k}={v} in child config cannot inherit from base '
+ f'because {k} is a dict in the child config but is of '
+ f'type {type(b[k])} in base config. You may set '
+ f'`{DELETE_KEY}=True` to ignore the base config')
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
+ else:
+ b[k] = v
+ return b
+
+ @staticmethod
+ def fromfile(filename,
+ use_predefined_variables=True,
+ import_custom_modules=True):
+ cfg_dict, cfg_text = Config._file2dict(filename,
+ use_predefined_variables)
+ if import_custom_modules and cfg_dict.get('custom_imports', None):
+ import_modules_from_strings(**cfg_dict['custom_imports'])
+ return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
+
+ @staticmethod
+ def fromstring(cfg_str, file_format):
+ """Generate config from config str.
+
+ Args:
+ cfg_str (str): Config str.
+ file_format (str): Config file format corresponding to the
+ config str. Only py/yml/yaml/json type are supported now!
+
+ Returns:
+ obj:`Config`: Config obj.
+ """
+ if file_format not in ['.py', '.json', '.yaml', '.yml']:
+ raise IOError('Only py/yml/yaml/json type are supported now!')
+ if file_format != '.py' and 'dict(' in cfg_str:
+ # check if users specify a wrong suffix for python
+ warnings.warn(
+ 'Please check "file_format", the file format may be .py')
+ with tempfile.NamedTemporaryFile(
+ 'w', encoding='utf-8', suffix=file_format,
+ delete=False) as temp_file:
+ temp_file.write(cfg_str)
+ # on windows, previous implementation cause error
+ # see PR 1077 for details
+ cfg = Config.fromfile(temp_file.name)
+ os.remove(temp_file.name)
+ return cfg
+
+ @staticmethod
+ def auto_argparser(description=None):
+ """Generate argparser from config file automatically (experimental)"""
+ partial_parser = ArgumentParser(description=description)
+ partial_parser.add_argument('config', help='config file path')
+ cfg_file = partial_parser.parse_known_args()[0].config
+ cfg = Config.fromfile(cfg_file)
+ parser = ArgumentParser(description=description)
+ parser.add_argument('config', help='config file path')
+ add_args(parser, cfg)
+ return parser, cfg
+
+ def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
+ if cfg_dict is None:
+ cfg_dict = dict()
+ elif not isinstance(cfg_dict, dict):
+ raise TypeError('cfg_dict must be a dict, but '
+ f'got {type(cfg_dict)}')
+ for key in cfg_dict:
+ if key in RESERVED_KEYS:
+ raise KeyError(f'{key} is reserved for config file')
+
+ super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
+ super(Config, self).__setattr__('_filename', filename)
+ if cfg_text:
+ text = cfg_text
+ elif filename:
+ with open(filename, 'r') as f:
+ text = f.read()
+ else:
+ text = ''
+ super(Config, self).__setattr__('_text', text)
+
+ @property
+ def filename(self):
+ return self._filename
+
+ @property
+ def text(self):
+ return self._text
+
+ @property
+ def pretty_text(self):
+
+ indent = 4
+
+ def _indent(s_, num_spaces):
+ s = s_.split('\n')
+ if len(s) == 1:
+ return s_
+ first = s.pop(0)
+ s = [(num_spaces * ' ') + line for line in s]
+ s = '\n'.join(s)
+ s = first + '\n' + s
+ return s
+
+ def _format_basic_types(k, v, use_mapping=False):
+ if isinstance(v, str):
+ v_str = f"'{v}'"
+ else:
+ v_str = str(v)
+
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: {v_str}'
+ else:
+ attr_str = f'{str(k)}={v_str}'
+ attr_str = _indent(attr_str, indent)
+
+ return attr_str
+
+ def _format_list(k, v, use_mapping=False):
+ # check if all items in the list are dict
+ if all(isinstance(_, dict) for _ in v):
+ v_str = '[\n'
+ v_str += '\n'.join(
+ f'dict({_indent(_format_dict(v_), indent)}),'
+ for v_ in v).rstrip(',')
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: {v_str}'
+ else:
+ attr_str = f'{str(k)}={v_str}'
+ attr_str = _indent(attr_str, indent) + ']'
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping)
+ return attr_str
+
+ def _contain_invalid_identifier(dict_str):
+ contain_invalid_identifier = False
+ for key_name in dict_str:
+ contain_invalid_identifier |= \
+ (not str(key_name).isidentifier())
+ return contain_invalid_identifier
+
+ def _format_dict(input_dict, outest_level=False):
+ r = ''
+ s = []
+
+ use_mapping = _contain_invalid_identifier(input_dict)
+ if use_mapping:
+ r += '{'
+ for idx, (k, v) in enumerate(input_dict.items()):
+ is_last = idx >= len(input_dict) - 1
+ end = '' if outest_level or is_last else ','
+ if isinstance(v, dict):
+ v_str = '\n' + _format_dict(v)
+ if use_mapping:
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
+ attr_str = f'{k_str}: dict({v_str}'
+ else:
+ attr_str = f'{str(k)}=dict({v_str}'
+ attr_str = _indent(attr_str, indent) + ')' + end
+ elif isinstance(v, list):
+ attr_str = _format_list(k, v, use_mapping) + end
+ else:
+ attr_str = _format_basic_types(k, v, use_mapping) + end
+
+ s.append(attr_str)
+ r += '\n'.join(s)
+ if use_mapping:
+ r += '}'
+ return r
+
+ cfg_dict = self._cfg_dict.to_dict()
+ text = _format_dict(cfg_dict, outest_level=True)
+ # copied from setup.cfg
+ yapf_style = dict(
+ based_on_style='pep8',
+ blank_line_before_nested_class_or_def=True,
+ split_before_expression_after_opening_paren=True)
+ text, _ = FormatCode(text, style_config=yapf_style, verify=True)
+
+ return text
+
+ def __repr__(self):
+ return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
+
+ def __len__(self):
+ return len(self._cfg_dict)
+
+ def __getattr__(self, name):
+ return getattr(self._cfg_dict, name)
+
+ def __getitem__(self, name):
+ return self._cfg_dict.__getitem__(name)
+
+ def __setattr__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setitem__(name, value)
+
+ def __iter__(self):
+ return iter(self._cfg_dict)
+
+ def __getstate__(self):
+ return (self._cfg_dict, self._filename, self._text)
+
+ def __setstate__(self, state):
+ _cfg_dict, _filename, _text = state
+ super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
+ super(Config, self).__setattr__('_filename', _filename)
+ super(Config, self).__setattr__('_text', _text)
+
+ def dump(self, file=None):
+ cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
+ if self.filename.endswith('.py'):
+ if file is None:
+ return self.pretty_text
+ else:
+ with open(file, 'w', encoding='utf-8') as f:
+ f.write(self.pretty_text)
+ else:
+ import annotator.mmpkg.mmcv as mmcv
+ if file is None:
+ file_format = self.filename.split('.')[-1]
+ return mmcv.dump(cfg_dict, file_format=file_format)
+ else:
+ mmcv.dump(cfg_dict, file)
+
+ def merge_from_dict(self, options, allow_list_keys=True):
+ """Merge list into cfg_dict.
+
+ Merge the dict parsed by MultipleKVAction into this cfg.
+
+ Examples:
+ >>> options = {'model.backbone.depth': 50,
+ ... 'model.backbone.with_cp':True}
+ >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
+ >>> cfg.merge_from_dict(options)
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ >>> assert cfg_dict == dict(
+ ... model=dict(backbone=dict(depth=50, with_cp=True)))
+
+ # Merge list element
+ >>> cfg = Config(dict(pipeline=[
+ ... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
+ >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
+ >>> cfg.merge_from_dict(options, allow_list_keys=True)
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ >>> assert cfg_dict == dict(pipeline=[
+ ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
+
+ Args:
+ options (dict): dict of configs to merge from.
+ allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
+ are allowed in ``options`` and will replace the element of the
+ corresponding index in the config if the config is a list.
+ Default: True.
+ """
+ option_cfg_dict = {}
+ for full_key, v in options.items():
+ d = option_cfg_dict
+ key_list = full_key.split('.')
+ for subkey in key_list[:-1]:
+ d.setdefault(subkey, ConfigDict())
+ d = d[subkey]
+ subkey = key_list[-1]
+ d[subkey] = v
+
+ cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+ super(Config, self).__setattr__(
+ '_cfg_dict',
+ Config._merge_a_into_b(
+ option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
+
+
+class DictAction(Action):
+ """
+ argparse action to split an argument into KEY=VALUE form
+ on the first = and append to a dictionary. List options can
+ be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
+ brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
+ list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
+ """
+
+ @staticmethod
+ def _parse_int_float_bool(val):
+ try:
+ return int(val)
+ except ValueError:
+ pass
+ try:
+ return float(val)
+ except ValueError:
+ pass
+ if val.lower() in ['true', 'false']:
+ return True if val.lower() == 'true' else False
+ return val
+
+ @staticmethod
+ def _parse_iterable(val):
+ """Parse iterable values in the string.
+
+ All elements inside '()' or '[]' are treated as iterable values.
+
+ Args:
+ val (str): Value string.
+
+ Returns:
+ list | tuple: The expanded list or tuple from the string.
+
+ Examples:
+ >>> DictAction._parse_iterable('1,2,3')
+ [1, 2, 3]
+ >>> DictAction._parse_iterable('[a, b, c]')
+ ['a', 'b', 'c']
+ >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
+ [(1, 2, 3), ['a', 'b'], 'c']
+ """
+
+ def find_next_comma(string):
+ """Find the position of next comma in the string.
+
+ If no ',' is found in the string, return the string length. All
+ chars inside '()' and '[]' are treated as one element and thus ','
+ inside these brackets are ignored.
+ """
+ assert (string.count('(') == string.count(')')) and (
+ string.count('[') == string.count(']')), \
+ f'Imbalanced brackets exist in {string}'
+ end = len(string)
+ for idx, char in enumerate(string):
+ pre = string[:idx]
+ # The string before this ',' is balanced
+ if ((char == ',') and (pre.count('(') == pre.count(')'))
+ and (pre.count('[') == pre.count(']'))):
+ end = idx
+ break
+ return end
+
+ # Strip ' and " characters and replace whitespace.
+ val = val.strip('\'\"').replace(' ', '')
+ is_tuple = False
+ if val.startswith('(') and val.endswith(')'):
+ is_tuple = True
+ val = val[1:-1]
+ elif val.startswith('[') and val.endswith(']'):
+ val = val[1:-1]
+ elif ',' not in val:
+ # val is a single value
+ return DictAction._parse_int_float_bool(val)
+
+ values = []
+ while len(val) > 0:
+ comma_idx = find_next_comma(val)
+ element = DictAction._parse_iterable(val[:comma_idx])
+ values.append(element)
+ val = val[comma_idx + 1:]
+ if is_tuple:
+ values = tuple(values)
+ return values
+
+ def __call__(self, parser, namespace, values, option_string=None):
+ options = {}
+ for kv in values:
+ key, val = kv.split('=', maxsplit=1)
+ options[key] = self._parse_iterable(val)
+ setattr(namespace, self.dest, options)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/env.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0c6e64a63f8a3ed813b749c134823a0ef69964c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/env.py
@@ -0,0 +1,95 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""This file holding some environment constant for sharing by other files."""
+
+import os.path as osp
+import subprocess
+import sys
+from collections import defaultdict
+
+import cv2
+import torch
+
+import annotator.mmpkg.mmcv as mmcv
+from .parrots_wrapper import get_build_config
+
+
+def collect_env():
+ """Collect the information of the running environments.
+
+ Returns:
+ dict: The environment information. The following fields are contained.
+
+ - sys.platform: The variable of ``sys.platform``.
+ - Python: Python version.
+ - CUDA available: Bool, indicating if CUDA is available.
+ - GPU devices: Device type of each GPU.
+ - CUDA_HOME (optional): The env var ``CUDA_HOME``.
+ - NVCC (optional): NVCC version.
+ - GCC: GCC version, "n/a" if GCC is not installed.
+ - PyTorch: PyTorch version.
+ - PyTorch compiling details: The output of \
+ ``torch.__config__.show()``.
+ - TorchVision (optional): TorchVision version.
+ - OpenCV: OpenCV version.
+ - MMCV: MMCV version.
+ - MMCV Compiler: The GCC version for compiling MMCV ops.
+ - MMCV CUDA Compiler: The CUDA version for compiling MMCV ops.
+ """
+ env_info = {}
+ env_info['sys.platform'] = sys.platform
+ env_info['Python'] = sys.version.replace('\n', '')
+
+ cuda_available = torch.cuda.is_available()
+ env_info['CUDA available'] = cuda_available
+
+ if cuda_available:
+ devices = defaultdict(list)
+ for k in range(torch.cuda.device_count()):
+ devices[torch.cuda.get_device_name(k)].append(str(k))
+ for name, device_ids in devices.items():
+ env_info['GPU ' + ','.join(device_ids)] = name
+
+ from annotator.mmpkg.mmcv.utils.parrots_wrapper import _get_cuda_home
+ CUDA_HOME = _get_cuda_home()
+ env_info['CUDA_HOME'] = CUDA_HOME
+
+ if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
+ try:
+ nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
+ nvcc = subprocess.check_output(
+ f'"{nvcc}" -V | tail -n1', shell=True)
+ nvcc = nvcc.decode('utf-8').strip()
+ except subprocess.SubprocessError:
+ nvcc = 'Not Available'
+ env_info['NVCC'] = nvcc
+
+ try:
+ gcc = subprocess.check_output('gcc --version | head -n1', shell=True)
+ gcc = gcc.decode('utf-8').strip()
+ env_info['GCC'] = gcc
+ except subprocess.CalledProcessError: # gcc is unavailable
+ env_info['GCC'] = 'n/a'
+
+ env_info['PyTorch'] = torch.__version__
+ env_info['PyTorch compiling details'] = get_build_config()
+
+ try:
+ import torchvision
+ env_info['TorchVision'] = torchvision.__version__
+ except ModuleNotFoundError:
+ pass
+
+ env_info['OpenCV'] = cv2.__version__
+
+ env_info['MMCV'] = mmcv.__version__
+
+ try:
+ from annotator.mmpkg.mmcv.ops import get_compiler_version, get_compiling_cuda_version
+ except ModuleNotFoundError:
+ env_info['MMCV Compiler'] = 'n/a'
+ env_info['MMCV CUDA Compiler'] = 'n/a'
+ else:
+ env_info['MMCV Compiler'] = get_compiler_version()
+ env_info['MMCV CUDA Compiler'] = get_compiling_cuda_version()
+
+ return env_info
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/ext_loader.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/ext_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..08132d2c1b9a1c28880e4bab4d4fa1ba39d9d083
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/ext_loader.py
@@ -0,0 +1,71 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import importlib
+import os
+import pkgutil
+import warnings
+from collections import namedtuple
+
+import torch
+
+if torch.__version__ != 'parrots':
+
+ def load_ext(name, funcs):
+ ext = importlib.import_module('mmcv.' + name)
+ for fun in funcs:
+ assert hasattr(ext, fun), f'{fun} miss in module {name}'
+ return ext
+else:
+ from parrots import extension
+ from parrots.base import ParrotsException
+
+ has_return_value_ops = [
+ 'nms',
+ 'softnms',
+ 'nms_match',
+ 'nms_rotated',
+ 'top_pool_forward',
+ 'top_pool_backward',
+ 'bottom_pool_forward',
+ 'bottom_pool_backward',
+ 'left_pool_forward',
+ 'left_pool_backward',
+ 'right_pool_forward',
+ 'right_pool_backward',
+ 'fused_bias_leakyrelu',
+ 'upfirdn2d',
+ 'ms_deform_attn_forward',
+ 'pixel_group',
+ 'contour_expand',
+ ]
+
+ def get_fake_func(name, e):
+
+ def fake_func(*args, **kwargs):
+ warnings.warn(f'{name} is not supported in parrots now')
+ raise e
+
+ return fake_func
+
+ def load_ext(name, funcs):
+ ExtModule = namedtuple('ExtModule', funcs)
+ ext_list = []
+ lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+ for fun in funcs:
+ try:
+ ext_fun = extension.load(fun, name, lib_dir=lib_root)
+ except ParrotsException as e:
+ if 'No element registered' not in e.message:
+ warnings.warn(e.message)
+ ext_fun = get_fake_func(fun, e)
+ ext_list.append(ext_fun)
+ else:
+ if fun in has_return_value_ops:
+ ext_list.append(ext_fun.op)
+ else:
+ ext_list.append(ext_fun.op_)
+ return ExtModule(*ext_list)
+
+
+def check_ops_exist():
+ ext_loader = pkgutil.find_loader('mmcv._ext')
+ return ext_loader is not None
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/logging.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aa0e04bb9b3ab2a4bfbc4def50404ccbac2c6e6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/logging.py
@@ -0,0 +1,110 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.distributed as dist
+
+logger_initialized = {}
+
+
+def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
+ """Initialize and get a logger by name.
+
+ If the logger has not been initialized, this method will initialize the
+ logger by adding one or two handlers, otherwise the initialized logger will
+ be directly returned. During initialization, a StreamHandler will always be
+ added. If `log_file` is specified and the process rank is 0, a FileHandler
+ will also be added.
+
+ Args:
+ name (str): Logger name.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the logger.
+ log_level (int): The logger level. Note that only the process of
+ rank 0 is affected, and other processes will set the level to
+ "Error" thus be silent most of the time.
+ file_mode (str): The file mode used in opening log file.
+ Defaults to 'w'.
+
+ Returns:
+ logging.Logger: The expected logger.
+ """
+ logger = logging.getLogger(name)
+ if name in logger_initialized:
+ return logger
+ # handle hierarchical names
+ # e.g., logger "a" is initialized, then logger "a.b" will skip the
+ # initialization since it is a child of "a".
+ for logger_name in logger_initialized:
+ if name.startswith(logger_name):
+ return logger
+
+ # handle duplicate logs to the console
+ # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
+ # to the root logger. As logger.propagate is True by default, this root
+ # level handler causes logging messages from rank>0 processes to
+ # unexpectedly show up on the console, creating much unwanted clutter.
+ # To fix this issue, we set the root logger's StreamHandler, if any, to log
+ # at the ERROR level.
+ for handler in logger.root.handlers:
+ if type(handler) is logging.StreamHandler:
+ handler.setLevel(logging.ERROR)
+
+ stream_handler = logging.StreamHandler()
+ handlers = [stream_handler]
+
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ else:
+ rank = 0
+
+ # only rank 0 will add a FileHandler
+ if rank == 0 and log_file is not None:
+ # Here, the default behaviour of the official logger is 'a'. Thus, we
+ # provide an interface to change the file mode to the default
+ # behaviour.
+ file_handler = logging.FileHandler(log_file, file_mode)
+ handlers.append(file_handler)
+
+ formatter = logging.Formatter(
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ for handler in handlers:
+ handler.setFormatter(formatter)
+ handler.setLevel(log_level)
+ logger.addHandler(handler)
+
+ if rank == 0:
+ logger.setLevel(log_level)
+ else:
+ logger.setLevel(logging.ERROR)
+
+ logger_initialized[name] = True
+
+ return logger
+
+
+def print_log(msg, logger=None, level=logging.INFO):
+ """Print a log message.
+
+ Args:
+ msg (str): The message to be logged.
+ logger (logging.Logger | str | None): The logger to be used.
+ Some special loggers are:
+ - "silent": no message will be printed.
+ - other str: the logger obtained with `get_root_logger(logger)`.
+ - None: The `print()` method will be used to print log messages.
+ level (int): Logging level. Only available when `logger` is a Logger
+ object or "root".
+ """
+ if logger is None:
+ print(msg)
+ elif isinstance(logger, logging.Logger):
+ logger.log(level, msg)
+ elif logger == 'silent':
+ pass
+ elif isinstance(logger, str):
+ _logger = get_logger(logger)
+ _logger.log(level, msg)
+ else:
+ raise TypeError(
+ 'logger should be either a logging.Logger object, str, '
+ f'"silent" or None, but got {type(logger)}')
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/misc.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c58d0d7fee9fe3d4519270ad8c1e998d0d8a18c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/misc.py
@@ -0,0 +1,377 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import collections.abc
+import functools
+import itertools
+import subprocess
+import warnings
+from collections import abc
+from importlib import import_module
+from inspect import getfullargspec
+from itertools import repeat
+
+
+# From PyTorch internals
+def _ntuple(n):
+
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
+
+
+def is_str(x):
+ """Whether the input is an string instance.
+
+ Note: This method is deprecated since python 2 is no longer supported.
+ """
+ return isinstance(x, str)
+
+
+def import_modules_from_strings(imports, allow_failed_imports=False):
+ """Import modules from the given list of strings.
+
+ Args:
+ imports (list | str | None): The given module names to be imported.
+ allow_failed_imports (bool): If True, the failed imports will return
+ None. Otherwise, an ImportError is raise. Default: False.
+
+ Returns:
+ list[module] | module | None: The imported modules.
+
+ Examples:
+ >>> osp, sys = import_modules_from_strings(
+ ... ['os.path', 'sys'])
+ >>> import os.path as osp_
+ >>> import sys as sys_
+ >>> assert osp == osp_
+ >>> assert sys == sys_
+ """
+ if not imports:
+ return
+ single_import = False
+ if isinstance(imports, str):
+ single_import = True
+ imports = [imports]
+ if not isinstance(imports, list):
+ raise TypeError(
+ f'custom_imports must be a list but got type {type(imports)}')
+ imported = []
+ for imp in imports:
+ if not isinstance(imp, str):
+ raise TypeError(
+ f'{imp} is of type {type(imp)} and cannot be imported.')
+ try:
+ imported_tmp = import_module(imp)
+ except ImportError:
+ if allow_failed_imports:
+ warnings.warn(f'{imp} failed to import and is ignored.',
+ UserWarning)
+ imported_tmp = None
+ else:
+ raise ImportError
+ imported.append(imported_tmp)
+ if single_import:
+ imported = imported[0]
+ return imported
+
+
+def iter_cast(inputs, dst_type, return_type=None):
+ """Cast elements of an iterable object into some type.
+
+ Args:
+ inputs (Iterable): The input object.
+ dst_type (type): Destination type.
+ return_type (type, optional): If specified, the output object will be
+ converted to this type, otherwise an iterator.
+
+ Returns:
+ iterator or specified type: The converted object.
+ """
+ if not isinstance(inputs, abc.Iterable):
+ raise TypeError('inputs must be an iterable object')
+ if not isinstance(dst_type, type):
+ raise TypeError('"dst_type" must be a valid type')
+
+ out_iterable = map(dst_type, inputs)
+
+ if return_type is None:
+ return out_iterable
+ else:
+ return return_type(out_iterable)
+
+
+def list_cast(inputs, dst_type):
+ """Cast elements of an iterable object into a list of some type.
+
+ A partial method of :func:`iter_cast`.
+ """
+ return iter_cast(inputs, dst_type, return_type=list)
+
+
+def tuple_cast(inputs, dst_type):
+ """Cast elements of an iterable object into a tuple of some type.
+
+ A partial method of :func:`iter_cast`.
+ """
+ return iter_cast(inputs, dst_type, return_type=tuple)
+
+
+def is_seq_of(seq, expected_type, seq_type=None):
+ """Check whether it is a sequence of some type.
+
+ Args:
+ seq (Sequence): The sequence to be checked.
+ expected_type (type): Expected type of sequence items.
+ seq_type (type, optional): Expected sequence type.
+
+ Returns:
+ bool: Whether the sequence is valid.
+ """
+ if seq_type is None:
+ exp_seq_type = abc.Sequence
+ else:
+ assert isinstance(seq_type, type)
+ exp_seq_type = seq_type
+ if not isinstance(seq, exp_seq_type):
+ return False
+ for item in seq:
+ if not isinstance(item, expected_type):
+ return False
+ return True
+
+
+def is_list_of(seq, expected_type):
+ """Check whether it is a list of some type.
+
+ A partial method of :func:`is_seq_of`.
+ """
+ return is_seq_of(seq, expected_type, seq_type=list)
+
+
+def is_tuple_of(seq, expected_type):
+ """Check whether it is a tuple of some type.
+
+ A partial method of :func:`is_seq_of`.
+ """
+ return is_seq_of(seq, expected_type, seq_type=tuple)
+
+
+def slice_list(in_list, lens):
+ """Slice a list into several sub lists by a list of given length.
+
+ Args:
+ in_list (list): The list to be sliced.
+ lens(int or list): The expected length of each out list.
+
+ Returns:
+ list: A list of sliced list.
+ """
+ if isinstance(lens, int):
+ assert len(in_list) % lens == 0
+ lens = [lens] * int(len(in_list) / lens)
+ if not isinstance(lens, list):
+ raise TypeError('"indices" must be an integer or a list of integers')
+ elif sum(lens) != len(in_list):
+ raise ValueError('sum of lens and list length does not '
+ f'match: {sum(lens)} != {len(in_list)}')
+ out_list = []
+ idx = 0
+ for i in range(len(lens)):
+ out_list.append(in_list[idx:idx + lens[i]])
+ idx += lens[i]
+ return out_list
+
+
+def concat_list(in_list):
+ """Concatenate a list of list into a single list.
+
+ Args:
+ in_list (list): The list of list to be merged.
+
+ Returns:
+ list: The concatenated flat list.
+ """
+ return list(itertools.chain(*in_list))
+
+
+def check_prerequisites(
+ prerequisites,
+ checker,
+ msg_tmpl='Prerequisites "{}" are required in method "{}" but not '
+ 'found, please install them first.'): # yapf: disable
+ """A decorator factory to check if prerequisites are satisfied.
+
+ Args:
+ prerequisites (str of list[str]): Prerequisites to be checked.
+ checker (callable): The checker method that returns True if a
+ prerequisite is meet, False otherwise.
+ msg_tmpl (str): The message template with two variables.
+
+ Returns:
+ decorator: A specific decorator.
+ """
+
+ def wrap(func):
+
+ @functools.wraps(func)
+ def wrapped_func(*args, **kwargs):
+ requirements = [prerequisites] if isinstance(
+ prerequisites, str) else prerequisites
+ missing = []
+ for item in requirements:
+ if not checker(item):
+ missing.append(item)
+ if missing:
+ print(msg_tmpl.format(', '.join(missing), func.__name__))
+ raise RuntimeError('Prerequisites not meet.')
+ else:
+ return func(*args, **kwargs)
+
+ return wrapped_func
+
+ return wrap
+
+
+def _check_py_package(package):
+ try:
+ import_module(package)
+ except ImportError:
+ return False
+ else:
+ return True
+
+
+def _check_executable(cmd):
+ if subprocess.call(f'which {cmd}', shell=True) != 0:
+ return False
+ else:
+ return True
+
+
+def requires_package(prerequisites):
+ """A decorator to check if some python packages are installed.
+
+ Example:
+ >>> @requires_package('numpy')
+ >>> func(arg1, args):
+ >>> return numpy.zeros(1)
+ array([0.])
+ >>> @requires_package(['numpy', 'non_package'])
+ >>> func(arg1, args):
+ >>> return numpy.zeros(1)
+ ImportError
+ """
+ return check_prerequisites(prerequisites, checker=_check_py_package)
+
+
+def requires_executable(prerequisites):
+ """A decorator to check if some executable files are installed.
+
+ Example:
+ >>> @requires_executable('ffmpeg')
+ >>> func(arg1, args):
+ >>> print(1)
+ 1
+ """
+ return check_prerequisites(prerequisites, checker=_check_executable)
+
+
+def deprecated_api_warning(name_dict, cls_name=None):
+ """A decorator to check if some arguments are deprecate and try to replace
+ deprecate src_arg_name to dst_arg_name.
+
+ Args:
+ name_dict(dict):
+ key (str): Deprecate argument names.
+ val (str): Expected argument names.
+
+ Returns:
+ func: New function.
+ """
+
+ def api_warning_wrapper(old_func):
+
+ @functools.wraps(old_func)
+ def new_func(*args, **kwargs):
+ # get the arg spec of the decorated method
+ args_info = getfullargspec(old_func)
+ # get name of the function
+ func_name = old_func.__name__
+ if cls_name is not None:
+ func_name = f'{cls_name}.{func_name}'
+ if args:
+ arg_names = args_info.args[:len(args)]
+ for src_arg_name, dst_arg_name in name_dict.items():
+ if src_arg_name in arg_names:
+ warnings.warn(
+ f'"{src_arg_name}" is deprecated in '
+ f'`{func_name}`, please use "{dst_arg_name}" '
+ 'instead')
+ arg_names[arg_names.index(src_arg_name)] = dst_arg_name
+ if kwargs:
+ for src_arg_name, dst_arg_name in name_dict.items():
+ if src_arg_name in kwargs:
+
+ assert dst_arg_name not in kwargs, (
+ f'The expected behavior is to replace '
+ f'the deprecated key `{src_arg_name}` to '
+ f'new key `{dst_arg_name}`, but got them '
+ f'in the arguments at the same time, which '
+ f'is confusing. `{src_arg_name} will be '
+ f'deprecated in the future, please '
+ f'use `{dst_arg_name}` instead.')
+
+ warnings.warn(
+ f'"{src_arg_name}" is deprecated in '
+ f'`{func_name}`, please use "{dst_arg_name}" '
+ 'instead')
+ kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
+
+ # apply converted arguments to the decorated method
+ output = old_func(*args, **kwargs)
+ return output
+
+ return new_func
+
+ return api_warning_wrapper
+
+
+def is_method_overridden(method, base_class, derived_class):
+ """Check if a method of base class is overridden in derived class.
+
+ Args:
+ method (str): the method name to check.
+ base_class (type): the class of the base class.
+ derived_class (type | Any): the class or instance of the derived class.
+ """
+ assert isinstance(base_class, type), \
+ "base_class doesn't accept instance, Please pass class instead."
+
+ if not isinstance(derived_class, type):
+ derived_class = derived_class.__class__
+
+ base_method = getattr(base_class, method)
+ derived_method = getattr(derived_class, method)
+ return derived_method != base_method
+
+
+def has_method(obj: object, method: str) -> bool:
+ """Check whether the object has a method.
+
+ Args:
+ method (str): The method name to check.
+ obj (object): The object to check.
+
+ Returns:
+ bool: True if the object has the method else False.
+ """
+ return hasattr(obj, method) and callable(getattr(obj, method))
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/parrots_jit.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/parrots_jit.py
new file mode 100644
index 0000000000000000000000000000000000000000..61873f6dbb9b10ed972c90aa8faa321e3cb3249e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/parrots_jit.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+
+from .parrots_wrapper import TORCH_VERSION
+
+parrots_jit_option = os.getenv('PARROTS_JIT_OPTION')
+
+if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON':
+ from parrots.jit import pat as jit
+else:
+
+ def jit(func=None,
+ check_input=None,
+ full_shape=True,
+ derivate=False,
+ coderize=False,
+ optimize=False):
+
+ def wrapper(func):
+
+ def wrapper_inner(*args, **kargs):
+ return func(*args, **kargs)
+
+ return wrapper_inner
+
+ if func is None:
+ return wrapper
+ else:
+ return func
+
+
+if TORCH_VERSION == 'parrots':
+ from parrots.utils.tester import skip_no_elena
+else:
+
+ def skip_no_elena(func):
+
+ def wrapper(*args, **kargs):
+ return func(*args, **kargs)
+
+ return wrapper
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/parrots_wrapper.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/parrots_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..93c97640d4b9ed088ca82cfe03e6efebfcfa9dbf
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/parrots_wrapper.py
@@ -0,0 +1,107 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from functools import partial
+
+import torch
+
+TORCH_VERSION = torch.__version__
+
+
+def is_rocm_pytorch() -> bool:
+ is_rocm = False
+ if TORCH_VERSION != 'parrots':
+ try:
+ from torch.utils.cpp_extension import ROCM_HOME
+ is_rocm = True if ((torch.version.hip is not None) and
+ (ROCM_HOME is not None)) else False
+ except ImportError:
+ pass
+ return is_rocm
+
+
+def _get_cuda_home():
+ if TORCH_VERSION == 'parrots':
+ from parrots.utils.build_extension import CUDA_HOME
+ else:
+ if is_rocm_pytorch():
+ from torch.utils.cpp_extension import ROCM_HOME
+ CUDA_HOME = ROCM_HOME
+ else:
+ from torch.utils.cpp_extension import CUDA_HOME
+ return CUDA_HOME
+
+
+def get_build_config():
+ if TORCH_VERSION == 'parrots':
+ from parrots.config import get_build_info
+ return get_build_info()
+ else:
+ return torch.__config__.show()
+
+
+def _get_conv():
+ if TORCH_VERSION == 'parrots':
+ from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
+ else:
+ from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
+ return _ConvNd, _ConvTransposeMixin
+
+
+def _get_dataloader():
+ if TORCH_VERSION == 'parrots':
+ from torch.utils.data import DataLoader, PoolDataLoader
+ else:
+ from torch.utils.data import DataLoader
+ PoolDataLoader = DataLoader
+ return DataLoader, PoolDataLoader
+
+
+def _get_extension():
+ if TORCH_VERSION == 'parrots':
+ from parrots.utils.build_extension import BuildExtension, Extension
+ CppExtension = partial(Extension, cuda=False)
+ CUDAExtension = partial(Extension, cuda=True)
+ else:
+ from torch.utils.cpp_extension import (BuildExtension, CppExtension,
+ CUDAExtension)
+ return BuildExtension, CppExtension, CUDAExtension
+
+
+def _get_pool():
+ if TORCH_VERSION == 'parrots':
+ from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
+ _AdaptiveMaxPoolNd, _AvgPoolNd,
+ _MaxPoolNd)
+ else:
+ from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd,
+ _AdaptiveMaxPoolNd, _AvgPoolNd,
+ _MaxPoolNd)
+ return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd
+
+
+def _get_norm():
+ if TORCH_VERSION == 'parrots':
+ from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
+ SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
+ else:
+ from torch.nn.modules.instancenorm import _InstanceNorm
+ from torch.nn.modules.batchnorm import _BatchNorm
+ SyncBatchNorm_ = torch.nn.SyncBatchNorm
+ return _BatchNorm, _InstanceNorm, SyncBatchNorm_
+
+
+_ConvNd, _ConvTransposeMixin = _get_conv()
+DataLoader, PoolDataLoader = _get_dataloader()
+BuildExtension, CppExtension, CUDAExtension = _get_extension()
+_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
+_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
+
+
+class SyncBatchNorm(SyncBatchNorm_):
+
+ def _check_input_dim(self, input):
+ if TORCH_VERSION == 'parrots':
+ if input.dim() < 2:
+ raise ValueError(
+ f'expected at least 2D input (got {input.dim()}D input)')
+ else:
+ super()._check_input_dim(input)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/path.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/path.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dab4b3041413b1432b0f434b8b14783097d33c6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/path.py
@@ -0,0 +1,101 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+from pathlib import Path
+
+from .misc import is_str
+
+
+def is_filepath(x):
+ return is_str(x) or isinstance(x, Path)
+
+
+def fopen(filepath, *args, **kwargs):
+ if is_str(filepath):
+ return open(filepath, *args, **kwargs)
+ elif isinstance(filepath, Path):
+ return filepath.open(*args, **kwargs)
+ raise ValueError('`filepath` should be a string or a Path')
+
+
+def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
+ if not osp.isfile(filename):
+ raise FileNotFoundError(msg_tmpl.format(filename))
+
+
+def mkdir_or_exist(dir_name, mode=0o777):
+ if dir_name == '':
+ return
+ dir_name = osp.expanduser(dir_name)
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
+
+
+def symlink(src, dst, overwrite=True, **kwargs):
+ if os.path.lexists(dst) and overwrite:
+ os.remove(dst)
+ os.symlink(src, dst, **kwargs)
+
+
+def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
+ """Scan a directory to find the interested files.
+
+ Args:
+ dir_path (str | obj:`Path`): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ case_sensitive (bool, optional) : If set to False, ignore the case of
+ suffix. Default: True.
+
+ Returns:
+ A generator for all the interested files with relative paths.
+ """
+ if isinstance(dir_path, (str, Path)):
+ dir_path = str(dir_path)
+ else:
+ raise TypeError('"dir_path" must be a string or Path object')
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ if suffix is not None and not case_sensitive:
+ suffix = suffix.lower() if isinstance(suffix, str) else tuple(
+ item.lower() for item in suffix)
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive, case_sensitive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ rel_path = osp.relpath(entry.path, root)
+ _rel_path = rel_path if case_sensitive else rel_path.lower()
+ if suffix is None or _rel_path.endswith(suffix):
+ yield rel_path
+ elif recursive and os.path.isdir(entry.path):
+ # scan recursively if entry.path is a directory
+ yield from _scandir(entry.path, suffix, recursive,
+ case_sensitive)
+
+ return _scandir(dir_path, suffix, recursive, case_sensitive)
+
+
+def find_vcs_root(path, markers=('.git', )):
+ """Finds the root directory (including itself) of specified markers.
+
+ Args:
+ path (str): Path of directory or file.
+ markers (list[str], optional): List of file or directory names.
+
+ Returns:
+ The directory contained one of the markers or None if not found.
+ """
+ if osp.isfile(path):
+ path = osp.dirname(path)
+
+ prev, cur = None, osp.abspath(osp.expanduser(path))
+ while cur != prev:
+ if any(osp.exists(osp.join(cur, marker)) for marker in markers):
+ return cur
+ prev, cur = cur, osp.split(cur)[0]
+ return None
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/progressbar.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/progressbar.py
new file mode 100644
index 0000000000000000000000000000000000000000..0062f670dd94fa9da559ab26ef85517dcf5211c7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/progressbar.py
@@ -0,0 +1,208 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import sys
+from collections.abc import Iterable
+from multiprocessing import Pool
+from shutil import get_terminal_size
+
+from .timer import Timer
+
+
+class ProgressBar:
+ """A progress bar which can print the progress."""
+
+ def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
+ self.task_num = task_num
+ self.bar_width = bar_width
+ self.completed = 0
+ self.file = file
+ if start:
+ self.start()
+
+ @property
+ def terminal_width(self):
+ width, _ = get_terminal_size()
+ return width
+
+ def start(self):
+ if self.task_num > 0:
+ self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
+ 'elapsed: 0s, ETA:')
+ else:
+ self.file.write('completed: 0, elapsed: 0s')
+ self.file.flush()
+ self.timer = Timer()
+
+ def update(self, num_tasks=1):
+ assert num_tasks > 0
+ self.completed += num_tasks
+ elapsed = self.timer.since_start()
+ if elapsed > 0:
+ fps = self.completed / elapsed
+ else:
+ fps = float('inf')
+ if self.task_num > 0:
+ percentage = self.completed / float(self.task_num)
+ eta = int(elapsed * (1 - percentage) / percentage + 0.5)
+ msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
+ f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
+ f'ETA: {eta:5}s'
+
+ bar_width = min(self.bar_width,
+ int(self.terminal_width - len(msg)) + 2,
+ int(self.terminal_width * 0.6))
+ bar_width = max(2, bar_width)
+ mark_width = int(bar_width * percentage)
+ bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
+ self.file.write(msg.format(bar_chars))
+ else:
+ self.file.write(
+ f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
+ f' {fps:.1f} tasks/s')
+ self.file.flush()
+
+
+def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
+ """Track the progress of tasks execution with a progress bar.
+
+ Tasks are done with a simple for-loop.
+
+ Args:
+ func (callable): The function to be applied to each task.
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ bar_width (int): Width of progress bar.
+
+ Returns:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError(
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
+ prog_bar = ProgressBar(task_num, bar_width, file=file)
+ results = []
+ for task in tasks:
+ results.append(func(task, **kwargs))
+ prog_bar.update()
+ prog_bar.file.write('\n')
+ return results
+
+
+def init_pool(process_num, initializer=None, initargs=None):
+ if initializer is None:
+ return Pool(process_num)
+ elif initargs is None:
+ return Pool(process_num, initializer)
+ else:
+ if not isinstance(initargs, tuple):
+ raise TypeError('"initargs" must be a tuple')
+ return Pool(process_num, initializer, initargs)
+
+
+def track_parallel_progress(func,
+ tasks,
+ nproc,
+ initializer=None,
+ initargs=None,
+ bar_width=50,
+ chunksize=1,
+ skip_first=False,
+ keep_order=True,
+ file=sys.stdout):
+ """Track the progress of parallel task execution with a progress bar.
+
+ The built-in :mod:`multiprocessing` module is used for process pools and
+ tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
+
+ Args:
+ func (callable): The function to be applied to each task.
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ nproc (int): Process (worker) number.
+ initializer (None or callable): Refer to :class:`multiprocessing.Pool`
+ for details.
+ initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
+ details.
+ chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
+ bar_width (int): Width of progress bar.
+ skip_first (bool): Whether to skip the first sample for each worker
+ when estimating fps, since the initialization step may takes
+ longer.
+ keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
+ :func:`Pool.imap_unordered` is used.
+
+ Returns:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError(
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
+ pool = init_pool(nproc, initializer, initargs)
+ start = not skip_first
+ task_num -= nproc * chunksize * int(skip_first)
+ prog_bar = ProgressBar(task_num, bar_width, start, file=file)
+ results = []
+ if keep_order:
+ gen = pool.imap(func, tasks, chunksize)
+ else:
+ gen = pool.imap_unordered(func, tasks, chunksize)
+ for result in gen:
+ results.append(result)
+ if skip_first:
+ if len(results) < nproc * chunksize:
+ continue
+ elif len(results) == nproc * chunksize:
+ prog_bar.start()
+ continue
+ prog_bar.update()
+ prog_bar.file.write('\n')
+ pool.close()
+ pool.join()
+ return results
+
+
+def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
+ """Track the progress of tasks iteration or enumeration with a progress
+ bar.
+
+ Tasks are yielded with a simple for-loop.
+
+ Args:
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ bar_width (int): Width of progress bar.
+
+ Yields:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError(
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
+ prog_bar = ProgressBar(task_num, bar_width, file=file)
+ for task in tasks:
+ yield task
+ prog_bar.update()
+ prog_bar.file.write('\n')
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/registry.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa9df39bc9f3d8d568361e7250ab35468f2b74e0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/registry.py
@@ -0,0 +1,315 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import warnings
+from functools import partial
+
+from .misc import is_seq_of
+
+
+def build_from_cfg(cfg, registry, default_args=None):
+ """Build a module from config dict.
+
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+ registry (:obj:`Registry`): The registry to search the type from.
+ default_args (dict, optional): Default initialization arguments.
+
+ Returns:
+ object: The constructed object.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+ if 'type' not in cfg:
+ if default_args is None or 'type' not in default_args:
+ raise KeyError(
+ '`cfg` or `default_args` must contain the key "type", '
+ f'but got {cfg}\n{default_args}')
+ if not isinstance(registry, Registry):
+ raise TypeError('registry must be an mmcv.Registry object, '
+ f'but got {type(registry)}')
+ if not (isinstance(default_args, dict) or default_args is None):
+ raise TypeError('default_args must be a dict or None, '
+ f'but got {type(default_args)}')
+
+ args = cfg.copy()
+
+ if default_args is not None:
+ for name, value in default_args.items():
+ args.setdefault(name, value)
+
+ obj_type = args.pop('type')
+ if isinstance(obj_type, str):
+ obj_cls = registry.get(obj_type)
+ if obj_cls is None:
+ raise KeyError(
+ f'{obj_type} is not in the {registry.name} registry')
+ elif inspect.isclass(obj_type):
+ obj_cls = obj_type
+ else:
+ raise TypeError(
+ f'type must be a str or valid type, but got {type(obj_type)}')
+ try:
+ return obj_cls(**args)
+ except Exception as e:
+ # Normal TypeError does not print class name.
+ raise type(e)(f'{obj_cls.__name__}: {e}')
+
+
+class Registry:
+ """A registry to map strings to classes.
+
+ Registered object could be built from registry.
+ Example:
+ >>> MODELS = Registry('models')
+ >>> @MODELS.register_module()
+ >>> class ResNet:
+ >>> pass
+ >>> resnet = MODELS.build(dict(type='ResNet'))
+
+ Please refer to
+ https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
+ advanced usage.
+
+ Args:
+ name (str): Registry name.
+ build_func(func, optional): Build function to construct instance from
+ Registry, func:`build_from_cfg` is used if neither ``parent`` or
+ ``build_func`` is specified. If ``parent`` is specified and
+ ``build_func`` is not given, ``build_func`` will be inherited
+ from ``parent``. Default: None.
+ parent (Registry, optional): Parent registry. The class registered in
+ children registry could be built from parent. Default: None.
+ scope (str, optional): The scope of registry. It is the key to search
+ for children registry. If not specified, scope will be the name of
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
+ Default: None.
+ """
+
+ def __init__(self, name, build_func=None, parent=None, scope=None):
+ self._name = name
+ self._module_dict = dict()
+ self._children = dict()
+ self._scope = self.infer_scope() if scope is None else scope
+
+ # self.build_func will be set with the following priority:
+ # 1. build_func
+ # 2. parent.build_func
+ # 3. build_from_cfg
+ if build_func is None:
+ if parent is not None:
+ self.build_func = parent.build_func
+ else:
+ self.build_func = build_from_cfg
+ else:
+ self.build_func = build_func
+ if parent is not None:
+ assert isinstance(parent, Registry)
+ parent._add_children(self)
+ self.parent = parent
+ else:
+ self.parent = None
+
+ def __len__(self):
+ return len(self._module_dict)
+
+ def __contains__(self, key):
+ return self.get(key) is not None
+
+ def __repr__(self):
+ format_str = self.__class__.__name__ + \
+ f'(name={self._name}, ' \
+ f'items={self._module_dict})'
+ return format_str
+
+ @staticmethod
+ def infer_scope():
+ """Infer the scope of registry.
+
+ The name of the package where registry is defined will be returned.
+
+ Example:
+ # in mmdet/models/backbone/resnet.py
+ >>> MODELS = Registry('models')
+ >>> @MODELS.register_module()
+ >>> class ResNet:
+ >>> pass
+ The scope of ``ResNet`` will be ``mmdet``.
+
+
+ Returns:
+ scope (str): The inferred scope name.
+ """
+ # inspect.stack() trace where this function is called, the index-2
+ # indicates the frame where `infer_scope()` is called
+ filename = inspect.getmodule(inspect.stack()[2][0]).__name__
+ split_filename = filename.split('.')
+ return split_filename[0]
+
+ @staticmethod
+ def split_scope_key(key):
+ """Split scope and key.
+
+ The first scope will be split from key.
+
+ Examples:
+ >>> Registry.split_scope_key('mmdet.ResNet')
+ 'mmdet', 'ResNet'
+ >>> Registry.split_scope_key('ResNet')
+ None, 'ResNet'
+
+ Return:
+ scope (str, None): The first scope.
+ key (str): The remaining key.
+ """
+ split_index = key.find('.')
+ if split_index != -1:
+ return key[:split_index], key[split_index + 1:]
+ else:
+ return None, key
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def scope(self):
+ return self._scope
+
+ @property
+ def module_dict(self):
+ return self._module_dict
+
+ @property
+ def children(self):
+ return self._children
+
+ def get(self, key):
+ """Get the registry record.
+
+ Args:
+ key (str): The class name in string format.
+
+ Returns:
+ class: The corresponding class.
+ """
+ scope, real_key = self.split_scope_key(key)
+ if scope is None or scope == self._scope:
+ # get from self
+ if real_key in self._module_dict:
+ return self._module_dict[real_key]
+ else:
+ # get from self._children
+ if scope in self._children:
+ return self._children[scope].get(real_key)
+ else:
+ # goto root
+ parent = self.parent
+ while parent.parent is not None:
+ parent = parent.parent
+ return parent.get(key)
+
+ def build(self, *args, **kwargs):
+ return self.build_func(*args, **kwargs, registry=self)
+
+ def _add_children(self, registry):
+ """Add children for a registry.
+
+ The ``registry`` will be added as children based on its scope.
+ The parent registry could build objects from children registry.
+
+ Example:
+ >>> models = Registry('models')
+ >>> mmdet_models = Registry('models', parent=models)
+ >>> @mmdet_models.register_module()
+ >>> class ResNet:
+ >>> pass
+ >>> resnet = models.build(dict(type='mmdet.ResNet'))
+ """
+
+ assert isinstance(registry, Registry)
+ assert registry.scope is not None
+ assert registry.scope not in self.children, \
+ f'scope {registry.scope} exists in {self.name} registry'
+ self.children[registry.scope] = registry
+
+ def _register_module(self, module_class, module_name=None, force=False):
+ if not inspect.isclass(module_class):
+ raise TypeError('module must be a class, '
+ f'but got {type(module_class)}')
+
+ if module_name is None:
+ module_name = module_class.__name__
+ if isinstance(module_name, str):
+ module_name = [module_name]
+ for name in module_name:
+ if not force and name in self._module_dict:
+ raise KeyError(f'{name} is already registered '
+ f'in {self.name}')
+ self._module_dict[name] = module_class
+
+ def deprecated_register_module(self, cls=None, force=False):
+ warnings.warn(
+ 'The old API of register_module(module, force=False) '
+ 'is deprecated and will be removed, please use the new API '
+ 'register_module(name=None, force=False, module=None) instead.')
+ if cls is None:
+ return partial(self.deprecated_register_module, force=force)
+ self._register_module(cls, force=force)
+ return cls
+
+ def register_module(self, name=None, force=False, module=None):
+ """Register a module.
+
+ A record will be added to `self._module_dict`, whose key is the class
+ name or the specified name, and value is the class itself.
+ It can be used as a decorator or a normal function.
+
+ Example:
+ >>> backbones = Registry('backbone')
+ >>> @backbones.register_module()
+ >>> class ResNet:
+ >>> pass
+
+ >>> backbones = Registry('backbone')
+ >>> @backbones.register_module(name='mnet')
+ >>> class MobileNet:
+ >>> pass
+
+ >>> backbones = Registry('backbone')
+ >>> class ResNet:
+ >>> pass
+ >>> backbones.register_module(ResNet)
+
+ Args:
+ name (str | None): The module name to be registered. If not
+ specified, the class name will be used.
+ force (bool, optional): Whether to override an existing class with
+ the same name. Default: False.
+ module (type): Module class to be registered.
+ """
+ if not isinstance(force, bool):
+ raise TypeError(f'force must be a boolean, but got {type(force)}')
+ # NOTE: This is a walkaround to be compatible with the old api,
+ # while it may introduce unexpected bugs.
+ if isinstance(name, type):
+ return self.deprecated_register_module(name, force=force)
+
+ # raise the error ahead of time
+ if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
+ raise TypeError(
+ 'name must be either of None, an instance of str or a sequence'
+ f' of str, but got {type(name)}')
+
+ # use it as a normal method: x.register_module(module=SomeClass)
+ if module is not None:
+ self._register_module(
+ module_class=module, module_name=name, force=force)
+ return module
+
+ # use it as a decorator: @x.register_module()
+ def _register(cls):
+ self._register_module(
+ module_class=cls, module_name=name, force=force)
+ return cls
+
+ return _register
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/testing.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/testing.py
new file mode 100644
index 0000000000000000000000000000000000000000..a27f936da8ec14bac18562ede0a79d476d82f797
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/testing.py
@@ -0,0 +1,140 @@
+# Copyright (c) Open-MMLab.
+import sys
+from collections.abc import Iterable
+from runpy import run_path
+from shlex import split
+from typing import Any, Dict, List
+from unittest.mock import patch
+
+
+def check_python_script(cmd):
+ """Run the python cmd script with `__main__`. The difference between
+ `os.system` is that, this function exectues code in the current process, so
+ that it can be tracked by coverage tools. Currently it supports two forms:
+
+ - ./tests/data/scripts/hello.py zz
+ - python tests/data/scripts/hello.py zz
+ """
+ args = split(cmd)
+ if args[0] == 'python':
+ args = args[1:]
+ with patch.object(sys, 'argv', args):
+ run_path(args[0], run_name='__main__')
+
+
+def _any(judge_result):
+ """Since built-in ``any`` works only when the element of iterable is not
+ iterable, implement the function."""
+ if not isinstance(judge_result, Iterable):
+ return judge_result
+
+ try:
+ for element in judge_result:
+ if _any(element):
+ return True
+ except TypeError:
+ # Maybe encounter the case: torch.tensor(True) | torch.tensor(False)
+ if judge_result:
+ return True
+ return False
+
+
+def assert_dict_contains_subset(dict_obj: Dict[Any, Any],
+ expected_subset: Dict[Any, Any]) -> bool:
+ """Check if the dict_obj contains the expected_subset.
+
+ Args:
+ dict_obj (Dict[Any, Any]): Dict object to be checked.
+ expected_subset (Dict[Any, Any]): Subset expected to be contained in
+ dict_obj.
+
+ Returns:
+ bool: Whether the dict_obj contains the expected_subset.
+ """
+
+ for key, value in expected_subset.items():
+ if key not in dict_obj.keys() or _any(dict_obj[key] != value):
+ return False
+ return True
+
+
+def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool:
+ """Check if attribute of class object is correct.
+
+ Args:
+ obj (object): Class object to be checked.
+ expected_attrs (Dict[str, Any]): Dict of the expected attrs.
+
+ Returns:
+ bool: Whether the attribute of class object is correct.
+ """
+ for attr, value in expected_attrs.items():
+ if not hasattr(obj, attr) or _any(getattr(obj, attr) != value):
+ return False
+ return True
+
+
+def assert_dict_has_keys(obj: Dict[str, Any],
+ expected_keys: List[str]) -> bool:
+ """Check if the obj has all the expected_keys.
+
+ Args:
+ obj (Dict[str, Any]): Object to be checked.
+ expected_keys (List[str]): Keys expected to contained in the keys of
+ the obj.
+
+ Returns:
+ bool: Whether the obj has the expected keys.
+ """
+ return set(expected_keys).issubset(set(obj.keys()))
+
+
+def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool:
+ """Check if target_keys is equal to result_keys.
+
+ Args:
+ result_keys (List[str]): Result keys to be checked.
+ target_keys (List[str]): Target keys to be checked.
+
+ Returns:
+ bool: Whether target_keys is equal to result_keys.
+ """
+ return set(result_keys) == set(target_keys)
+
+
+def assert_is_norm_layer(module) -> bool:
+ """Check if the module is a norm layer.
+
+ Args:
+ module (nn.Module): The module to be checked.
+
+ Returns:
+ bool: Whether the module is a norm layer.
+ """
+ from .parrots_wrapper import _BatchNorm, _InstanceNorm
+ from torch.nn import GroupNorm, LayerNorm
+ norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
+ return isinstance(module, norm_layer_candidates)
+
+
+def assert_params_all_zeros(module) -> bool:
+ """Check if the parameters of the module is all zeros.
+
+ Args:
+ module (nn.Module): The module to be checked.
+
+ Returns:
+ bool: Whether the parameters of the module is all zeros.
+ """
+ weight_data = module.weight.data
+ is_weight_zero = weight_data.allclose(
+ weight_data.new_zeros(weight_data.size()))
+
+ if hasattr(module, 'bias') and module.bias is not None:
+ bias_data = module.bias.data
+ is_bias_zero = bias_data.allclose(
+ bias_data.new_zeros(bias_data.size()))
+ else:
+ is_bias_zero = True
+
+ return is_weight_zero and is_bias_zero
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/timer.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0435c1250ebb63e0d881d7022979a76b2dcc7298
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/timer.py
@@ -0,0 +1,118 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from time import time
+
+
+class TimerError(Exception):
+
+ def __init__(self, message):
+ self.message = message
+ super(TimerError, self).__init__(message)
+
+
+class Timer:
+ """A flexible Timer class.
+
+ :Example:
+
+ >>> import time
+ >>> import annotator.mmpkg.mmcv as mmcv
+ >>> with mmcv.Timer():
+ >>> # simulate a code block that will run for 1s
+ >>> time.sleep(1)
+ 1.000
+ >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'):
+ >>> # simulate a code block that will run for 1s
+ >>> time.sleep(1)
+ it takes 1.0 seconds
+ >>> timer = mmcv.Timer()
+ >>> time.sleep(0.5)
+ >>> print(timer.since_start())
+ 0.500
+ >>> time.sleep(0.5)
+ >>> print(timer.since_last_check())
+ 0.500
+ >>> print(timer.since_start())
+ 1.000
+ """
+
+ def __init__(self, start=True, print_tmpl=None):
+ self._is_running = False
+ self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
+ if start:
+ self.start()
+
+ @property
+ def is_running(self):
+ """bool: indicate whether the timer is running"""
+ return self._is_running
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ print(self.print_tmpl.format(self.since_last_check()))
+ self._is_running = False
+
+ def start(self):
+ """Start the timer."""
+ if not self._is_running:
+ self._t_start = time()
+ self._is_running = True
+ self._t_last = time()
+
+ def since_start(self):
+ """Total time since the timer is started.
+
+ Returns (float): Time in seconds.
+ """
+ if not self._is_running:
+ raise TimerError('timer is not running')
+ self._t_last = time()
+ return self._t_last - self._t_start
+
+ def since_last_check(self):
+ """Time since the last checking.
+
+ Either :func:`since_start` or :func:`since_last_check` is a checking
+ operation.
+
+ Returns (float): Time in seconds.
+ """
+ if not self._is_running:
+ raise TimerError('timer is not running')
+ dur = time() - self._t_last
+ self._t_last = time()
+ return dur
+
+
+_g_timers = {} # global timers
+
+
+def check_time(timer_id):
+ """Add check points in a single line.
+
+ This method is suitable for running a task on a list of items. A timer will
+ be registered when the method is called for the first time.
+
+ :Example:
+
+ >>> import time
+ >>> import annotator.mmpkg.mmcv as mmcv
+ >>> for i in range(1, 6):
+ >>> # simulate a code block
+ >>> time.sleep(i)
+ >>> mmcv.check_time('task1')
+ 2.000
+ 3.000
+ 4.000
+ 5.000
+
+ Args:
+ timer_id (str): Timer identifier.
+ """
+ if timer_id not in _g_timers:
+ _g_timers[timer_id] = Timer()
+ return 0
+ else:
+ return _g_timers[timer_id].since_last_check()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/trace.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/trace.py
new file mode 100644
index 0000000000000000000000000000000000000000..51f6e3cab4ac7bbdf561583d7463a5f2897960e7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/trace.py
@@ -0,0 +1,23 @@
+import warnings
+
+import torch
+
+from annotator.mmpkg.mmcv.utils import digit_version
+
+
+def is_jit_tracing() -> bool:
+ if (torch.__version__ != 'parrots'
+ and digit_version(torch.__version__) >= digit_version('1.6.0')):
+ on_trace = torch.jit.is_tracing()
+ # In PyTorch 1.6, torch.jit.is_tracing has a bug.
+ # Refers to https://github.com/pytorch/pytorch/issues/42448
+ if isinstance(on_trace, bool):
+ return on_trace
+ else:
+ return torch._C._is_tracing()
+ else:
+ warnings.warn(
+ 'torch.jit.is_tracing is only supported after v1.6.0. '
+ 'Therefore is_tracing returns False automatically. Please '
+ 'set on_trace manually if you are using trace.', UserWarning)
+ return False
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/version_utils.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/version_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..963c45a2e8a86a88413ab6c18c22481fb9831985
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/utils/version_utils.py
@@ -0,0 +1,90 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import subprocess
+import warnings
+
+from packaging.version import parse
+
+
+def digit_version(version_str: str, length: int = 4):
+ """Convert a version string into a tuple of integers.
+
+ This method is usually used for comparing two versions. For pre-release
+ versions: alpha < beta < rc.
+
+ Args:
+ version_str (str): The version string.
+ length (int): The maximum number of version levels. Default: 4.
+
+ Returns:
+ tuple[int]: The version info in digits (integers).
+ """
+ assert 'parrots' not in version_str
+ version = parse(version_str)
+ assert version.release, f'failed to parse version {version_str}'
+ release = list(version.release)
+ release = release[:length]
+ if len(release) < length:
+ release = release + [0] * (length - len(release))
+ if version.is_prerelease:
+ mapping = {'a': -3, 'b': -2, 'rc': -1}
+ val = -4
+ # version.pre can be None
+ if version.pre:
+ if version.pre[0] not in mapping:
+ warnings.warn(f'unknown prerelease version {version.pre[0]}, '
+ 'version checking may go wrong')
+ else:
+ val = mapping[version.pre[0]]
+ release.extend([val, version.pre[-1]])
+ else:
+ release.extend([val, 0])
+
+ elif version.is_postrelease:
+ release.extend([1, version.post])
+ else:
+ release.extend([0, 0])
+ return tuple(release)
+
+
+def _minimal_ext_cmd(cmd):
+ # construct minimal environment
+ env = {}
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
+ v = os.environ.get(k)
+ if v is not None:
+ env[k] = v
+ # LANGUAGE is used on win32
+ env['LANGUAGE'] = 'C'
+ env['LANG'] = 'C'
+ env['LC_ALL'] = 'C'
+ out = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
+ return out
+
+
+def get_git_hash(fallback='unknown', digits=None):
+ """Get the git hash of the current repo.
+
+ Args:
+ fallback (str, optional): The fallback string when git hash is
+ unavailable. Defaults to 'unknown'.
+ digits (int, optional): kept digits of the hash. Defaults to None,
+ meaning all digits are kept.
+
+ Returns:
+ str: Git commit hash.
+ """
+
+ if digits is not None and not isinstance(digits, int):
+ raise TypeError('digits must be None or an integer')
+
+ try:
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
+ sha = out.strip().decode('ascii')
+ if digits is not None:
+ sha = sha[:digits]
+ except OSError:
+ sha = fallback
+
+ return sha
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/version.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cce4e50bd692d4002e3cac3c545a3fb2efe95d0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/version.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+__version__ = '1.3.17'
+
+
+def parse_version_info(version_str: str, length: int = 4) -> tuple:
+ """Parse a version string into a tuple.
+
+ Args:
+ version_str (str): The version string.
+ length (int): The maximum number of version levels. Default: 4.
+
+ Returns:
+ tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
+ (1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into
+ (2, 0, 0, 0, 'rc', 1) (when length is set to 4).
+ """
+ from packaging.version import parse
+ version = parse(version_str)
+ assert version.release, f'failed to parse version {version_str}'
+ release = list(version.release)
+ release = release[:length]
+ if len(release) < length:
+ release = release + [0] * (length - len(release))
+ if version.is_prerelease:
+ release.extend(list(version.pre))
+ elif version.is_postrelease:
+ release.extend(list(version.post))
+ else:
+ release.extend([0, 0])
+ return tuple(release)
+
+
+version_info = tuple(int(x) for x in __version__.split('.')[:3])
+
+__all__ = ['__version__', 'version_info', 'parse_version_info']
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/video/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73199b01dec52820dc6ca0139903536344d5a1eb
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/video/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .io import Cache, VideoReader, frames2video
+from .optflow import (dequantize_flow, flow_from_bytes, flow_warp, flowread,
+ flowwrite, quantize_flow, sparse_flow_from_bytes)
+from .processing import concat_video, convert_video, cut_video, resize_video
+
+__all__ = [
+ 'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video',
+ 'cut_video', 'concat_video', 'flowread', 'flowwrite', 'quantize_flow',
+ 'dequantize_flow', 'flow_warp', 'flow_from_bytes', 'sparse_flow_from_bytes'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/video/io.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/video/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..06ae9b8ae4404ec7822fd49c01c183a0be0cbf35
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/video/io.py
@@ -0,0 +1,318 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from collections import OrderedDict
+
+import cv2
+from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
+ CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
+ CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
+
+from annotator.mmpkg.mmcv.utils import (check_file_exist, mkdir_or_exist, scandir,
+ track_progress)
+
+
+class Cache:
+
+ def __init__(self, capacity):
+ self._cache = OrderedDict()
+ self._capacity = int(capacity)
+ if capacity <= 0:
+ raise ValueError('capacity must be a positive integer')
+
+ @property
+ def capacity(self):
+ return self._capacity
+
+ @property
+ def size(self):
+ return len(self._cache)
+
+ def put(self, key, val):
+ if key in self._cache:
+ return
+ if len(self._cache) >= self.capacity:
+ self._cache.popitem(last=False)
+ self._cache[key] = val
+
+ def get(self, key, default=None):
+ val = self._cache[key] if key in self._cache else default
+ return val
+
+
+class VideoReader:
+ """Video class with similar usage to a list object.
+
+ This video warpper class provides convenient apis to access frames.
+ There exists an issue of OpenCV's VideoCapture class that jumping to a
+ certain frame may be inaccurate. It is fixed in this class by checking
+ the position after jumping each time.
+ Cache is used when decoding videos. So if the same frame is visited for
+ the second time, there is no need to decode again if it is stored in the
+ cache.
+
+ :Example:
+
+ >>> import annotator.mmpkg.mmcv as mmcv
+ >>> v = mmcv.VideoReader('sample.mp4')
+ >>> len(v) # get the total frame number with `len()`
+ 120
+ >>> for img in v: # v is iterable
+ >>> mmcv.imshow(img)
+ >>> v[5] # get the 6th frame
+ """
+
+ def __init__(self, filename, cache_capacity=10):
+ # Check whether the video path is a url
+ if not filename.startswith(('https://', 'http://')):
+ check_file_exist(filename, 'Video file not found: ' + filename)
+ self._vcap = cv2.VideoCapture(filename)
+ assert cache_capacity > 0
+ self._cache = Cache(cache_capacity)
+ self._position = 0
+ # get basic info
+ self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
+ self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
+ self._fps = self._vcap.get(CAP_PROP_FPS)
+ self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
+ self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
+
+ @property
+ def vcap(self):
+ """:obj:`cv2.VideoCapture`: The raw VideoCapture object."""
+ return self._vcap
+
+ @property
+ def opened(self):
+ """bool: Indicate whether the video is opened."""
+ return self._vcap.isOpened()
+
+ @property
+ def width(self):
+ """int: Width of video frames."""
+ return self._width
+
+ @property
+ def height(self):
+ """int: Height of video frames."""
+ return self._height
+
+ @property
+ def resolution(self):
+ """tuple: Video resolution (width, height)."""
+ return (self._width, self._height)
+
+ @property
+ def fps(self):
+ """float: FPS of the video."""
+ return self._fps
+
+ @property
+ def frame_cnt(self):
+ """int: Total frames of the video."""
+ return self._frame_cnt
+
+ @property
+ def fourcc(self):
+ """str: "Four character code" of the video."""
+ return self._fourcc
+
+ @property
+ def position(self):
+ """int: Current cursor position, indicating frame decoded."""
+ return self._position
+
+ def _get_real_position(self):
+ return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
+
+ def _set_real_position(self, frame_id):
+ self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
+ pos = self._get_real_position()
+ for _ in range(frame_id - pos):
+ self._vcap.read()
+ self._position = frame_id
+
+ def read(self):
+ """Read the next frame.
+
+ If the next frame have been decoded before and in the cache, then
+ return it directly, otherwise decode, cache and return it.
+
+ Returns:
+ ndarray or None: Return the frame if successful, otherwise None.
+ """
+ # pos = self._position
+ if self._cache:
+ img = self._cache.get(self._position)
+ if img is not None:
+ ret = True
+ else:
+ if self._position != self._get_real_position():
+ self._set_real_position(self._position)
+ ret, img = self._vcap.read()
+ if ret:
+ self._cache.put(self._position, img)
+ else:
+ ret, img = self._vcap.read()
+ if ret:
+ self._position += 1
+ return img
+
+ def get_frame(self, frame_id):
+ """Get frame by index.
+
+ Args:
+ frame_id (int): Index of the expected frame, 0-based.
+
+ Returns:
+ ndarray or None: Return the frame if successful, otherwise None.
+ """
+ if frame_id < 0 or frame_id >= self._frame_cnt:
+ raise IndexError(
+ f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
+ if frame_id == self._position:
+ return self.read()
+ if self._cache:
+ img = self._cache.get(frame_id)
+ if img is not None:
+ self._position = frame_id + 1
+ return img
+ self._set_real_position(frame_id)
+ ret, img = self._vcap.read()
+ if ret:
+ if self._cache:
+ self._cache.put(self._position, img)
+ self._position += 1
+ return img
+
+ def current_frame(self):
+ """Get the current frame (frame that is just visited).
+
+ Returns:
+ ndarray or None: If the video is fresh, return None, otherwise
+ return the frame.
+ """
+ if self._position == 0:
+ return None
+ return self._cache.get(self._position - 1)
+
+ def cvt2frames(self,
+ frame_dir,
+ file_start=0,
+ filename_tmpl='{:06d}.jpg',
+ start=0,
+ max_num=0,
+ show_progress=True):
+ """Convert a video to frame images.
+
+ Args:
+ frame_dir (str): Output directory to store all the frame images.
+ file_start (int): Filenames will start from the specified number.
+ filename_tmpl (str): Filename template with the index as the
+ placeholder.
+ start (int): The starting frame index.
+ max_num (int): Maximum number of frames to be written.
+ show_progress (bool): Whether to show a progress bar.
+ """
+ mkdir_or_exist(frame_dir)
+ if max_num == 0:
+ task_num = self.frame_cnt - start
+ else:
+ task_num = min(self.frame_cnt - start, max_num)
+ if task_num <= 0:
+ raise ValueError('start must be less than total frame number')
+ if start > 0:
+ self._set_real_position(start)
+
+ def write_frame(file_idx):
+ img = self.read()
+ if img is None:
+ return
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
+ cv2.imwrite(filename, img)
+
+ if show_progress:
+ track_progress(write_frame, range(file_start,
+ file_start + task_num))
+ else:
+ for i in range(task_num):
+ write_frame(file_start + i)
+
+ def __len__(self):
+ return self.frame_cnt
+
+ def __getitem__(self, index):
+ if isinstance(index, slice):
+ return [
+ self.get_frame(i)
+ for i in range(*index.indices(self.frame_cnt))
+ ]
+ # support negative indexing
+ if index < 0:
+ index += self.frame_cnt
+ if index < 0:
+ raise IndexError('index out of range')
+ return self.get_frame(index)
+
+ def __iter__(self):
+ self._set_real_position(0)
+ return self
+
+ def __next__(self):
+ img = self.read()
+ if img is not None:
+ return img
+ else:
+ raise StopIteration
+
+ next = __next__
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self._vcap.release()
+
+
+def frames2video(frame_dir,
+ video_file,
+ fps=30,
+ fourcc='XVID',
+ filename_tmpl='{:06d}.jpg',
+ start=0,
+ end=0,
+ show_progress=True):
+ """Read the frame images from a directory and join them as a video.
+
+ Args:
+ frame_dir (str): The directory containing video frames.
+ video_file (str): Output filename.
+ fps (float): FPS of the output video.
+ fourcc (str): Fourcc of the output video, this should be compatible
+ with the output file type.
+ filename_tmpl (str): Filename template with the index as the variable.
+ start (int): Starting frame index.
+ end (int): Ending frame index.
+ show_progress (bool): Whether to show a progress bar.
+ """
+ if end == 0:
+ ext = filename_tmpl.split('.')[-1]
+ end = len([name for name in scandir(frame_dir, ext)])
+ first_file = osp.join(frame_dir, filename_tmpl.format(start))
+ check_file_exist(first_file, 'The start frame not found: ' + first_file)
+ img = cv2.imread(first_file)
+ height, width = img.shape[:2]
+ resolution = (width, height)
+ vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
+ resolution)
+
+ def write_frame(file_idx):
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
+ img = cv2.imread(filename)
+ vwriter.write(img)
+
+ if show_progress:
+ track_progress(write_frame, range(start, end))
+ else:
+ for i in range(start, end):
+ write_frame(i)
+ vwriter.release()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/video/optflow.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/video/optflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bd78970dce8faf30bce0d5f2ec278b994fdd623
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/video/optflow.py
@@ -0,0 +1,254 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import cv2
+import numpy as np
+
+from annotator.mmpkg.mmcv.arraymisc import dequantize, quantize
+from annotator.mmpkg.mmcv.image import imread, imwrite
+from annotator.mmpkg.mmcv.utils import is_str
+
+
+def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
+ """Read an optical flow map.
+
+ Args:
+ flow_or_path (ndarray or str): A flow map or filepath.
+ quantize (bool): whether to read quantized pair, if set to True,
+ remaining args will be passed to :func:`dequantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+
+ Returns:
+ ndarray: Optical flow represented as a (h, w, 2) numpy array
+ """
+ if isinstance(flow_or_path, np.ndarray):
+ if (flow_or_path.ndim != 3) or (flow_or_path.shape[-1] != 2):
+ raise ValueError(f'Invalid flow with shape {flow_or_path.shape}')
+ return flow_or_path
+ elif not is_str(flow_or_path):
+ raise TypeError(f'"flow_or_path" must be a filename or numpy array, '
+ f'not {type(flow_or_path)}')
+
+ if not quantize:
+ with open(flow_or_path, 'rb') as f:
+ try:
+ header = f.read(4).decode('utf-8')
+ except Exception:
+ raise IOError(f'Invalid flow file: {flow_or_path}')
+ else:
+ if header != 'PIEH':
+ raise IOError(f'Invalid flow file: {flow_or_path}, '
+ 'header does not contain PIEH')
+
+ w = np.fromfile(f, np.int32, 1).squeeze()
+ h = np.fromfile(f, np.int32, 1).squeeze()
+ flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
+ else:
+ assert concat_axis in [0, 1]
+ cat_flow = imread(flow_or_path, flag='unchanged')
+ if cat_flow.ndim != 2:
+ raise IOError(
+ f'{flow_or_path} is not a valid quantized flow file, '
+ f'its dimension is {cat_flow.ndim}.')
+ assert cat_flow.shape[concat_axis] % 2 == 0
+ dx, dy = np.split(cat_flow, 2, axis=concat_axis)
+ flow = dequantize_flow(dx, dy, *args, **kwargs)
+
+ return flow.astype(np.float32)
+
+
+def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
+ """Write optical flow to file.
+
+ If the flow is not quantized, it will be saved as a .flo file losslessly,
+ otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
+ will be concatenated horizontally into a single image if quantize is True.)
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ filename (str): Output filepath.
+ quantize (bool): Whether to quantize the flow and save it to 2 jpeg
+ images. If set to True, remaining args will be passed to
+ :func:`quantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+ """
+ if not quantize:
+ with open(filename, 'wb') as f:
+ f.write('PIEH'.encode('utf-8'))
+ np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
+ flow = flow.astype(np.float32)
+ flow.tofile(f)
+ f.flush()
+ else:
+ assert concat_axis in [0, 1]
+ dx, dy = quantize_flow(flow, *args, **kwargs)
+ dxdy = np.concatenate((dx, dy), axis=concat_axis)
+ imwrite(dxdy, filename)
+
+
+def quantize_flow(flow, max_val=0.02, norm=True):
+ """Quantize flow to [0, 255].
+
+ After this step, the size of flow will be much smaller, and can be
+ dumped as jpeg images.
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ max_val (float): Maximum value of flow, values beyond
+ [-max_val, max_val] will be truncated.
+ norm (bool): Whether to divide flow values by image width/height.
+
+ Returns:
+ tuple[ndarray]: Quantized dx and dy.
+ """
+ h, w, _ = flow.shape
+ dx = flow[..., 0]
+ dy = flow[..., 1]
+ if norm:
+ dx = dx / w # avoid inplace operations
+ dy = dy / h
+ # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
+ flow_comps = [
+ quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
+ ]
+ return tuple(flow_comps)
+
+
+def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
+ """Recover from quantized flow.
+
+ Args:
+ dx (ndarray): Quantized dx.
+ dy (ndarray): Quantized dy.
+ max_val (float): Maximum value used when quantizing.
+ denorm (bool): Whether to multiply flow values with width/height.
+
+ Returns:
+ ndarray: Dequantized flow.
+ """
+ assert dx.shape == dy.shape
+ assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
+
+ dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
+
+ if denorm:
+ dx *= dx.shape[1]
+ dy *= dx.shape[0]
+ flow = np.dstack((dx, dy))
+ return flow
+
+
+def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
+ """Use flow to warp img.
+
+ Args:
+ img (ndarray, float or uint8): Image to be warped.
+ flow (ndarray, float): Optical Flow.
+ filling_value (int): The missing pixels will be set with filling_value.
+ interpolate_mode (str): bilinear -> Bilinear Interpolation;
+ nearest -> Nearest Neighbor.
+
+ Returns:
+ ndarray: Warped image with the same shape of img
+ """
+ warnings.warn('This function is just for prototyping and cannot '
+ 'guarantee the computational efficiency.')
+ assert flow.ndim == 3, 'Flow must be in 3D arrays.'
+ height = flow.shape[0]
+ width = flow.shape[1]
+ channels = img.shape[2]
+
+ output = np.ones(
+ (height, width, channels), dtype=img.dtype) * filling_value
+
+ grid = np.indices((height, width)).swapaxes(0, 1).swapaxes(1, 2)
+ dx = grid[:, :, 0] + flow[:, :, 1]
+ dy = grid[:, :, 1] + flow[:, :, 0]
+ sx = np.floor(dx).astype(int)
+ sy = np.floor(dy).astype(int)
+ valid = (sx >= 0) & (sx < height - 1) & (sy >= 0) & (sy < width - 1)
+
+ if interpolate_mode == 'nearest':
+ output[valid, :] = img[dx[valid].round().astype(int),
+ dy[valid].round().astype(int), :]
+ elif interpolate_mode == 'bilinear':
+ # dirty walkround for integer positions
+ eps_ = 1e-6
+ dx, dy = dx + eps_, dy + eps_
+ left_top_ = img[np.floor(dx[valid]).astype(int),
+ np.floor(dy[valid]).astype(int), :] * (
+ np.ceil(dx[valid]) - dx[valid])[:, None] * (
+ np.ceil(dy[valid]) - dy[valid])[:, None]
+ left_down_ = img[np.ceil(dx[valid]).astype(int),
+ np.floor(dy[valid]).astype(int), :] * (
+ dx[valid] - np.floor(dx[valid]))[:, None] * (
+ np.ceil(dy[valid]) - dy[valid])[:, None]
+ right_top_ = img[np.floor(dx[valid]).astype(int),
+ np.ceil(dy[valid]).astype(int), :] * (
+ np.ceil(dx[valid]) - dx[valid])[:, None] * (
+ dy[valid] - np.floor(dy[valid]))[:, None]
+ right_down_ = img[np.ceil(dx[valid]).astype(int),
+ np.ceil(dy[valid]).astype(int), :] * (
+ dx[valid] - np.floor(dx[valid]))[:, None] * (
+ dy[valid] - np.floor(dy[valid]))[:, None]
+ output[valid, :] = left_top_ + left_down_ + right_top_ + right_down_
+ else:
+ raise NotImplementedError(
+ 'We only support interpolation modes of nearest and bilinear, '
+ f'but got {interpolate_mode}.')
+ return output.astype(img.dtype)
+
+
+def flow_from_bytes(content):
+ """Read dense optical flow from bytes.
+
+ .. note::
+ This load optical flow function works for FlyingChairs, FlyingThings3D,
+ Sintel, FlyingChairsOcc datasets, but cannot load the data from
+ ChairsSDHom.
+
+ Args:
+ content (bytes): Optical flow bytes got from files or other streams.
+
+ Returns:
+ ndarray: Loaded optical flow with the shape (H, W, 2).
+ """
+
+ # header in first 4 bytes
+ header = content[:4]
+ if header.decode('utf-8') != 'PIEH':
+ raise Exception('Flow file header does not contain PIEH')
+ # width in second 4 bytes
+ width = np.frombuffer(content[4:], np.int32, 1).squeeze()
+ # height in third 4 bytes
+ height = np.frombuffer(content[8:], np.int32, 1).squeeze()
+ # after first 12 bytes, all bytes are flow
+ flow = np.frombuffer(content[12:], np.float32, width * height * 2).reshape(
+ (height, width, 2))
+
+ return flow
+
+
+def sparse_flow_from_bytes(content):
+ """Read the optical flow in KITTI datasets from bytes.
+
+ This function is modified from RAFT load the `KITTI datasets
+ `_.
+
+ Args:
+ content (bytes): Optical flow bytes got from files or other streams.
+
+ Returns:
+ Tuple(ndarray, ndarray): Loaded optical flow with the shape (H, W, 2)
+ and flow valid mask with the shape (H, W).
+ """ # nopa
+
+ content = np.frombuffer(content, np.uint8)
+ flow = cv2.imdecode(content, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
+ flow = flow[:, :, ::-1].astype(np.float32)
+ # flow shape (H, W, 2) valid shape (H, W)
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
+ flow = (flow - 2**15) / 64.0
+ return flow, valid
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/video/processing.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/video/processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b93a59215d56b6e5ba05f48bca3527772f0c744
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/video/processing.py
@@ -0,0 +1,160 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+import subprocess
+import tempfile
+
+from annotator.mmpkg.mmcv.utils import requires_executable
+
+
+@requires_executable('ffmpeg')
+def convert_video(in_file,
+ out_file,
+ print_cmd=False,
+ pre_options='',
+ **kwargs):
+ """Convert a video with ffmpeg.
+
+ This provides a general api to ffmpeg, the executed command is::
+
+ `ffmpeg -y -i `
+
+ Options(kwargs) are mapped to ffmpeg commands with the following rules:
+
+ - key=val: "-key val"
+ - key=True: "-key"
+ - key=False: ""
+
+ Args:
+ in_file (str): Input video filename.
+ out_file (str): Output video filename.
+ pre_options (str): Options appears before "-i ".
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ options = []
+ for k, v in kwargs.items():
+ if isinstance(v, bool):
+ if v:
+ options.append(f'-{k}')
+ elif k == 'log_level':
+ assert v in [
+ 'quiet', 'panic', 'fatal', 'error', 'warning', 'info',
+ 'verbose', 'debug', 'trace'
+ ]
+ options.append(f'-loglevel {v}')
+ else:
+ options.append(f'-{k} {v}')
+ cmd = f'ffmpeg -y {pre_options} -i {in_file} {" ".join(options)} ' \
+ f'{out_file}'
+ if print_cmd:
+ print(cmd)
+ subprocess.call(cmd, shell=True)
+
+
+@requires_executable('ffmpeg')
+def resize_video(in_file,
+ out_file,
+ size=None,
+ ratio=None,
+ keep_ar=False,
+ log_level='info',
+ print_cmd=False):
+ """Resize a video.
+
+ Args:
+ in_file (str): Input video filename.
+ out_file (str): Output video filename.
+ size (tuple): Expected size (w, h), eg, (320, 240) or (320, -1).
+ ratio (tuple or float): Expected resize ratio, (2, 0.5) means
+ (w*2, h*0.5).
+ keep_ar (bool): Whether to keep original aspect ratio.
+ log_level (str): Logging level of ffmpeg.
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ if size is None and ratio is None:
+ raise ValueError('expected size or ratio must be specified')
+ if size is not None and ratio is not None:
+ raise ValueError('size and ratio cannot be specified at the same time')
+ options = {'log_level': log_level}
+ if size:
+ if not keep_ar:
+ options['vf'] = f'scale={size[0]}:{size[1]}'
+ else:
+ options['vf'] = f'scale=w={size[0]}:h={size[1]}:' \
+ 'force_original_aspect_ratio=decrease'
+ else:
+ if not isinstance(ratio, tuple):
+ ratio = (ratio, ratio)
+ options['vf'] = f'scale="trunc(iw*{ratio[0]}):trunc(ih*{ratio[1]})"'
+ convert_video(in_file, out_file, print_cmd, **options)
+
+
+@requires_executable('ffmpeg')
+def cut_video(in_file,
+ out_file,
+ start=None,
+ end=None,
+ vcodec=None,
+ acodec=None,
+ log_level='info',
+ print_cmd=False):
+ """Cut a clip from a video.
+
+ Args:
+ in_file (str): Input video filename.
+ out_file (str): Output video filename.
+ start (None or float): Start time (in seconds).
+ end (None or float): End time (in seconds).
+ vcodec (None or str): Output video codec, None for unchanged.
+ acodec (None or str): Output audio codec, None for unchanged.
+ log_level (str): Logging level of ffmpeg.
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ options = {'log_level': log_level}
+ if vcodec is None:
+ options['vcodec'] = 'copy'
+ if acodec is None:
+ options['acodec'] = 'copy'
+ if start:
+ options['ss'] = start
+ else:
+ start = 0
+ if end:
+ options['t'] = end - start
+ convert_video(in_file, out_file, print_cmd, **options)
+
+
+@requires_executable('ffmpeg')
+def concat_video(video_list,
+ out_file,
+ vcodec=None,
+ acodec=None,
+ log_level='info',
+ print_cmd=False):
+ """Concatenate multiple videos into a single one.
+
+ Args:
+ video_list (list): A list of video filenames
+ out_file (str): Output video filename
+ vcodec (None or str): Output video codec, None for unchanged
+ acodec (None or str): Output audio codec, None for unchanged
+ log_level (str): Logging level of ffmpeg.
+ print_cmd (bool): Whether to print the final ffmpeg command.
+ """
+ tmp_filehandler, tmp_filename = tempfile.mkstemp(suffix='.txt', text=True)
+ with open(tmp_filename, 'w') as f:
+ for filename in video_list:
+ f.write(f'file {osp.abspath(filename)}\n')
+ options = {'log_level': log_level}
+ if vcodec is None:
+ options['vcodec'] = 'copy'
+ if acodec is None:
+ options['acodec'] = 'copy'
+ convert_video(
+ tmp_filename,
+ out_file,
+ print_cmd,
+ pre_options='-f concat -safe 0',
+ **options)
+ os.close(tmp_filehandler)
+ os.remove(tmp_filename)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/visualization/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/visualization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..835df136bdcf69348281d22914d41aa84cdf92b1
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/visualization/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .color import Color, color_val
+from .image import imshow, imshow_bboxes, imshow_det_bboxes
+from .optflow import flow2rgb, flowshow, make_color_wheel
+
+__all__ = [
+ 'Color', 'color_val', 'imshow', 'imshow_bboxes', 'imshow_det_bboxes',
+ 'flowshow', 'flow2rgb', 'make_color_wheel'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/visualization/color.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/visualization/color.py
new file mode 100644
index 0000000000000000000000000000000000000000..48379a283e48570f226426510270de8e15323c8d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/visualization/color.py
@@ -0,0 +1,51 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from enum import Enum
+
+import numpy as np
+
+from annotator.mmpkg.mmcv.utils import is_str
+
+
+class Color(Enum):
+ """An enum that defines common colors.
+
+ Contains red, green, blue, cyan, yellow, magenta, white and black.
+ """
+ red = (0, 0, 255)
+ green = (0, 255, 0)
+ blue = (255, 0, 0)
+ cyan = (255, 255, 0)
+ yellow = (0, 255, 255)
+ magenta = (255, 0, 255)
+ white = (255, 255, 255)
+ black = (0, 0, 0)
+
+
+def color_val(color):
+ """Convert various input to color tuples.
+
+ Args:
+ color (:obj:`Color`/str/tuple/int/ndarray): Color inputs
+
+ Returns:
+ tuple[int]: A tuple of 3 integers indicating BGR channels.
+ """
+ if is_str(color):
+ return Color[color].value
+ elif isinstance(color, Color):
+ return color.value
+ elif isinstance(color, tuple):
+ assert len(color) == 3
+ for channel in color:
+ assert 0 <= channel <= 255
+ return color
+ elif isinstance(color, int):
+ assert 0 <= color <= 255
+ return color, color, color
+ elif isinstance(color, np.ndarray):
+ assert color.ndim == 1 and color.size == 3
+ assert np.all((color >= 0) & (color <= 255))
+ color = color.astype(np.uint8)
+ return tuple(color)
+ else:
+ raise TypeError(f'Invalid type for color: {type(color)}')
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/visualization/image.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/visualization/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..378de2104f6554389fcb2e6a3904283345fd74b0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/visualization/image.py
@@ -0,0 +1,152 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+from annotator.mmpkg.mmcv.image import imread, imwrite
+from .color import color_val
+
+
+def imshow(img, win_name='', wait_time=0):
+ """Show an image.
+
+ Args:
+ img (str or ndarray): The image to be displayed.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ """
+ cv2.imshow(win_name, imread(img))
+ if wait_time == 0: # prevent from hanging if windows was closed
+ while True:
+ ret = cv2.waitKey(1)
+
+ closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
+ # if user closed window or if some key pressed
+ if closed or ret != -1:
+ break
+ else:
+ ret = cv2.waitKey(wait_time)
+
+
+def imshow_bboxes(img,
+ bboxes,
+ colors='green',
+ top_k=-1,
+ thickness=1,
+ show=True,
+ win_name='',
+ wait_time=0,
+ out_file=None):
+ """Draw bboxes on an image.
+
+ Args:
+ img (str or ndarray): The image to be displayed.
+ bboxes (list or ndarray): A list of ndarray of shape (k, 4).
+ colors (list[str or tuple or Color]): A list of colors.
+ top_k (int): Plot the first k bboxes only if set positive.
+ thickness (int): Thickness of lines.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ out_file (str, optional): The filename to write the image.
+
+ Returns:
+ ndarray: The image with bboxes drawn on it.
+ """
+ img = imread(img)
+ img = np.ascontiguousarray(img)
+
+ if isinstance(bboxes, np.ndarray):
+ bboxes = [bboxes]
+ if not isinstance(colors, list):
+ colors = [colors for _ in range(len(bboxes))]
+ colors = [color_val(c) for c in colors]
+ assert len(bboxes) == len(colors)
+
+ for i, _bboxes in enumerate(bboxes):
+ _bboxes = _bboxes.astype(np.int32)
+ if top_k <= 0:
+ _top_k = _bboxes.shape[0]
+ else:
+ _top_k = min(top_k, _bboxes.shape[0])
+ for j in range(_top_k):
+ left_top = (_bboxes[j, 0], _bboxes[j, 1])
+ right_bottom = (_bboxes[j, 2], _bboxes[j, 3])
+ cv2.rectangle(
+ img, left_top, right_bottom, colors[i], thickness=thickness)
+
+ if show:
+ imshow(img, win_name, wait_time)
+ if out_file is not None:
+ imwrite(img, out_file)
+ return img
+
+
+def imshow_det_bboxes(img,
+ bboxes,
+ labels,
+ class_names=None,
+ score_thr=0,
+ bbox_color='green',
+ text_color='green',
+ thickness=1,
+ font_scale=0.5,
+ show=True,
+ win_name='',
+ wait_time=0,
+ out_file=None):
+ """Draw bboxes and class labels (with scores) on an image.
+
+ Args:
+ img (str or ndarray): The image to be displayed.
+ bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
+ (n, 5).
+ labels (ndarray): Labels of bboxes.
+ class_names (list[str]): Names of each classes.
+ score_thr (float): Minimum score of bboxes to be shown.
+ bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
+ text_color (str or tuple or :obj:`Color`): Color of texts.
+ thickness (int): Thickness of lines.
+ font_scale (float): Font scales of texts.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ out_file (str or None): The filename to write the image.
+
+ Returns:
+ ndarray: The image with bboxes drawn on it.
+ """
+ assert bboxes.ndim == 2
+ assert labels.ndim == 1
+ assert bboxes.shape[0] == labels.shape[0]
+ assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5
+ img = imread(img)
+ img = np.ascontiguousarray(img)
+
+ if score_thr > 0:
+ assert bboxes.shape[1] == 5
+ scores = bboxes[:, -1]
+ inds = scores > score_thr
+ bboxes = bboxes[inds, :]
+ labels = labels[inds]
+
+ bbox_color = color_val(bbox_color)
+ text_color = color_val(text_color)
+
+ for bbox, label in zip(bboxes, labels):
+ bbox_int = bbox.astype(np.int32)
+ left_top = (bbox_int[0], bbox_int[1])
+ right_bottom = (bbox_int[2], bbox_int[3])
+ cv2.rectangle(
+ img, left_top, right_bottom, bbox_color, thickness=thickness)
+ label_text = class_names[
+ label] if class_names is not None else f'cls {label}'
+ if len(bbox) > 4:
+ label_text += f'|{bbox[-1]:.02f}'
+ cv2.putText(img, label_text, (bbox_int[0], bbox_int[1] - 2),
+ cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)
+
+ if show:
+ imshow(img, win_name, wait_time)
+ if out_file is not None:
+ imwrite(img, out_file)
+ return img
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmcv/visualization/optflow.py b/sd-webui-controlnet/annotator/mmpkg/mmcv/visualization/optflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4c3ce980f9f6c74c85fe714aca1623a08ae7a8d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmcv/visualization/optflow.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from __future__ import division
+
+import numpy as np
+
+from annotator.mmpkg.mmcv.image import rgb2bgr
+from annotator.mmpkg.mmcv.video import flowread
+from .image import imshow
+
+
+def flowshow(flow, win_name='', wait_time=0):
+ """Show optical flow.
+
+ Args:
+ flow (ndarray or str): The optical flow to be displayed.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ """
+ flow = flowread(flow)
+ flow_img = flow2rgb(flow)
+ imshow(rgb2bgr(flow_img), win_name, wait_time)
+
+
+def flow2rgb(flow, color_wheel=None, unknown_thr=1e6):
+ """Convert flow map to RGB image.
+
+ Args:
+ flow (ndarray): Array of optical flow.
+ color_wheel (ndarray or None): Color wheel used to map flow field to
+ RGB colorspace. Default color wheel will be used if not specified.
+ unknown_thr (str): Values above this threshold will be marked as
+ unknown and thus ignored.
+
+ Returns:
+ ndarray: RGB image that can be visualized.
+ """
+ assert flow.ndim == 3 and flow.shape[-1] == 2
+ if color_wheel is None:
+ color_wheel = make_color_wheel()
+ assert color_wheel.ndim == 2 and color_wheel.shape[1] == 3
+ num_bins = color_wheel.shape[0]
+
+ dx = flow[:, :, 0].copy()
+ dy = flow[:, :, 1].copy()
+
+ ignore_inds = (
+ np.isnan(dx) | np.isnan(dy) | (np.abs(dx) > unknown_thr) |
+ (np.abs(dy) > unknown_thr))
+ dx[ignore_inds] = 0
+ dy[ignore_inds] = 0
+
+ rad = np.sqrt(dx**2 + dy**2)
+ if np.any(rad > np.finfo(float).eps):
+ max_rad = np.max(rad)
+ dx /= max_rad
+ dy /= max_rad
+
+ rad = np.sqrt(dx**2 + dy**2)
+ angle = np.arctan2(-dy, -dx) / np.pi
+
+ bin_real = (angle + 1) / 2 * (num_bins - 1)
+ bin_left = np.floor(bin_real).astype(int)
+ bin_right = (bin_left + 1) % num_bins
+ w = (bin_real - bin_left.astype(np.float32))[..., None]
+ flow_img = (1 -
+ w) * color_wheel[bin_left, :] + w * color_wheel[bin_right, :]
+ small_ind = rad <= 1
+ flow_img[small_ind] = 1 - rad[small_ind, None] * (1 - flow_img[small_ind])
+ flow_img[np.logical_not(small_ind)] *= 0.75
+
+ flow_img[ignore_inds, :] = 0
+
+ return flow_img
+
+
+def make_color_wheel(bins=None):
+ """Build a color wheel.
+
+ Args:
+ bins(list or tuple, optional): Specify the number of bins for each
+ color range, corresponding to six ranges: red -> yellow,
+ yellow -> green, green -> cyan, cyan -> blue, blue -> magenta,
+ magenta -> red. [15, 6, 4, 11, 13, 6] is used for default
+ (see Middlebury).
+
+ Returns:
+ ndarray: Color wheel of shape (total_bins, 3).
+ """
+ if bins is None:
+ bins = [15, 6, 4, 11, 13, 6]
+ assert len(bins) == 6
+
+ RY, YG, GC, CB, BM, MR = tuple(bins)
+
+ ry = [1, np.arange(RY) / RY, 0]
+ yg = [1 - np.arange(YG) / YG, 1, 0]
+ gc = [0, 1, np.arange(GC) / GC]
+ cb = [0, 1 - np.arange(CB) / CB, 1]
+ bm = [np.arange(BM) / BM, 0, 1]
+ mr = [1, 0, 1 - np.arange(MR) / MR]
+
+ num_bins = RY + YG + GC + CB + BM + MR
+
+ color_wheel = np.zeros((3, num_bins), dtype=np.float32)
+
+ col = 0
+ for i, color in enumerate([ry, yg, gc, cb, bm, mr]):
+ for j in range(3):
+ color_wheel[j, col:col + bins[i]] = color[j]
+ col += bins[i]
+
+ return color_wheel.T
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/apis/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/apis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..170724be38de42daf2bc1a1910e181d68818f165
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/apis/__init__.py
@@ -0,0 +1,9 @@
+from .inference import inference_segmentor, init_segmentor, show_result_pyplot
+from .test import multi_gpu_test, single_gpu_test
+from .train import get_root_logger, set_random_seed, train_segmentor
+
+__all__ = [
+ 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
+ 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
+ 'show_result_pyplot'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/apis/inference.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/apis/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..515e459ff6e66e955624fedaf32d2076be750563
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/apis/inference.py
@@ -0,0 +1,138 @@
+import matplotlib.pyplot as plt
+import annotator.mmpkg.mmcv as mmcv
+import torch
+from annotator.mmpkg.mmcv.parallel import collate, scatter
+from annotator.mmpkg.mmcv.runner import load_checkpoint
+
+from annotator.mmpkg.mmseg.datasets.pipelines import Compose
+from annotator.mmpkg.mmseg.models import build_segmentor
+from modules import devices
+
+
+def init_segmentor(config, checkpoint=None, device=devices.get_device_for("controlnet")):
+ """Initialize a segmentor from config file.
+
+ Args:
+ config (str or :obj:`mmcv.Config`): Config file path or the config
+ object.
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
+ will not load any weights.
+ device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
+ Use 'cpu' for loading model on CPU.
+ Returns:
+ nn.Module: The constructed segmentor.
+ """
+ if isinstance(config, str):
+ config = mmcv.Config.fromfile(config)
+ elif not isinstance(config, mmcv.Config):
+ raise TypeError('config must be a filename or Config object, '
+ 'but got {}'.format(type(config)))
+ config.model.pretrained = None
+ config.model.train_cfg = None
+ model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
+ if checkpoint is not None:
+ checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
+ model.CLASSES = checkpoint['meta']['CLASSES']
+ model.PALETTE = checkpoint['meta']['PALETTE']
+ model.cfg = config # save the config in the model for convenience
+ model.to(device)
+ model.eval()
+ return model
+
+
+class LoadImage:
+ """A simple pipeline to load image."""
+
+ def __call__(self, results):
+ """Call function to load images into results.
+
+ Args:
+ results (dict): A result dict contains the file name
+ of the image to be read.
+
+ Returns:
+ dict: ``results`` will be returned containing loaded image.
+ """
+
+ if isinstance(results['img'], str):
+ results['filename'] = results['img']
+ results['ori_filename'] = results['img']
+ else:
+ results['filename'] = None
+ results['ori_filename'] = None
+ img = mmcv.imread(results['img'])
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ return results
+
+
+def inference_segmentor(model, img):
+ """Inference image(s) with the segmentor.
+
+ Args:
+ model (nn.Module): The loaded segmentor.
+ imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
+ images.
+
+ Returns:
+ (list[Tensor]): The segmentation result.
+ """
+ cfg = model.cfg
+ device = next(model.parameters()).device # model device
+ # build the data pipeline
+ test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
+ test_pipeline = Compose(test_pipeline)
+ # prepare data
+ data = dict(img=img)
+ data = test_pipeline(data)
+ data = collate([data], samples_per_gpu=1)
+ if next(model.parameters()).is_cuda:
+ # scatter to specified GPU
+ data = scatter(data, [device])[0]
+ else:
+ data['img'][0] = data['img'][0].to(devices.get_device_for("controlnet"))
+ data['img_metas'] = [i.data[0] for i in data['img_metas']]
+
+ # forward the model
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)
+ return result
+
+
+def show_result_pyplot(model,
+ img,
+ result,
+ palette=None,
+ fig_size=(15, 10),
+ opacity=0.5,
+ title='',
+ block=True):
+ """Visualize the segmentation results on the image.
+
+ Args:
+ model (nn.Module): The loaded segmentor.
+ img (str or np.ndarray): Image filename or loaded image.
+ result (list): The segmentation result.
+ palette (list[list[int]]] | None): The palette of segmentation
+ map. If None is given, random palette will be generated.
+ Default: None
+ fig_size (tuple): Figure size of the pyplot figure.
+ opacity(float): Opacity of painted segmentation map.
+ Default 0.5.
+ Must be in (0, 1] range.
+ title (str): The title of pyplot figure.
+ Default is ''.
+ block (bool): Whether to block the pyplot figure.
+ Default is True.
+ """
+ if hasattr(model, 'module'):
+ model = model.module
+ img = model.show_result(
+ img, result, palette=palette, show=False, opacity=opacity)
+ # plt.figure(figsize=fig_size)
+ # plt.imshow(mmcv.bgr2rgb(img))
+ # plt.title(title)
+ # plt.tight_layout()
+ # plt.show(block=block)
+ return mmcv.bgr2rgb(img)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/apis/test.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/apis/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9954e6a3709afdbf6a2027b213afcad644c47d7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/apis/test.py
@@ -0,0 +1,238 @@
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+
+import annotator.mmpkg.mmcv as mmcv
+import numpy as np
+import torch
+import torch.distributed as dist
+from annotator.mmpkg.mmcv.image import tensor2imgs
+from annotator.mmpkg.mmcv.runner import get_dist_info
+
+
+def np2tmp(array, temp_file_name=None):
+ """Save ndarray to local numpy file.
+
+ Args:
+ array (ndarray): Ndarray to save.
+ temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
+ function will generate a file name with tempfile.NamedTemporaryFile
+ to save ndarray. Default: None.
+
+ Returns:
+ str: The numpy file name.
+ """
+
+ if temp_file_name is None:
+ temp_file_name = tempfile.NamedTemporaryFile(
+ suffix='.npy', delete=False).name
+ np.save(temp_file_name, array)
+ return temp_file_name
+
+
+def single_gpu_test(model,
+ data_loader,
+ show=False,
+ out_dir=None,
+ efficient_test=False,
+ opacity=0.5):
+ """Test with single GPU.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (utils.data.Dataloader): Pytorch data loader.
+ show (bool): Whether show results during inference. Default: False.
+ out_dir (str, optional): If specified, the results will be dumped into
+ the directory to save output results.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+ opacity(float): Opacity of painted segmentation map.
+ Default 0.5.
+ Must be in (0, 1] range.
+ Returns:
+ list: The prediction results.
+ """
+
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, **data)
+
+ if show or out_dir:
+ img_tensor = data['img'][0]
+ img_metas = data['img_metas'][0].data[0]
+ imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
+ assert len(imgs) == len(img_metas)
+
+ for img, img_meta in zip(imgs, img_metas):
+ h, w, _ = img_meta['img_shape']
+ img_show = img[:h, :w, :]
+
+ ori_h, ori_w = img_meta['ori_shape'][:-1]
+ img_show = mmcv.imresize(img_show, (ori_w, ori_h))
+
+ if out_dir:
+ out_file = osp.join(out_dir, img_meta['ori_filename'])
+ else:
+ out_file = None
+
+ model.module.show_result(
+ img_show,
+ result,
+ palette=dataset.PALETTE,
+ show=show,
+ out_file=out_file,
+ opacity=opacity)
+
+ if isinstance(result, list):
+ if efficient_test:
+ result = [np2tmp(_) for _ in result]
+ results.extend(result)
+ else:
+ if efficient_test:
+ result = np2tmp(result)
+ results.append(result)
+
+ batch_size = len(result)
+ for _ in range(batch_size):
+ prog_bar.update()
+ return results
+
+
+def multi_gpu_test(model,
+ data_loader,
+ tmpdir=None,
+ gpu_collect=False,
+ efficient_test=False):
+ """Test model with multiple gpus.
+
+ This method tests model with multiple gpus and collects the results
+ under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
+ it encodes results to gpu tensors and use gpu communication for results
+ collection. On cpu mode it saves the results on different gpus to 'tmpdir'
+ and collects them by the rank 0 worker.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (utils.data.Dataloader): Pytorch data loader.
+ tmpdir (str): Path of directory to save the temporary results from
+ different gpus under cpu mode.
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+
+ Returns:
+ list: The prediction results.
+ """
+
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+ if rank == 0:
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)
+
+ if isinstance(result, list):
+ if efficient_test:
+ result = [np2tmp(_) for _ in result]
+ results.extend(result)
+ else:
+ if efficient_test:
+ result = np2tmp(result)
+ results.append(result)
+
+ if rank == 0:
+ batch_size = data['img'][0].size(0)
+ for _ in range(batch_size * world_size):
+ prog_bar.update()
+
+ # collect results from all ranks
+ if gpu_collect:
+ results = collect_results_gpu(results, len(dataset))
+ else:
+ results = collect_results_cpu(results, len(dataset), tmpdir)
+ return results
+
+
+def collect_results_cpu(result_part, size, tmpdir=None):
+ """Collect results with CPU."""
+ rank, world_size = get_dist_info()
+ # create a tmp dir if it is not specified
+ if tmpdir is None:
+ MAX_LEN = 512
+ # 32 is whitespace
+ dir_tensor = torch.full((MAX_LEN, ),
+ 32,
+ dtype=torch.uint8,
+ device='cuda')
+ if rank == 0:
+ tmpdir = tempfile.mkdtemp()
+ tmpdir = torch.tensor(
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
+ dir_tensor[:len(tmpdir)] = tmpdir
+ dist.broadcast(dir_tensor, 0)
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+ else:
+ mmcv.mkdir_or_exist(tmpdir)
+ # dump the part result to the dir
+ mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
+ dist.barrier()
+ # collect all parts
+ if rank != 0:
+ return None
+ else:
+ # load results of all parts from tmp dir
+ part_list = []
+ for i in range(world_size):
+ part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
+ part_list.append(mmcv.load(part_file))
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ # remove tmp dir
+ shutil.rmtree(tmpdir)
+ return ordered_results
+
+
+def collect_results_gpu(result_part, size):
+ """Collect results with GPU."""
+ rank, world_size = get_dist_info()
+ # dump result part to tensor with pickle
+ part_tensor = torch.tensor(
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+ # gather all result part tensor shape
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
+ dist.all_gather(shape_list, shape_tensor)
+ # padding result part tensor to max length
+ shape_max = torch.tensor(shape_list).max()
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+ part_send[:shape_tensor[0]] = part_tensor
+ part_recv_list = [
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
+ ]
+ # gather all result part
+ dist.all_gather(part_recv_list, part_send)
+
+ if rank == 0:
+ part_list = []
+ for recv, shape in zip(part_recv_list, shape_list):
+ part_list.append(
+ pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ return ordered_results
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/apis/train.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/apis/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0a87d65c72e4581c96b41aebf879905510c9d22
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/apis/train.py
@@ -0,0 +1,116 @@
+import random
+import warnings
+
+import numpy as np
+import torch
+from annotator.mmpkg.mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from annotator.mmpkg.mmcv.runner import build_optimizer, build_runner
+
+from annotator.mmpkg.mmseg.core import DistEvalHook, EvalHook
+from annotator.mmpkg.mmseg.datasets import build_dataloader, build_dataset
+from annotator.mmpkg.mmseg.utils import get_root_logger
+
+
+def set_random_seed(seed, deterministic=False):
+ """Set random seed.
+
+ Args:
+ seed (int): Seed to be used.
+ deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ if deterministic:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def train_segmentor(model,
+ dataset,
+ cfg,
+ distributed=False,
+ validate=False,
+ timestamp=None,
+ meta=None):
+ """Launch segmentor training."""
+ logger = get_root_logger(cfg.log_level)
+
+ # prepare data loaders
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
+ data_loaders = [
+ build_dataloader(
+ ds,
+ cfg.data.samples_per_gpu,
+ cfg.data.workers_per_gpu,
+ # cfg.gpus will be ignored if distributed
+ len(cfg.gpu_ids),
+ dist=distributed,
+ seed=cfg.seed,
+ drop_last=True) for ds in dataset
+ ]
+
+ # put model on gpus
+ if distributed:
+ find_unused_parameters = cfg.get('find_unused_parameters', False)
+ # Sets the `find_unused_parameters` parameter in
+ # torch.nn.parallel.DistributedDataParallel
+ model = MMDistributedDataParallel(
+ model.cuda(),
+ device_ids=[torch.cuda.current_device()],
+ broadcast_buffers=False,
+ find_unused_parameters=find_unused_parameters)
+ else:
+ model = MMDataParallel(
+ model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
+
+ # build runner
+ optimizer = build_optimizer(model, cfg.optimizer)
+
+ if cfg.get('runner') is None:
+ cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
+ warnings.warn(
+ 'config is now expected to have a `runner` section, '
+ 'please set `runner` in your config.', UserWarning)
+
+ runner = build_runner(
+ cfg.runner,
+ default_args=dict(
+ model=model,
+ batch_processor=None,
+ optimizer=optimizer,
+ work_dir=cfg.work_dir,
+ logger=logger,
+ meta=meta))
+
+ # register hooks
+ runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
+ cfg.checkpoint_config, cfg.log_config,
+ cfg.get('momentum_config', None))
+
+ # an ugly walkaround to make the .log and .log.json filenames the same
+ runner.timestamp = timestamp
+
+ # register eval hooks
+ if validate:
+ val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
+ val_dataloader = build_dataloader(
+ val_dataset,
+ samples_per_gpu=1,
+ workers_per_gpu=cfg.data.workers_per_gpu,
+ dist=distributed,
+ shuffle=False)
+ eval_cfg = cfg.get('evaluation', {})
+ eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
+ eval_hook = DistEvalHook if distributed else EvalHook
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg), priority='LOW')
+
+ if cfg.resume_from:
+ runner.resume(cfg.resume_from)
+ elif cfg.load_from:
+ runner.load_checkpoint(cfg.load_from)
+ runner.run(data_loaders, cfg.workflow)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/core/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..965605587211b7bf0bd6bc3acdbb33dd49cab023
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/__init__.py
@@ -0,0 +1,3 @@
+from .evaluation import * # noqa: F401, F403
+from .seg import * # noqa: F401, F403
+from .utils import * # noqa: F401, F403
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/core/evaluation/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7cc4b23413a0639e9de00eeb0bf600632d2c6cd
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/evaluation/__init__.py
@@ -0,0 +1,8 @@
+from .class_names import get_classes, get_palette
+from .eval_hooks import DistEvalHook, EvalHook
+from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou
+
+__all__ = [
+ 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
+ 'eval_metrics', 'get_classes', 'get_palette'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/core/evaluation/class_names.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/evaluation/class_names.py
new file mode 100644
index 0000000000000000000000000000000000000000..532c5fd78946ede66d747ec8e7b72dbb66471aac
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/evaluation/class_names.py
@@ -0,0 +1,152 @@
+import annotator.mmpkg.mmcv as mmcv
+
+
+def cityscapes_classes():
+ """Cityscapes class names for external use."""
+ return [
+ 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+ 'bicycle'
+ ]
+
+
+def ade_classes():
+ """ADE20K class names for external use."""
+ return [
+ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
+ 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
+ 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
+ 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
+ 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
+ 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
+ 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
+ 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
+ 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
+ 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
+ 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
+ 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
+ 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
+ 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
+ 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
+ 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
+ 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
+ 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
+ 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
+ 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
+ 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
+ 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
+ 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
+ 'clock', 'flag'
+ ]
+
+
+def voc_classes():
+ """Pascal VOC class names for external use."""
+ return [
+ 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
+ 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
+ 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
+ 'tvmonitor'
+ ]
+
+
+def cityscapes_palette():
+ """Cityscapes palette for external use."""
+ return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
+ [0, 0, 230], [119, 11, 32]]
+
+
+def ade_palette():
+ """ADE20K palette for external use."""
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+ [102, 255, 0], [92, 0, 255]]
+
+
+def voc_palette():
+ """Pascal VOC palette for external use."""
+ return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
+
+
+dataset_aliases = {
+ 'cityscapes': ['cityscapes'],
+ 'ade': ['ade', 'ade20k'],
+ 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug']
+}
+
+
+def get_classes(dataset):
+ """Get class names of a dataset."""
+ alias2name = {}
+ for name, aliases in dataset_aliases.items():
+ for alias in aliases:
+ alias2name[alias] = name
+
+ if mmcv.is_str(dataset):
+ if dataset in alias2name:
+ labels = eval(alias2name[dataset] + '_classes()')
+ else:
+ raise ValueError(f'Unrecognized dataset: {dataset}')
+ else:
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
+ return labels
+
+
+def get_palette(dataset):
+ """Get class palette (RGB) of a dataset."""
+ alias2name = {}
+ for name, aliases in dataset_aliases.items():
+ for alias in aliases:
+ alias2name[alias] = name
+
+ if mmcv.is_str(dataset):
+ if dataset in alias2name:
+ labels = eval(alias2name[dataset] + '_palette()')
+ else:
+ raise ValueError(f'Unrecognized dataset: {dataset}')
+ else:
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
+ return labels
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/core/evaluation/eval_hooks.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/evaluation/eval_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..408e9670f61d1b118477562b341adc644c52799a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/evaluation/eval_hooks.py
@@ -0,0 +1,109 @@
+import os.path as osp
+
+from annotator.mmpkg.mmcv.runner import DistEvalHook as _DistEvalHook
+from annotator.mmpkg.mmcv.runner import EvalHook as _EvalHook
+
+
+class EvalHook(_EvalHook):
+ """Single GPU EvalHook, with efficient test support.
+
+ Args:
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ Default: False.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+ Returns:
+ list: The prediction results.
+ """
+
+ greater_keys = ['mIoU', 'mAcc', 'aAcc']
+
+ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
+ super().__init__(*args, by_epoch=by_epoch, **kwargs)
+ self.efficient_test = efficient_test
+
+ def after_train_iter(self, runner):
+ """After train epoch hook.
+
+ Override default ``single_gpu_test``.
+ """
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
+ return
+ from annotator.mmpkg.mmseg.apis import single_gpu_test
+ runner.log_buffer.clear()
+ results = single_gpu_test(
+ runner.model,
+ self.dataloader,
+ show=False,
+ efficient_test=self.efficient_test)
+ self.evaluate(runner, results)
+
+ def after_train_epoch(self, runner):
+ """After train epoch hook.
+
+ Override default ``single_gpu_test``.
+ """
+ if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
+ return
+ from annotator.mmpkg.mmseg.apis import single_gpu_test
+ runner.log_buffer.clear()
+ results = single_gpu_test(runner.model, self.dataloader, show=False)
+ self.evaluate(runner, results)
+
+
+class DistEvalHook(_DistEvalHook):
+ """Distributed EvalHook, with efficient test support.
+
+ Args:
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+ If set to True, it will perform by epoch. Otherwise, by iteration.
+ Default: False.
+ efficient_test (bool): Whether save the results as local numpy files to
+ save CPU memory during evaluation. Default: False.
+ Returns:
+ list: The prediction results.
+ """
+
+ greater_keys = ['mIoU', 'mAcc', 'aAcc']
+
+ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
+ super().__init__(*args, by_epoch=by_epoch, **kwargs)
+ self.efficient_test = efficient_test
+
+ def after_train_iter(self, runner):
+ """After train epoch hook.
+
+ Override default ``multi_gpu_test``.
+ """
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
+ return
+ from annotator.mmpkg.mmseg.apis import multi_gpu_test
+ runner.log_buffer.clear()
+ results = multi_gpu_test(
+ runner.model,
+ self.dataloader,
+ tmpdir=osp.join(runner.work_dir, '.eval_hook'),
+ gpu_collect=self.gpu_collect,
+ efficient_test=self.efficient_test)
+ if runner.rank == 0:
+ print('\n')
+ self.evaluate(runner, results)
+
+ def after_train_epoch(self, runner):
+ """After train epoch hook.
+
+ Override default ``multi_gpu_test``.
+ """
+ if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
+ return
+ from annotator.mmpkg.mmseg.apis import multi_gpu_test
+ runner.log_buffer.clear()
+ results = multi_gpu_test(
+ runner.model,
+ self.dataloader,
+ tmpdir=osp.join(runner.work_dir, '.eval_hook'),
+ gpu_collect=self.gpu_collect)
+ if runner.rank == 0:
+ print('\n')
+ self.evaluate(runner, results)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/core/evaluation/metrics.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/evaluation/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ede737624a0ba6e6365639f7019ac2527052cfd
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/evaluation/metrics.py
@@ -0,0 +1,326 @@
+from collections import OrderedDict
+
+import annotator.mmpkg.mmcv as mmcv
+import numpy as np
+import torch
+
+
+def f_score(precision, recall, beta=1):
+ """calcuate the f-score value.
+
+ Args:
+ precision (float | torch.Tensor): The precision value.
+ recall (float | torch.Tensor): The recall value.
+ beta (int): Determines the weight of recall in the combined score.
+ Default: False.
+
+ Returns:
+ [torch.tensor]: The f-score value.
+ """
+ score = (1 + beta**2) * (precision * recall) / (
+ (beta**2 * precision) + recall)
+ return score
+
+
+def intersect_and_union(pred_label,
+ label,
+ num_classes,
+ ignore_index,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate intersection and Union.
+
+ Args:
+ pred_label (ndarray | str): Prediction segmentation map
+ or predict result filename.
+ label (ndarray | str): Ground truth segmentation map
+ or label filename.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ label_map (dict): Mapping old labels to new labels. The parameter will
+ work only when label is str. Default: dict().
+ reduce_zero_label (bool): Whether ignore zero label. The parameter will
+ work only when label is str. Default: False.
+
+ Returns:
+ torch.Tensor: The intersection of prediction and ground truth
+ histogram on all classes.
+ torch.Tensor: The union of prediction and ground truth histogram on
+ all classes.
+ torch.Tensor: The prediction histogram on all classes.
+ torch.Tensor: The ground truth histogram on all classes.
+ """
+
+ if isinstance(pred_label, str):
+ pred_label = torch.from_numpy(np.load(pred_label))
+ else:
+ pred_label = torch.from_numpy((pred_label))
+
+ if isinstance(label, str):
+ label = torch.from_numpy(
+ mmcv.imread(label, flag='unchanged', backend='pillow'))
+ else:
+ label = torch.from_numpy(label)
+
+ if label_map is not None:
+ for old_id, new_id in label_map.items():
+ label[label == old_id] = new_id
+ if reduce_zero_label:
+ label[label == 0] = 255
+ label = label - 1
+ label[label == 254] = 255
+
+ mask = (label != ignore_index)
+ pred_label = pred_label[mask]
+ label = label[mask]
+
+ intersect = pred_label[pred_label == label]
+ area_intersect = torch.histc(
+ intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
+ area_pred_label = torch.histc(
+ pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
+ area_label = torch.histc(
+ label.float(), bins=(num_classes), min=0, max=num_classes - 1)
+ area_union = area_pred_label + area_label - area_intersect
+ return area_intersect, area_union, area_pred_label, area_label
+
+
+def total_intersect_and_union(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate Total Intersection and Union.
+
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
+
+ Returns:
+ ndarray: The intersection of prediction and ground truth histogram
+ on all classes.
+ ndarray: The union of prediction and ground truth histogram on all
+ classes.
+ ndarray: The prediction histogram on all classes.
+ ndarray: The ground truth histogram on all classes.
+ """
+ num_imgs = len(results)
+ assert len(gt_seg_maps) == num_imgs
+ total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
+ total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
+ total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
+ total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
+ for i in range(num_imgs):
+ area_intersect, area_union, area_pred_label, area_label = \
+ intersect_and_union(
+ results[i], gt_seg_maps[i], num_classes, ignore_index,
+ label_map, reduce_zero_label)
+ total_area_intersect += area_intersect
+ total_area_union += area_union
+ total_area_pred_label += area_pred_label
+ total_area_label += area_label
+ return total_area_intersect, total_area_union, total_area_pred_label, \
+ total_area_label
+
+
+def mean_iou(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate Mean Intersection and Union (mIoU)
+
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
+
+ Returns:
+ dict[str, float | ndarray]:
+ float: Overall accuracy on all images.
+ ndarray: Per category accuracy, shape (num_classes, ).
+ ndarray: Per category IoU, shape (num_classes, ).
+ """
+ iou_result = eval_metrics(
+ results=results,
+ gt_seg_maps=gt_seg_maps,
+ num_classes=num_classes,
+ ignore_index=ignore_index,
+ metrics=['mIoU'],
+ nan_to_num=nan_to_num,
+ label_map=label_map,
+ reduce_zero_label=reduce_zero_label)
+ return iou_result
+
+
+def mean_dice(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False):
+ """Calculate Mean Dice (mDice)
+
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
+
+ Returns:
+ dict[str, float | ndarray]: Default metrics.
+ float: Overall accuracy on all images.
+ ndarray: Per category accuracy, shape (num_classes, ).
+ ndarray: Per category dice, shape (num_classes, ).
+ """
+
+ dice_result = eval_metrics(
+ results=results,
+ gt_seg_maps=gt_seg_maps,
+ num_classes=num_classes,
+ ignore_index=ignore_index,
+ metrics=['mDice'],
+ nan_to_num=nan_to_num,
+ label_map=label_map,
+ reduce_zero_label=reduce_zero_label)
+ return dice_result
+
+
+def mean_fscore(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False,
+ beta=1):
+ """Calculate Mean Intersection and Union (mIoU)
+
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
+ beta (int): Determines the weight of recall in the combined score.
+ Default: False.
+
+
+ Returns:
+ dict[str, float | ndarray]: Default metrics.
+ float: Overall accuracy on all images.
+ ndarray: Per category recall, shape (num_classes, ).
+ ndarray: Per category precision, shape (num_classes, ).
+ ndarray: Per category f-score, shape (num_classes, ).
+ """
+ fscore_result = eval_metrics(
+ results=results,
+ gt_seg_maps=gt_seg_maps,
+ num_classes=num_classes,
+ ignore_index=ignore_index,
+ metrics=['mFscore'],
+ nan_to_num=nan_to_num,
+ label_map=label_map,
+ reduce_zero_label=reduce_zero_label,
+ beta=beta)
+ return fscore_result
+
+
+def eval_metrics(results,
+ gt_seg_maps,
+ num_classes,
+ ignore_index,
+ metrics=['mIoU'],
+ nan_to_num=None,
+ label_map=dict(),
+ reduce_zero_label=False,
+ beta=1):
+ """Calculate evaluation metrics
+ Args:
+ results (list[ndarray] | list[str]): List of prediction segmentation
+ maps or list of prediction result filenames.
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+ segmentation maps or list of label filenames.
+ num_classes (int): Number of categories.
+ ignore_index (int): Index that will be ignored in evaluation.
+ metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
+ nan_to_num (int, optional): If specified, NaN values will be replaced
+ by the numbers defined by the user. Default: None.
+ label_map (dict): Mapping old labels to new labels. Default: dict().
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
+ Returns:
+ float: Overall accuracy on all images.
+ ndarray: Per category accuracy, shape (num_classes, ).
+ ndarray: Per category evaluation metrics, shape (num_classes, ).
+ """
+ if isinstance(metrics, str):
+ metrics = [metrics]
+ allowed_metrics = ['mIoU', 'mDice', 'mFscore']
+ if not set(metrics).issubset(set(allowed_metrics)):
+ raise KeyError('metrics {} is not supported'.format(metrics))
+
+ total_area_intersect, total_area_union, total_area_pred_label, \
+ total_area_label = total_intersect_and_union(
+ results, gt_seg_maps, num_classes, ignore_index, label_map,
+ reduce_zero_label)
+ all_acc = total_area_intersect.sum() / total_area_label.sum()
+ ret_metrics = OrderedDict({'aAcc': all_acc})
+ for metric in metrics:
+ if metric == 'mIoU':
+ iou = total_area_intersect / total_area_union
+ acc = total_area_intersect / total_area_label
+ ret_metrics['IoU'] = iou
+ ret_metrics['Acc'] = acc
+ elif metric == 'mDice':
+ dice = 2 * total_area_intersect / (
+ total_area_pred_label + total_area_label)
+ acc = total_area_intersect / total_area_label
+ ret_metrics['Dice'] = dice
+ ret_metrics['Acc'] = acc
+ elif metric == 'mFscore':
+ precision = total_area_intersect / total_area_pred_label
+ recall = total_area_intersect / total_area_label
+ f_value = torch.tensor(
+ [f_score(x[0], x[1], beta) for x in zip(precision, recall)])
+ ret_metrics['Fscore'] = f_value
+ ret_metrics['Precision'] = precision
+ ret_metrics['Recall'] = recall
+
+ ret_metrics = {
+ metric: value.numpy()
+ for metric, value in ret_metrics.items()
+ }
+ if nan_to_num is not None:
+ ret_metrics = OrderedDict({
+ metric: np.nan_to_num(metric_value, nan=nan_to_num)
+ for metric, metric_value in ret_metrics.items()
+ })
+ return ret_metrics
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/core/seg/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/seg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..93bc129b685e4a3efca2cc891729981b2865900d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/seg/__init__.py
@@ -0,0 +1,4 @@
+from .builder import build_pixel_sampler
+from .sampler import BasePixelSampler, OHEMPixelSampler
+
+__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/core/seg/builder.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/seg/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8fff6375622282f85b3acf15af1a7d27fb9c426
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/seg/builder.py
@@ -0,0 +1,8 @@
+from annotator.mmpkg.mmcv.utils import Registry, build_from_cfg
+
+PIXEL_SAMPLERS = Registry('pixel sampler')
+
+
+def build_pixel_sampler(cfg, **default_args):
+ """Build pixel sampler for segmentation map."""
+ return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..332b242c03d1c5e80d4577df442a9a037b1816e1
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/__init__.py
@@ -0,0 +1,4 @@
+from .base_pixel_sampler import BasePixelSampler
+from .ohem_pixel_sampler import OHEMPixelSampler
+
+__all__ = ['BasePixelSampler', 'OHEMPixelSampler']
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/base_pixel_sampler.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/base_pixel_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b75b1566c9f18169cee51d4b55d75e0357b69c57
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/base_pixel_sampler.py
@@ -0,0 +1,12 @@
+from abc import ABCMeta, abstractmethod
+
+
+class BasePixelSampler(metaclass=ABCMeta):
+ """Base class of pixel sampler."""
+
+ def __init__(self, **kwargs):
+ pass
+
+ @abstractmethod
+ def sample(self, seg_logit, seg_label):
+ """Placeholder for sample function."""
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/ohem_pixel_sampler.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/ohem_pixel_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..88bb10d44026ba9f21756eaea9e550841cd59b9f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/ohem_pixel_sampler.py
@@ -0,0 +1,76 @@
+import torch
+import torch.nn.functional as F
+
+from ..builder import PIXEL_SAMPLERS
+from .base_pixel_sampler import BasePixelSampler
+
+
+@PIXEL_SAMPLERS.register_module()
+class OHEMPixelSampler(BasePixelSampler):
+ """Online Hard Example Mining Sampler for segmentation.
+
+ Args:
+ context (nn.Module): The context of sampler, subclass of
+ :obj:`BaseDecodeHead`.
+ thresh (float, optional): The threshold for hard example selection.
+ Below which, are prediction with low confidence. If not
+ specified, the hard examples will be pixels of top ``min_kept``
+ loss. Default: None.
+ min_kept (int, optional): The minimum number of predictions to keep.
+ Default: 100000.
+ """
+
+ def __init__(self, context, thresh=None, min_kept=100000):
+ super(OHEMPixelSampler, self).__init__()
+ self.context = context
+ assert min_kept > 1
+ self.thresh = thresh
+ self.min_kept = min_kept
+
+ def sample(self, seg_logit, seg_label):
+ """Sample pixels that have high loss or with low prediction confidence.
+
+ Args:
+ seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
+ seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
+
+ Returns:
+ torch.Tensor: segmentation weight, shape (N, H, W)
+ """
+ with torch.no_grad():
+ assert seg_logit.shape[2:] == seg_label.shape[2:]
+ assert seg_label.shape[1] == 1
+ seg_label = seg_label.squeeze(1).long()
+ batch_kept = self.min_kept * seg_label.size(0)
+ valid_mask = seg_label != self.context.ignore_index
+ seg_weight = seg_logit.new_zeros(size=seg_label.size())
+ valid_seg_weight = seg_weight[valid_mask]
+ if self.thresh is not None:
+ seg_prob = F.softmax(seg_logit, dim=1)
+
+ tmp_seg_label = seg_label.clone().unsqueeze(1)
+ tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
+ seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
+ sort_prob, sort_indices = seg_prob[valid_mask].sort()
+
+ if sort_prob.numel() > 0:
+ min_threshold = sort_prob[min(batch_kept,
+ sort_prob.numel() - 1)]
+ else:
+ min_threshold = 0.0
+ threshold = max(min_threshold, self.thresh)
+ valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
+ else:
+ losses = self.context.loss_decode(
+ seg_logit,
+ seg_label,
+ weight=None,
+ ignore_index=self.context.ignore_index,
+ reduction_override='none')
+ # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
+ _, sort_indices = losses[valid_mask].sort(descending=True)
+ valid_seg_weight[sort_indices[:batch_kept]] = 1.
+
+ seg_weight[valid_mask] = valid_seg_weight
+
+ return seg_weight
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/core/utils/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2678b321c295bcceaef945111ac3524be19d6e4
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/utils/__init__.py
@@ -0,0 +1,3 @@
+from .misc import add_prefix
+
+__all__ = ['add_prefix']
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/core/utils/misc.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb862a82bd47c8624db3dd5c6fb6ad8a03b62466
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/core/utils/misc.py
@@ -0,0 +1,17 @@
+def add_prefix(inputs, prefix):
+ """Add prefix for dict.
+
+ Args:
+ inputs (dict): The input dict with str keys.
+ prefix (str): The prefix to add.
+
+ Returns:
+
+ dict: The dict with keys updated with ``prefix``.
+ """
+
+ outputs = dict()
+ for name, value in inputs.items():
+ outputs[f'{prefix}.{name}'] = value
+
+ return outputs
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebeaef4a28ef655e43578552a8aef6b77f13a636
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/__init__.py
@@ -0,0 +1,19 @@
+from .ade import ADE20KDataset
+from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
+from .chase_db1 import ChaseDB1Dataset
+from .cityscapes import CityscapesDataset
+from .custom import CustomDataset
+from .dataset_wrappers import ConcatDataset, RepeatDataset
+from .drive import DRIVEDataset
+from .hrf import HRFDataset
+from .pascal_context import PascalContextDataset, PascalContextDataset59
+from .stare import STAREDataset
+from .voc import PascalVOCDataset
+
+__all__ = [
+ 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
+ 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
+ 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
+ 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
+ 'STAREDataset'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/ade.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/ade.py
new file mode 100644
index 0000000000000000000000000000000000000000..5913e43775ed4920b6934c855eb5a37c54218ebf
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/ade.py
@@ -0,0 +1,84 @@
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class ADE20KDataset(CustomDataset):
+ """ADE20K dataset.
+
+ In segmentation map annotation for ADE20K, 0 stands for background, which
+ is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
+ The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
+ '.png'.
+ """
+ CLASSES = (
+ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
+ 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
+ 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
+ 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
+ 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
+ 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
+ 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
+ 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
+ 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
+ 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
+ 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
+ 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
+ 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
+ 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
+ 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
+ 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
+ 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
+ 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
+ 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
+ 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
+ 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
+ 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
+ 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
+ 'clock', 'flag')
+
+ PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+ [102, 255, 0], [92, 0, 255]]
+
+ def __init__(self, **kwargs):
+ super(ADE20KDataset, self).__init__(
+ img_suffix='.jpg',
+ seg_map_suffix='.png',
+ reduce_zero_label=True,
+ **kwargs)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/builder.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf8b4d9d32d4464905507cd54a84eb534f38bb6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/builder.py
@@ -0,0 +1,169 @@
+import copy
+import platform
+import random
+from functools import partial
+
+import numpy as np
+from annotator.mmpkg.mmcv.parallel import collate
+from annotator.mmpkg.mmcv.runner import get_dist_info
+from annotator.mmpkg.mmcv.utils import Registry, build_from_cfg
+from annotator.mmpkg.mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader
+from torch.utils.data import DistributedSampler
+
+if platform.system() != 'Windows':
+ # https://github.com/pytorch/pytorch/issues/973
+ import resource
+ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+ hard_limit = rlimit[1]
+ soft_limit = min(4096, hard_limit)
+ resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
+
+DATASETS = Registry('dataset')
+PIPELINES = Registry('pipeline')
+
+
+def _concat_dataset(cfg, default_args=None):
+ """Build :obj:`ConcatDataset by."""
+ from .dataset_wrappers import ConcatDataset
+ img_dir = cfg['img_dir']
+ ann_dir = cfg.get('ann_dir', None)
+ split = cfg.get('split', None)
+ num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
+ if ann_dir is not None:
+ num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
+ else:
+ num_ann_dir = 0
+ if split is not None:
+ num_split = len(split) if isinstance(split, (list, tuple)) else 1
+ else:
+ num_split = 0
+ if num_img_dir > 1:
+ assert num_img_dir == num_ann_dir or num_ann_dir == 0
+ assert num_img_dir == num_split or num_split == 0
+ else:
+ assert num_split == num_ann_dir or num_ann_dir <= 1
+ num_dset = max(num_split, num_img_dir)
+
+ datasets = []
+ for i in range(num_dset):
+ data_cfg = copy.deepcopy(cfg)
+ if isinstance(img_dir, (list, tuple)):
+ data_cfg['img_dir'] = img_dir[i]
+ if isinstance(ann_dir, (list, tuple)):
+ data_cfg['ann_dir'] = ann_dir[i]
+ if isinstance(split, (list, tuple)):
+ data_cfg['split'] = split[i]
+ datasets.append(build_dataset(data_cfg, default_args))
+
+ return ConcatDataset(datasets)
+
+
+def build_dataset(cfg, default_args=None):
+ """Build datasets."""
+ from .dataset_wrappers import ConcatDataset, RepeatDataset
+ if isinstance(cfg, (list, tuple)):
+ dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
+ elif cfg['type'] == 'RepeatDataset':
+ dataset = RepeatDataset(
+ build_dataset(cfg['dataset'], default_args), cfg['times'])
+ elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
+ cfg.get('split', None), (list, tuple)):
+ dataset = _concat_dataset(cfg, default_args)
+ else:
+ dataset = build_from_cfg(cfg, DATASETS, default_args)
+
+ return dataset
+
+
+def build_dataloader(dataset,
+ samples_per_gpu,
+ workers_per_gpu,
+ num_gpus=1,
+ dist=True,
+ shuffle=True,
+ seed=None,
+ drop_last=False,
+ pin_memory=True,
+ dataloader_type='PoolDataLoader',
+ **kwargs):
+ """Build PyTorch DataLoader.
+
+ In distributed training, each GPU/process has a dataloader.
+ In non-distributed training, there is only one dataloader for all GPUs.
+
+ Args:
+ dataset (Dataset): A PyTorch dataset.
+ samples_per_gpu (int): Number of training samples on each GPU, i.e.,
+ batch size of each GPU.
+ workers_per_gpu (int): How many subprocesses to use for data loading
+ for each GPU.
+ num_gpus (int): Number of GPUs. Only used in non-distributed training.
+ dist (bool): Distributed training/test or not. Default: True.
+ shuffle (bool): Whether to shuffle the data at every epoch.
+ Default: True.
+ seed (int | None): Seed to be used. Default: None.
+ drop_last (bool): Whether to drop the last incomplete batch in epoch.
+ Default: False
+ pin_memory (bool): Whether to use pin_memory in DataLoader.
+ Default: True
+ dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
+ kwargs: any keyword argument to be used to initialize DataLoader
+
+ Returns:
+ DataLoader: A PyTorch dataloader.
+ """
+ rank, world_size = get_dist_info()
+ if dist:
+ sampler = DistributedSampler(
+ dataset, world_size, rank, shuffle=shuffle)
+ shuffle = False
+ batch_size = samples_per_gpu
+ num_workers = workers_per_gpu
+ else:
+ sampler = None
+ batch_size = num_gpus * samples_per_gpu
+ num_workers = num_gpus * workers_per_gpu
+
+ init_fn = partial(
+ worker_init_fn, num_workers=num_workers, rank=rank,
+ seed=seed) if seed is not None else None
+
+ assert dataloader_type in (
+ 'DataLoader',
+ 'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'
+
+ if dataloader_type == 'PoolDataLoader':
+ dataloader = PoolDataLoader
+ elif dataloader_type == 'DataLoader':
+ dataloader = DataLoader
+
+ data_loader = dataloader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ num_workers=num_workers,
+ collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
+ pin_memory=pin_memory,
+ shuffle=shuffle,
+ worker_init_fn=init_fn,
+ drop_last=drop_last,
+ **kwargs)
+
+ return data_loader
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ """Worker init func for dataloader.
+
+ The seed of each worker equals to num_worker * rank + worker_id + user_seed
+
+ Args:
+ worker_id (int): Worker id.
+ num_workers (int): Number of workers.
+ rank (int): The rank of current process.
+ seed (int): The random seed to use.
+ """
+
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/chase_db1.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/chase_db1.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bc29bea14704a4407f83474610cbc3bef32c708
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/chase_db1.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class ChaseDB1Dataset(CustomDataset):
+ """Chase_db1 dataset.
+
+ In segmentation map annotation for Chase_db1, 0 stands for background,
+ which is included in 2 categories. ``reduce_zero_label`` is fixed to False.
+ The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '_1stHO.png'.
+ """
+
+ CLASSES = ('background', 'vessel')
+
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+ def __init__(self, **kwargs):
+ super(ChaseDB1Dataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='_1stHO.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/cityscapes.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/cityscapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..38f80e8043d25178cf5dac18911241c74be4e3ac
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/cityscapes.py
@@ -0,0 +1,217 @@
+import os.path as osp
+import tempfile
+
+import annotator.mmpkg.mmcv as mmcv
+import numpy as np
+from annotator.mmpkg.mmcv.utils import print_log
+from PIL import Image
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class CityscapesDataset(CustomDataset):
+ """Cityscapes dataset.
+
+ The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
+ fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
+ """
+
+ CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+ 'bicycle')
+
+ PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
+ [0, 80, 100], [0, 0, 230], [119, 11, 32]]
+
+ def __init__(self, **kwargs):
+ super(CityscapesDataset, self).__init__(
+ img_suffix='_leftImg8bit.png',
+ seg_map_suffix='_gtFine_labelTrainIds.png',
+ **kwargs)
+
+ @staticmethod
+ def _convert_to_label_id(result):
+ """Convert trainId to id for cityscapes."""
+ if isinstance(result, str):
+ result = np.load(result)
+ import cityscapesscripts.helpers.labels as CSLabels
+ result_copy = result.copy()
+ for trainId, label in CSLabels.trainId2label.items():
+ result_copy[result == trainId] = label.id
+
+ return result_copy
+
+ def results2img(self, results, imgfile_prefix, to_label_id):
+ """Write the segmentation results to images.
+
+ Args:
+ results (list[list | tuple | ndarray]): Testing results of the
+ dataset.
+ imgfile_prefix (str): The filename prefix of the png files.
+ If the prefix is "somepath/xxx",
+ the png files will be named "somepath/xxx.png".
+ to_label_id (bool): whether convert output to label_id for
+ submission
+
+ Returns:
+ list[str: str]: result txt files which contains corresponding
+ semantic segmentation images.
+ """
+ mmcv.mkdir_or_exist(imgfile_prefix)
+ result_files = []
+ prog_bar = mmcv.ProgressBar(len(self))
+ for idx in range(len(self)):
+ result = results[idx]
+ if to_label_id:
+ result = self._convert_to_label_id(result)
+ filename = self.img_infos[idx]['filename']
+ basename = osp.splitext(osp.basename(filename))[0]
+
+ png_filename = osp.join(imgfile_prefix, f'{basename}.png')
+
+ output = Image.fromarray(result.astype(np.uint8)).convert('P')
+ import cityscapesscripts.helpers.labels as CSLabels
+ palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
+ for label_id, label in CSLabels.id2label.items():
+ palette[label_id] = label.color
+
+ output.putpalette(palette)
+ output.save(png_filename)
+ result_files.append(png_filename)
+ prog_bar.update()
+
+ return result_files
+
+ def format_results(self, results, imgfile_prefix=None, to_label_id=True):
+ """Format the results into dir (standard format for Cityscapes
+ evaluation).
+
+ Args:
+ results (list): Testing results of the dataset.
+ imgfile_prefix (str | None): The prefix of images files. It
+ includes the file path and the prefix of filename, e.g.,
+ "a/b/prefix". If not specified, a temp file will be created.
+ Default: None.
+ to_label_id (bool): whether convert output to label_id for
+ submission. Default: False
+
+ Returns:
+ tuple: (result_files, tmp_dir), result_files is a list containing
+ the image paths, tmp_dir is the temporal directory created
+ for saving json/png files when img_prefix is not specified.
+ """
+
+ assert isinstance(results, list), 'results must be a list'
+ assert len(results) == len(self), (
+ 'The length of results is not equal to the dataset len: '
+ f'{len(results)} != {len(self)}')
+
+ if imgfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ imgfile_prefix = tmp_dir.name
+ else:
+ tmp_dir = None
+ result_files = self.results2img(results, imgfile_prefix, to_label_id)
+
+ return result_files, tmp_dir
+
+ def evaluate(self,
+ results,
+ metric='mIoU',
+ logger=None,
+ imgfile_prefix=None,
+ efficient_test=False):
+ """Evaluation in Cityscapes/default protocol.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ logger (logging.Logger | None | str): Logger used for printing
+ related information during evaluation. Default: None.
+ imgfile_prefix (str | None): The prefix of output image file,
+ for cityscapes evaluation only. It includes the file path and
+ the prefix of filename, e.g., "a/b/prefix".
+ If results are evaluated with cityscapes protocol, it would be
+ the prefix of output png files. The output files would be
+ png images under folder "a/b/prefix/xxx.png", where "xxx" is
+ the image name of cityscapes. If not specified, a temp file
+ will be created for evaluation.
+ Default: None.
+
+ Returns:
+ dict[str, float]: Cityscapes/default metrics.
+ """
+
+ eval_results = dict()
+ metrics = metric.copy() if isinstance(metric, list) else [metric]
+ if 'cityscapes' in metrics:
+ eval_results.update(
+ self._evaluate_cityscapes(results, logger, imgfile_prefix))
+ metrics.remove('cityscapes')
+ if len(metrics) > 0:
+ eval_results.update(
+ super(CityscapesDataset,
+ self).evaluate(results, metrics, logger, efficient_test))
+
+ return eval_results
+
+ def _evaluate_cityscapes(self, results, logger, imgfile_prefix):
+ """Evaluation in Cityscapes protocol.
+
+ Args:
+ results (list): Testing results of the dataset.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ imgfile_prefix (str | None): The prefix of output image file
+
+ Returns:
+ dict[str: float]: Cityscapes evaluation results.
+ """
+ try:
+ import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
+ except ImportError:
+ raise ImportError('Please run "pip install cityscapesscripts" to '
+ 'install cityscapesscripts first.')
+ msg = 'Evaluating in Cityscapes style'
+ if logger is None:
+ msg = '\n' + msg
+ print_log(msg, logger=logger)
+
+ result_files, tmp_dir = self.format_results(results, imgfile_prefix)
+
+ if tmp_dir is None:
+ result_dir = imgfile_prefix
+ else:
+ result_dir = tmp_dir.name
+
+ eval_results = dict()
+ print_log(f'Evaluating results under {result_dir} ...', logger=logger)
+
+ CSEval.args.evalInstLevelScore = True
+ CSEval.args.predictionPath = osp.abspath(result_dir)
+ CSEval.args.evalPixelAccuracy = True
+ CSEval.args.JSONOutput = False
+
+ seg_map_list = []
+ pred_list = []
+
+ # when evaluating with official cityscapesscripts,
+ # **_gtFine_labelIds.png is used
+ for seg_map in mmcv.scandir(
+ self.ann_dir, 'gtFine_labelIds.png', recursive=True):
+ seg_map_list.append(osp.join(self.ann_dir, seg_map))
+ pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
+
+ eval_results.update(
+ CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
+
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+
+ return eval_results
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/custom.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a626976c7fa88c3d1c1e871ef621422acc1be83
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/custom.py
@@ -0,0 +1,403 @@
+import os
+import os.path as osp
+from collections import OrderedDict
+from functools import reduce
+
+import annotator.mmpkg.mmcv as mmcv
+import numpy as np
+from annotator.mmpkg.mmcv.utils import print_log
+from torch.utils.data import Dataset
+
+from annotator.mmpkg.mmseg.core import eval_metrics
+from annotator.mmpkg.mmseg.utils import get_root_logger
+from .builder import DATASETS
+from .pipelines import Compose
+
+
+@DATASETS.register_module()
+class CustomDataset(Dataset):
+ """Custom dataset for semantic segmentation. An example of file structure
+ is as followed.
+
+ .. code-block:: none
+
+ ├── data
+ │ ├── my_dataset
+ │ │ ├── img_dir
+ │ │ │ ├── train
+ │ │ │ │ ├── xxx{img_suffix}
+ │ │ │ │ ├── yyy{img_suffix}
+ │ │ │ │ ├── zzz{img_suffix}
+ │ │ │ ├── val
+ │ │ ├── ann_dir
+ │ │ │ ├── train
+ │ │ │ │ ├── xxx{seg_map_suffix}
+ │ │ │ │ ├── yyy{seg_map_suffix}
+ │ │ │ │ ├── zzz{seg_map_suffix}
+ │ │ │ ├── val
+
+ The img/gt_semantic_seg pair of CustomDataset should be of the same
+ except suffix. A valid img/gt_semantic_seg filename pair should be like
+ ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
+ in the suffix). If split is given, then ``xxx`` is specified in txt file.
+ Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
+ Please refer to ``docs/tutorials/new_dataset.md`` for more details.
+
+
+ Args:
+ pipeline (list[dict]): Processing pipeline
+ img_dir (str): Path to image directory
+ img_suffix (str): Suffix of images. Default: '.jpg'
+ ann_dir (str, optional): Path to annotation directory. Default: None
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+ split (str, optional): Split txt file. If split is specified, only
+ file with suffix in the splits will be loaded. Otherwise, all
+ images in img_dir/ann_dir will be loaded. Default: None
+ data_root (str, optional): Data root for img_dir/ann_dir. Default:
+ None.
+ test_mode (bool): If test_mode=True, gt wouldn't be loaded.
+ ignore_index (int): The label index to be ignored. Default: 255
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
+ Default: False
+ classes (str | Sequence[str], optional): Specify classes to load.
+ If is None, ``cls.CLASSES`` will be used. Default: None.
+ palette (Sequence[Sequence[int]]] | np.ndarray | None):
+ The palette of segmentation map. If None is given, and
+ self.PALETTE is None, random palette will be generated.
+ Default: None
+ """
+
+ CLASSES = None
+
+ PALETTE = None
+
+ def __init__(self,
+ pipeline,
+ img_dir,
+ img_suffix='.jpg',
+ ann_dir=None,
+ seg_map_suffix='.png',
+ split=None,
+ data_root=None,
+ test_mode=False,
+ ignore_index=255,
+ reduce_zero_label=False,
+ classes=None,
+ palette=None):
+ self.pipeline = Compose(pipeline)
+ self.img_dir = img_dir
+ self.img_suffix = img_suffix
+ self.ann_dir = ann_dir
+ self.seg_map_suffix = seg_map_suffix
+ self.split = split
+ self.data_root = data_root
+ self.test_mode = test_mode
+ self.ignore_index = ignore_index
+ self.reduce_zero_label = reduce_zero_label
+ self.label_map = None
+ self.CLASSES, self.PALETTE = self.get_classes_and_palette(
+ classes, palette)
+
+ # join paths if data_root is specified
+ if self.data_root is not None:
+ if not osp.isabs(self.img_dir):
+ self.img_dir = osp.join(self.data_root, self.img_dir)
+ if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
+ self.ann_dir = osp.join(self.data_root, self.ann_dir)
+ if not (self.split is None or osp.isabs(self.split)):
+ self.split = osp.join(self.data_root, self.split)
+
+ # load annotations
+ self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
+ self.ann_dir,
+ self.seg_map_suffix, self.split)
+
+ def __len__(self):
+ """Total number of samples of data."""
+ return len(self.img_infos)
+
+ def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
+ split):
+ """Load annotation from directory.
+
+ Args:
+ img_dir (str): Path to image directory
+ img_suffix (str): Suffix of images.
+ ann_dir (str|None): Path to annotation directory.
+ seg_map_suffix (str|None): Suffix of segmentation maps.
+ split (str|None): Split txt file. If split is specified, only file
+ with suffix in the splits will be loaded. Otherwise, all images
+ in img_dir/ann_dir will be loaded. Default: None
+
+ Returns:
+ list[dict]: All image info of dataset.
+ """
+
+ img_infos = []
+ if split is not None:
+ with open(split) as f:
+ for line in f:
+ img_name = line.strip()
+ img_info = dict(filename=img_name + img_suffix)
+ if ann_dir is not None:
+ seg_map = img_name + seg_map_suffix
+ img_info['ann'] = dict(seg_map=seg_map)
+ img_infos.append(img_info)
+ else:
+ for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
+ img_info = dict(filename=img)
+ if ann_dir is not None:
+ seg_map = img.replace(img_suffix, seg_map_suffix)
+ img_info['ann'] = dict(seg_map=seg_map)
+ img_infos.append(img_info)
+
+ print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
+ return img_infos
+
+ def get_ann_info(self, idx):
+ """Get annotation by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+
+ return self.img_infos[idx]['ann']
+
+ def pre_pipeline(self, results):
+ """Prepare results dict for pipeline."""
+ results['seg_fields'] = []
+ results['img_prefix'] = self.img_dir
+ results['seg_prefix'] = self.ann_dir
+ if self.custom_classes:
+ results['label_map'] = self.label_map
+
+ def __getitem__(self, idx):
+ """Get training/test data after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Training/test data (with annotation if `test_mode` is set
+ False).
+ """
+
+ if self.test_mode:
+ return self.prepare_test_img(idx)
+ else:
+ return self.prepare_train_img(idx)
+
+ def prepare_train_img(self, idx):
+ """Get training data and annotations after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Training data and annotation after pipeline with new keys
+ introduced by pipeline.
+ """
+
+ img_info = self.img_infos[idx]
+ ann_info = self.get_ann_info(idx)
+ results = dict(img_info=img_info, ann_info=ann_info)
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ def prepare_test_img(self, idx):
+ """Get testing data after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Testing data after pipeline with new keys introduced by
+ pipeline.
+ """
+
+ img_info = self.img_infos[idx]
+ results = dict(img_info=img_info)
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ def format_results(self, results, **kwargs):
+ """Place holder to format result to dataset specific output."""
+
+ def get_gt_seg_maps(self, efficient_test=False):
+ """Get ground truth segmentation maps for evaluation."""
+ gt_seg_maps = []
+ for img_info in self.img_infos:
+ seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
+ if efficient_test:
+ gt_seg_map = seg_map
+ else:
+ gt_seg_map = mmcv.imread(
+ seg_map, flag='unchanged', backend='pillow')
+ gt_seg_maps.append(gt_seg_map)
+ return gt_seg_maps
+
+ def get_classes_and_palette(self, classes=None, palette=None):
+ """Get class names of current dataset.
+
+ Args:
+ classes (Sequence[str] | str | None): If classes is None, use
+ default CLASSES defined by builtin dataset. If classes is a
+ string, take it as a file name. The file contains the name of
+ classes where each line contains one class name. If classes is
+ a tuple or list, override the CLASSES defined by the dataset.
+ palette (Sequence[Sequence[int]]] | np.ndarray | None):
+ The palette of segmentation map. If None is given, random
+ palette will be generated. Default: None
+ """
+ if classes is None:
+ self.custom_classes = False
+ return self.CLASSES, self.PALETTE
+
+ self.custom_classes = True
+ if isinstance(classes, str):
+ # take it as a file path
+ class_names = mmcv.list_from_file(classes)
+ elif isinstance(classes, (tuple, list)):
+ class_names = classes
+ else:
+ raise ValueError(f'Unsupported type {type(classes)} of classes.')
+
+ if self.CLASSES:
+ if not set(classes).issubset(self.CLASSES):
+ raise ValueError('classes is not a subset of CLASSES.')
+
+ # dictionary, its keys are the old label ids and its values
+ # are the new label ids.
+ # used for changing pixel labels in load_annotations.
+ self.label_map = {}
+ for i, c in enumerate(self.CLASSES):
+ if c not in class_names:
+ self.label_map[i] = -1
+ else:
+ self.label_map[i] = classes.index(c)
+
+ palette = self.get_palette_for_custom_classes(class_names, palette)
+
+ return class_names, palette
+
+ def get_palette_for_custom_classes(self, class_names, palette=None):
+
+ if self.label_map is not None:
+ # return subset of palette
+ palette = []
+ for old_id, new_id in sorted(
+ self.label_map.items(), key=lambda x: x[1]):
+ if new_id != -1:
+ palette.append(self.PALETTE[old_id])
+ palette = type(self.PALETTE)(palette)
+
+ elif palette is None:
+ if self.PALETTE is None:
+ palette = np.random.randint(0, 255, size=(len(class_names), 3))
+ else:
+ palette = self.PALETTE
+
+ return palette
+
+ def evaluate(self,
+ results,
+ metric='mIoU',
+ logger=None,
+ efficient_test=False,
+ **kwargs):
+ """Evaluate the dataset.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. 'mIoU',
+ 'mDice' and 'mFscore' are supported.
+ logger (logging.Logger | None | str): Logger used for printing
+ related information during evaluation. Default: None.
+
+ Returns:
+ dict[str, float]: Default metrics.
+ """
+
+ if isinstance(metric, str):
+ metric = [metric]
+ allowed_metrics = ['mIoU', 'mDice', 'mFscore']
+ if not set(metric).issubset(set(allowed_metrics)):
+ raise KeyError('metric {} is not supported'.format(metric))
+ eval_results = {}
+ gt_seg_maps = self.get_gt_seg_maps(efficient_test)
+ if self.CLASSES is None:
+ num_classes = len(
+ reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
+ else:
+ num_classes = len(self.CLASSES)
+ ret_metrics = eval_metrics(
+ results,
+ gt_seg_maps,
+ num_classes,
+ self.ignore_index,
+ metric,
+ label_map=self.label_map,
+ reduce_zero_label=self.reduce_zero_label)
+
+ if self.CLASSES is None:
+ class_names = tuple(range(num_classes))
+ else:
+ class_names = self.CLASSES
+
+ # summary table
+ ret_metrics_summary = OrderedDict({
+ ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
+ for ret_metric, ret_metric_value in ret_metrics.items()
+ })
+
+ # each class table
+ ret_metrics.pop('aAcc', None)
+ ret_metrics_class = OrderedDict({
+ ret_metric: np.round(ret_metric_value * 100, 2)
+ for ret_metric, ret_metric_value in ret_metrics.items()
+ })
+ ret_metrics_class.update({'Class': class_names})
+ ret_metrics_class.move_to_end('Class', last=False)
+
+ try:
+ from prettytable import PrettyTable
+ # for logger
+ class_table_data = PrettyTable()
+ for key, val in ret_metrics_class.items():
+ class_table_data.add_column(key, val)
+
+ summary_table_data = PrettyTable()
+ for key, val in ret_metrics_summary.items():
+ if key == 'aAcc':
+ summary_table_data.add_column(key, [val])
+ else:
+ summary_table_data.add_column('m' + key, [val])
+
+ print_log('per class results:', logger)
+ print_log('\n' + class_table_data.get_string(), logger=logger)
+ print_log('Summary:', logger)
+ print_log('\n' + summary_table_data.get_string(), logger=logger)
+ except ImportError: # prettytable is not installed
+ pass
+
+ # each metric dict
+ for key, value in ret_metrics_summary.items():
+ if key == 'aAcc':
+ eval_results[key] = value / 100.0
+ else:
+ eval_results['m' + key] = value / 100.0
+
+ ret_metrics_class.pop('Class', None)
+ for key, value in ret_metrics_class.items():
+ eval_results.update({
+ key + '.' + str(name): value[idx] / 100.0
+ for idx, name in enumerate(class_names)
+ })
+
+ if mmcv.is_list_of(results, str):
+ for file_name in results:
+ os.remove(file_name)
+ return eval_results
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/dataset_wrappers.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/dataset_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6a5e957ec3b44465432617cf6e8f0b86a8a5efa
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/dataset_wrappers.py
@@ -0,0 +1,50 @@
+from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
+
+from .builder import DATASETS
+
+
+@DATASETS.register_module()
+class ConcatDataset(_ConcatDataset):
+ """A wrapper of concatenated dataset.
+
+ Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
+ concat the group flag for image aspect ratio.
+
+ Args:
+ datasets (list[:obj:`Dataset`]): A list of datasets.
+ """
+
+ def __init__(self, datasets):
+ super(ConcatDataset, self).__init__(datasets)
+ self.CLASSES = datasets[0].CLASSES
+ self.PALETTE = datasets[0].PALETTE
+
+
+@DATASETS.register_module()
+class RepeatDataset(object):
+ """A wrapper of repeated dataset.
+
+ The length of repeated dataset will be `times` larger than the original
+ dataset. This is useful when the data loading time is long but the dataset
+ is small. Using RepeatDataset can reduce the data loading time between
+ epochs.
+
+ Args:
+ dataset (:obj:`Dataset`): The dataset to be repeated.
+ times (int): Repeat times.
+ """
+
+ def __init__(self, dataset, times):
+ self.dataset = dataset
+ self.times = times
+ self.CLASSES = dataset.CLASSES
+ self.PALETTE = dataset.PALETTE
+ self._ori_len = len(self.dataset)
+
+ def __getitem__(self, idx):
+ """Get item from original dataset."""
+ return self.dataset[idx % self._ori_len]
+
+ def __len__(self):
+ """The length is multiplied by ``times``"""
+ return self.times * self._ori_len
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/drive.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/drive.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbfda8ae74bdf26c5aef197ff2866a7c7ad0cfd
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/drive.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class DRIVEDataset(CustomDataset):
+ """DRIVE dataset.
+
+ In segmentation map annotation for DRIVE, 0 stands for background, which is
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '_manual1.png'.
+ """
+
+ CLASSES = ('background', 'vessel')
+
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+ def __init__(self, **kwargs):
+ super(DRIVEDataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='_manual1.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/hrf.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/hrf.py
new file mode 100644
index 0000000000000000000000000000000000000000..923203b51377f9344277fc561803d7a78bd2c684
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/hrf.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class HRFDataset(CustomDataset):
+ """HRF dataset.
+
+ In segmentation map annotation for HRF, 0 stands for background, which is
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '.png'.
+ """
+
+ CLASSES = ('background', 'vessel')
+
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+ def __init__(self, **kwargs):
+ super(HRFDataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pascal_context.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pascal_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..541a63c66a13fb16fd52921e755715ad8d078fdd
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pascal_context.py
@@ -0,0 +1,103 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class PascalContextDataset(CustomDataset):
+ """PascalContext dataset.
+
+ In segmentation map annotation for PascalContext, 0 stands for background,
+ which is included in 60 categories. ``reduce_zero_label`` is fixed to
+ False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
+ fixed to '.png'.
+
+ Args:
+ split (str): Split txt file for PascalContext.
+ """
+
+ CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench',
+ 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
+ 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
+ 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
+ 'floor', 'flower', 'food', 'grass', 'ground', 'horse',
+ 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person',
+ 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
+ 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table',
+ 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water',
+ 'window', 'wood')
+
+ PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]
+
+ def __init__(self, split, **kwargs):
+ super(PascalContextDataset, self).__init__(
+ img_suffix='.jpg',
+ seg_map_suffix='.png',
+ split=split,
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir) and self.split is not None
+
+
+@DATASETS.register_module()
+class PascalContextDataset59(CustomDataset):
+ """PascalContext dataset.
+
+ In segmentation map annotation for PascalContext, 0 stands for background,
+ which is included in 60 categories. ``reduce_zero_label`` is fixed to
+ False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
+ fixed to '.png'.
+
+ Args:
+ split (str): Split txt file for PascalContext.
+ """
+
+ CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
+ 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet',
+ 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow',
+ 'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower',
+ 'food', 'grass', 'ground', 'horse', 'keyboard', 'light',
+ 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform',
+ 'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk',
+ 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train',
+ 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood')
+
+ PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
+ [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
+ [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
+ [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
+ [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
+ [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
+ [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
+ [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
+ [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
+ [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
+ [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
+ [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
+ [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
+ [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
+ [0, 235, 255], [0, 173, 255], [31, 0, 255]]
+
+ def __init__(self, split, **kwargs):
+ super(PascalContextDataset59, self).__init__(
+ img_suffix='.jpg',
+ seg_map_suffix='.png',
+ split=split,
+ reduce_zero_label=True,
+ **kwargs)
+ assert osp.exists(self.img_dir) and self.split is not None
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b9046b07bb4ddea7a707a392b42e72db7c9df67
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/__init__.py
@@ -0,0 +1,16 @@
+from .compose import Compose
+from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,
+ Transpose, to_tensor)
+from .loading import LoadAnnotations, LoadImageFromFile
+from .test_time_aug import MultiScaleFlipAug
+from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
+ PhotoMetricDistortion, RandomCrop, RandomFlip,
+ RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)
+
+__all__ = [
+ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
+ 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
+ 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
+ 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
+ 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/compose.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/compose.py
new file mode 100644
index 0000000000000000000000000000000000000000..1683e533237ce6420e4a53e477513853d6b33b3e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/compose.py
@@ -0,0 +1,51 @@
+import collections
+
+from annotator.mmpkg.mmcv.utils import build_from_cfg
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class Compose(object):
+ """Compose multiple transforms sequentially.
+
+ Args:
+ transforms (Sequence[dict | callable]): Sequence of transform object or
+ config dict to be composed.
+ """
+
+ def __init__(self, transforms):
+ assert isinstance(transforms, collections.abc.Sequence)
+ self.transforms = []
+ for transform in transforms:
+ if isinstance(transform, dict):
+ transform = build_from_cfg(transform, PIPELINES)
+ self.transforms.append(transform)
+ elif callable(transform):
+ self.transforms.append(transform)
+ else:
+ raise TypeError('transform must be callable or a dict')
+
+ def __call__(self, data):
+ """Call function to apply transforms sequentially.
+
+ Args:
+ data (dict): A result dict contains the data to transform.
+
+ Returns:
+ dict: Transformed data.
+ """
+
+ for t in self.transforms:
+ data = t(data)
+ if data is None:
+ return None
+ return data
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ for t in self.transforms:
+ format_string += '\n'
+ format_string += f' {t}'
+ format_string += '\n)'
+ return format_string
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/formating.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/formating.py
new file mode 100644
index 0000000000000000000000000000000000000000..82e2e08ff819506bb7a7693be189017d473e677f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/formating.py
@@ -0,0 +1,288 @@
+from collections.abc import Sequence
+
+import annotator.mmpkg.mmcv as mmcv
+import numpy as np
+import torch
+from annotator.mmpkg.mmcv.parallel import DataContainer as DC
+
+from ..builder import PIPELINES
+
+
+def to_tensor(data):
+ """Convert objects of various python types to :obj:`torch.Tensor`.
+
+ Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
+ :class:`Sequence`, :class:`int` and :class:`float`.
+
+ Args:
+ data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
+ be converted.
+ """
+
+ if isinstance(data, torch.Tensor):
+ return data
+ elif isinstance(data, np.ndarray):
+ return torch.from_numpy(data)
+ elif isinstance(data, Sequence) and not mmcv.is_str(data):
+ return torch.tensor(data)
+ elif isinstance(data, int):
+ return torch.LongTensor([data])
+ elif isinstance(data, float):
+ return torch.FloatTensor([data])
+ else:
+ raise TypeError(f'type {type(data)} cannot be converted to tensor.')
+
+
+@PIPELINES.register_module()
+class ToTensor(object):
+ """Convert some results to :obj:`torch.Tensor` by given keys.
+
+ Args:
+ keys (Sequence[str]): Keys that need to be converted to Tensor.
+ """
+
+ def __init__(self, keys):
+ self.keys = keys
+
+ def __call__(self, results):
+ """Call function to convert data in results to :obj:`torch.Tensor`.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data converted
+ to :obj:`torch.Tensor`.
+ """
+
+ for key in self.keys:
+ results[key] = to_tensor(results[key])
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class ImageToTensor(object):
+ """Convert image to :obj:`torch.Tensor` by given keys.
+
+ The dimension order of input image is (H, W, C). The pipeline will convert
+ it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
+ (1, H, W).
+
+ Args:
+ keys (Sequence[str]): Key of images to be converted to Tensor.
+ """
+
+ def __init__(self, keys):
+ self.keys = keys
+
+ def __call__(self, results):
+ """Call function to convert image in results to :obj:`torch.Tensor` and
+ transpose the channel order.
+
+ Args:
+ results (dict): Result dict contains the image data to convert.
+
+ Returns:
+ dict: The result dict contains the image converted
+ to :obj:`torch.Tensor` and transposed to (C, H, W) order.
+ """
+
+ for key in self.keys:
+ img = results[key]
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ results[key] = to_tensor(img.transpose(2, 0, 1))
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class Transpose(object):
+ """Transpose some results by given keys.
+
+ Args:
+ keys (Sequence[str]): Keys of results to be transposed.
+ order (Sequence[int]): Order of transpose.
+ """
+
+ def __init__(self, keys, order):
+ self.keys = keys
+ self.order = order
+
+ def __call__(self, results):
+ """Call function to convert image in results to :obj:`torch.Tensor` and
+ transpose the channel order.
+
+ Args:
+ results (dict): Result dict contains the image data to convert.
+
+ Returns:
+ dict: The result dict contains the image converted
+ to :obj:`torch.Tensor` and transposed to (C, H, W) order.
+ """
+
+ for key in self.keys:
+ results[key] = results[key].transpose(self.order)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, order={self.order})'
+
+
+@PIPELINES.register_module()
+class ToDataContainer(object):
+ """Convert results to :obj:`mmcv.DataContainer` by given fields.
+
+ Args:
+ fields (Sequence[dict]): Each field is a dict like
+ ``dict(key='xxx', **kwargs)``. The ``key`` in result will
+ be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
+ Default: ``(dict(key='img', stack=True),
+ dict(key='gt_semantic_seg'))``.
+ """
+
+ def __init__(self,
+ fields=(dict(key='img',
+ stack=True), dict(key='gt_semantic_seg'))):
+ self.fields = fields
+
+ def __call__(self, results):
+ """Call function to convert data in results to
+ :obj:`mmcv.DataContainer`.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data converted to
+ :obj:`mmcv.DataContainer`.
+ """
+
+ for field in self.fields:
+ field = field.copy()
+ key = field.pop('key')
+ results[key] = DC(results[key], **field)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(fields={self.fields})'
+
+
+@PIPELINES.register_module()
+class DefaultFormatBundle(object):
+ """Default formatting bundle.
+
+ It simplifies the pipeline of formatting common fields, including "img"
+ and "gt_semantic_seg". These fields are formatted as follows.
+
+ - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
+ - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor,
+ (3)to DataContainer (stack=True)
+ """
+
+ def __call__(self, results):
+ """Call function to transform and format common fields in results.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data that is formatted with
+ default bundle.
+ """
+
+ if 'img' in results:
+ img = results['img']
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ img = np.ascontiguousarray(img.transpose(2, 0, 1))
+ results['img'] = DC(to_tensor(img), stack=True)
+ if 'gt_semantic_seg' in results:
+ # convert to long
+ results['gt_semantic_seg'] = DC(
+ to_tensor(results['gt_semantic_seg'][None,
+ ...].astype(np.int64)),
+ stack=True)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__
+
+
+@PIPELINES.register_module()
+class Collect(object):
+ """Collect data from the loader relevant to the specific task.
+
+ This is usually the last stage of the data loader pipeline. Typically keys
+ is set to some subset of "img", "gt_semantic_seg".
+
+ The "img_meta" item is always populated. The contents of the "img_meta"
+ dictionary depends on "meta_keys". By default this includes:
+
+ - "img_shape": shape of the image input to the network as a tuple
+ (h, w, c). Note that images may be zero padded on the bottom/right
+ if the batch tensor is larger than this shape.
+
+ - "scale_factor": a float indicating the preprocessing scale
+
+ - "flip": a boolean indicating if image flip transform was used
+
+ - "filename": path to the image file
+
+ - "ori_shape": original shape of the image as a tuple (h, w, c)
+
+ - "pad_shape": image shape after padding
+
+ - "img_norm_cfg": a dict of normalization information:
+ - mean - per channel mean subtraction
+ - std - per channel std divisor
+ - to_rgb - bool indicating if bgr was converted to rgb
+
+ Args:
+ keys (Sequence[str]): Keys of results to be collected in ``data``.
+ meta_keys (Sequence[str], optional): Meta keys to be converted to
+ ``mmcv.DataContainer`` and collected in ``data[img_metas]``.
+ Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'img_norm_cfg')``
+ """
+
+ def __init__(self,
+ keys,
+ meta_keys=('filename', 'ori_filename', 'ori_shape',
+ 'img_shape', 'pad_shape', 'scale_factor', 'flip',
+ 'flip_direction', 'img_norm_cfg')):
+ self.keys = keys
+ self.meta_keys = meta_keys
+
+ def __call__(self, results):
+ """Call function to collect keys in results. The keys in ``meta_keys``
+ will be converted to :obj:mmcv.DataContainer.
+
+ Args:
+ results (dict): Result dict contains the data to collect.
+
+ Returns:
+ dict: The result dict contains the following keys
+ - keys in``self.keys``
+ - ``img_metas``
+ """
+
+ data = {}
+ img_meta = {}
+ for key in self.meta_keys:
+ img_meta[key] = results[key]
+ data['img_metas'] = DC(img_meta, cpu_only=True)
+ for key in self.keys:
+ data[key] = results[key]
+ return data
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, meta_keys={self.meta_keys})'
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/loading.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ad8c2cb67cb1d2b593217fb1fb2e0ca5834c24f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/loading.py
@@ -0,0 +1,153 @@
+import os.path as osp
+
+import annotator.mmpkg.mmcv as mmcv
+import numpy as np
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class LoadImageFromFile(object):
+ """Load an image from file.
+
+ Required keys are "img_prefix" and "img_info" (a dict that must contain the
+ key "filename"). Added or updated keys are "filename", "img", "img_shape",
+ "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
+ "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
+
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image is an uint8 array.
+ Defaults to False.
+ color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
+ Defaults to 'color'.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
+ 'cv2'
+ """
+
+ def __init__(self,
+ to_float32=False,
+ color_type='color',
+ file_client_args=dict(backend='disk'),
+ imdecode_backend='cv2'):
+ self.to_float32 = to_float32
+ self.color_type = color_type
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+ self.imdecode_backend = imdecode_backend
+
+ def __call__(self, results):
+ """Call functions to load image and get image meta information.
+
+ Args:
+ results (dict): Result dict from :obj:`mmseg.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded image and meta information.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ if results.get('img_prefix') is not None:
+ filename = osp.join(results['img_prefix'],
+ results['img_info']['filename'])
+ else:
+ filename = results['img_info']['filename']
+ img_bytes = self.file_client.get(filename)
+ img = mmcv.imfrombytes(
+ img_bytes, flag=self.color_type, backend=self.imdecode_backend)
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['filename'] = filename
+ results['ori_filename'] = results['img_info']['filename']
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ # Set initial values for default meta_keys
+ results['pad_shape'] = img.shape
+ results['scale_factor'] = 1.0
+ num_channels = 1 if len(img.shape) < 3 else img.shape[2]
+ results['img_norm_cfg'] = dict(
+ mean=np.zeros(num_channels, dtype=np.float32),
+ std=np.ones(num_channels, dtype=np.float32),
+ to_rgb=False)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(to_float32={self.to_float32},'
+ repr_str += f"color_type='{self.color_type}',"
+ repr_str += f"imdecode_backend='{self.imdecode_backend}')"
+ return repr_str
+
+
+@PIPELINES.register_module()
+class LoadAnnotations(object):
+ """Load annotations for semantic segmentation.
+
+ Args:
+ reduce_zero_label (bool): Whether reduce all label value by 1.
+ Usually used for datasets where 0 is background label.
+ Default: False.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
+ 'pillow'
+ """
+
+ def __init__(self,
+ reduce_zero_label=False,
+ file_client_args=dict(backend='disk'),
+ imdecode_backend='pillow'):
+ self.reduce_zero_label = reduce_zero_label
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+ self.imdecode_backend = imdecode_backend
+
+ def __call__(self, results):
+ """Call function to load multiple types annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmseg.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded semantic segmentation annotations.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ if results.get('seg_prefix', None) is not None:
+ filename = osp.join(results['seg_prefix'],
+ results['ann_info']['seg_map'])
+ else:
+ filename = results['ann_info']['seg_map']
+ img_bytes = self.file_client.get(filename)
+ gt_semantic_seg = mmcv.imfrombytes(
+ img_bytes, flag='unchanged',
+ backend=self.imdecode_backend).squeeze().astype(np.uint8)
+ # modify if custom classes
+ if results.get('label_map', None) is not None:
+ for old_id, new_id in results['label_map'].items():
+ gt_semantic_seg[gt_semantic_seg == old_id] = new_id
+ # reduce zero_label
+ if self.reduce_zero_label:
+ # avoid using underflow conversion
+ gt_semantic_seg[gt_semantic_seg == 0] = 255
+ gt_semantic_seg = gt_semantic_seg - 1
+ gt_semantic_seg[gt_semantic_seg == 254] = 255
+ results['gt_semantic_seg'] = gt_semantic_seg
+ results['seg_fields'].append('gt_semantic_seg')
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(reduce_zero_label={self.reduce_zero_label},'
+ repr_str += f"imdecode_backend='{self.imdecode_backend}')"
+ return repr_str
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/test_time_aug.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/test_time_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb781d928ed71aceb1abcaef44d3889c00d2261e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/test_time_aug.py
@@ -0,0 +1,133 @@
+import warnings
+
+import annotator.mmpkg.mmcv as mmcv
+
+from ..builder import PIPELINES
+from .compose import Compose
+
+
+@PIPELINES.register_module()
+class MultiScaleFlipAug(object):
+ """Test-time augmentation with multiple scales and flipping.
+
+ An example configuration is as followed:
+
+ .. code-block::
+
+ img_scale=(2048, 1024),
+ img_ratios=[0.5, 1.0],
+ flip=True,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ]
+
+ After MultiScaleFLipAug with above configuration, the results are wrapped
+ into lists of the same length as followed:
+
+ .. code-block::
+
+ dict(
+ img=[...],
+ img_shape=[...],
+ scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)]
+ flip=[False, True, False, True]
+ ...
+ )
+
+ Args:
+ transforms (list[dict]): Transforms to apply in each augmentation.
+ img_scale (None | tuple | list[tuple]): Images scales for resizing.
+ img_ratios (float | list[float]): Image ratios for resizing
+ flip (bool): Whether apply flip augmentation. Default: False.
+ flip_direction (str | list[str]): Flip augmentation directions,
+ options are "horizontal" and "vertical". If flip_direction is list,
+ multiple flip augmentations will be applied.
+ It has no effect when flip == False. Default: "horizontal".
+ """
+
+ def __init__(self,
+ transforms,
+ img_scale,
+ img_ratios=None,
+ flip=False,
+ flip_direction='horizontal'):
+ self.transforms = Compose(transforms)
+ if img_ratios is not None:
+ img_ratios = img_ratios if isinstance(img_ratios,
+ list) else [img_ratios]
+ assert mmcv.is_list_of(img_ratios, float)
+ if img_scale is None:
+ # mode 1: given img_scale=None and a range of image ratio
+ self.img_scale = None
+ assert mmcv.is_list_of(img_ratios, float)
+ elif isinstance(img_scale, tuple) and mmcv.is_list_of(
+ img_ratios, float):
+ assert len(img_scale) == 2
+ # mode 2: given a scale and a range of image ratio
+ self.img_scale = [(int(img_scale[0] * ratio),
+ int(img_scale[1] * ratio))
+ for ratio in img_ratios]
+ else:
+ # mode 3: given multiple scales
+ self.img_scale = img_scale if isinstance(img_scale,
+ list) else [img_scale]
+ assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None
+ self.flip = flip
+ self.img_ratios = img_ratios
+ self.flip_direction = flip_direction if isinstance(
+ flip_direction, list) else [flip_direction]
+ assert mmcv.is_list_of(self.flip_direction, str)
+ if not self.flip and self.flip_direction != ['horizontal']:
+ warnings.warn(
+ 'flip_direction has no effect when flip is set to False')
+ if (self.flip
+ and not any([t['type'] == 'RandomFlip' for t in transforms])):
+ warnings.warn(
+ 'flip has no effect when RandomFlip is not in transforms')
+
+ def __call__(self, results):
+ """Call function to apply test time augment transforms on results.
+
+ Args:
+ results (dict): Result dict contains the data to transform.
+
+ Returns:
+ dict[str: list]: The augmented data, where each value is wrapped
+ into a list.
+ """
+
+ aug_data = []
+ if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
+ h, w = results['img'].shape[:2]
+ img_scale = [(int(w * ratio), int(h * ratio))
+ for ratio in self.img_ratios]
+ else:
+ img_scale = self.img_scale
+ flip_aug = [False, True] if self.flip else [False]
+ for scale in img_scale:
+ for flip in flip_aug:
+ for direction in self.flip_direction:
+ _results = results.copy()
+ _results['scale'] = scale
+ _results['flip'] = flip
+ _results['flip_direction'] = direction
+ data = self.transforms(_results)
+ aug_data.append(data)
+ # list of dict to dict of list
+ aug_data_dict = {key: [] for key in aug_data[0]}
+ for data in aug_data:
+ for key, val in data.items():
+ aug_data_dict[key].append(val)
+ return aug_data_dict
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(transforms={self.transforms}, '
+ repr_str += f'img_scale={self.img_scale}, flip={self.flip})'
+ repr_str += f'flip_direction={self.flip_direction}'
+ return repr_str
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/transforms.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..842763db97685dd9280424204d62ee65993fdd5a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/transforms.py
@@ -0,0 +1,889 @@
+import annotator.mmpkg.mmcv as mmcv
+import numpy as np
+from annotator.mmpkg.mmcv.utils import deprecated_api_warning, is_tuple_of
+from numpy import random
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class Resize(object):
+ """Resize images & seg.
+
+ This transform resizes the input image to some scale. If the input dict
+ contains the key "scale", then the scale in the input dict is used,
+ otherwise the specified scale in the init method is used.
+
+ ``img_scale`` can be None, a tuple (single-scale) or a list of tuple
+ (multi-scale). There are 4 multiscale modes:
+
+ - ``ratio_range is not None``:
+ 1. When img_scale is None, img_scale is the shape of image in results
+ (img_scale = results['img'].shape[:2]) and the image is resized based
+ on the original size. (mode 1)
+ 2. When img_scale is a tuple (single-scale), randomly sample a ratio from
+ the ratio range and multiply it with the image scale. (mode 2)
+
+ - ``ratio_range is None and multiscale_mode == "range"``: randomly sample a
+ scale from the a range. (mode 3)
+
+ - ``ratio_range is None and multiscale_mode == "value"``: randomly sample a
+ scale from multiple scales. (mode 4)
+
+ Args:
+ img_scale (tuple or list[tuple]): Images scales for resizing.
+ multiscale_mode (str): Either "range" or "value".
+ ratio_range (tuple[float]): (min_ratio, max_ratio)
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+ image.
+ """
+
+ def __init__(self,
+ img_scale=None,
+ multiscale_mode='range',
+ ratio_range=None,
+ keep_ratio=True):
+ if img_scale is None:
+ self.img_scale = None
+ else:
+ if isinstance(img_scale, list):
+ self.img_scale = img_scale
+ else:
+ self.img_scale = [img_scale]
+ assert mmcv.is_list_of(self.img_scale, tuple)
+
+ if ratio_range is not None:
+ # mode 1: given img_scale=None and a range of image ratio
+ # mode 2: given a scale and a range of image ratio
+ assert self.img_scale is None or len(self.img_scale) == 1
+ else:
+ # mode 3 and 4: given multiple scales or a range of scales
+ assert multiscale_mode in ['value', 'range']
+
+ self.multiscale_mode = multiscale_mode
+ self.ratio_range = ratio_range
+ self.keep_ratio = keep_ratio
+
+ @staticmethod
+ def random_select(img_scales):
+ """Randomly select an img_scale from given candidates.
+
+ Args:
+ img_scales (list[tuple]): Images scales for selection.
+
+ Returns:
+ (tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
+ where ``img_scale`` is the selected image scale and
+ ``scale_idx`` is the selected index in the given candidates.
+ """
+
+ assert mmcv.is_list_of(img_scales, tuple)
+ scale_idx = np.random.randint(len(img_scales))
+ img_scale = img_scales[scale_idx]
+ return img_scale, scale_idx
+
+ @staticmethod
+ def random_sample(img_scales):
+ """Randomly sample an img_scale when ``multiscale_mode=='range'``.
+
+ Args:
+ img_scales (list[tuple]): Images scale range for sampling.
+ There must be two tuples in img_scales, which specify the lower
+ and upper bound of image scales.
+
+ Returns:
+ (tuple, None): Returns a tuple ``(img_scale, None)``, where
+ ``img_scale`` is sampled scale and None is just a placeholder
+ to be consistent with :func:`random_select`.
+ """
+
+ assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
+ img_scale_long = [max(s) for s in img_scales]
+ img_scale_short = [min(s) for s in img_scales]
+ long_edge = np.random.randint(
+ min(img_scale_long),
+ max(img_scale_long) + 1)
+ short_edge = np.random.randint(
+ min(img_scale_short),
+ max(img_scale_short) + 1)
+ img_scale = (long_edge, short_edge)
+ return img_scale, None
+
+ @staticmethod
+ def random_sample_ratio(img_scale, ratio_range):
+ """Randomly sample an img_scale when ``ratio_range`` is specified.
+
+ A ratio will be randomly sampled from the range specified by
+ ``ratio_range``. Then it would be multiplied with ``img_scale`` to
+ generate sampled scale.
+
+ Args:
+ img_scale (tuple): Images scale base to multiply with ratio.
+ ratio_range (tuple[float]): The minimum and maximum ratio to scale
+ the ``img_scale``.
+
+ Returns:
+ (tuple, None): Returns a tuple ``(scale, None)``, where
+ ``scale`` is sampled ratio multiplied with ``img_scale`` and
+ None is just a placeholder to be consistent with
+ :func:`random_select`.
+ """
+
+ assert isinstance(img_scale, tuple) and len(img_scale) == 2
+ min_ratio, max_ratio = ratio_range
+ assert min_ratio <= max_ratio
+ ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
+ scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
+ return scale, None
+
+ def _random_scale(self, results):
+ """Randomly sample an img_scale according to ``ratio_range`` and
+ ``multiscale_mode``.
+
+ If ``ratio_range`` is specified, a ratio will be sampled and be
+ multiplied with ``img_scale``.
+ If multiple scales are specified by ``img_scale``, a scale will be
+ sampled according to ``multiscale_mode``.
+ Otherwise, single scale will be used.
+
+ Args:
+ results (dict): Result dict from :obj:`dataset`.
+
+ Returns:
+ dict: Two new keys 'scale` and 'scale_idx` are added into
+ ``results``, which would be used by subsequent pipelines.
+ """
+
+ if self.ratio_range is not None:
+ if self.img_scale is None:
+ h, w = results['img'].shape[:2]
+ scale, scale_idx = self.random_sample_ratio((w, h),
+ self.ratio_range)
+ else:
+ scale, scale_idx = self.random_sample_ratio(
+ self.img_scale[0], self.ratio_range)
+ elif len(self.img_scale) == 1:
+ scale, scale_idx = self.img_scale[0], 0
+ elif self.multiscale_mode == 'range':
+ scale, scale_idx = self.random_sample(self.img_scale)
+ elif self.multiscale_mode == 'value':
+ scale, scale_idx = self.random_select(self.img_scale)
+ else:
+ raise NotImplementedError
+
+ results['scale'] = scale
+ results['scale_idx'] = scale_idx
+
+ def _resize_img(self, results):
+ """Resize images with ``results['scale']``."""
+ if self.keep_ratio:
+ img, scale_factor = mmcv.imrescale(
+ results['img'], results['scale'], return_scale=True)
+ # the w_scale and h_scale has minor difference
+ # a real fix should be done in the mmcv.imrescale in the future
+ new_h, new_w = img.shape[:2]
+ h, w = results['img'].shape[:2]
+ w_scale = new_w / w
+ h_scale = new_h / h
+ else:
+ img, w_scale, h_scale = mmcv.imresize(
+ results['img'], results['scale'], return_scale=True)
+ scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
+ dtype=np.float32)
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['pad_shape'] = img.shape # in case that there is no padding
+ results['scale_factor'] = scale_factor
+ results['keep_ratio'] = self.keep_ratio
+
+ def _resize_seg(self, results):
+ """Resize semantic segmentation map with ``results['scale']``."""
+ for key in results.get('seg_fields', []):
+ if self.keep_ratio:
+ gt_seg = mmcv.imrescale(
+ results[key], results['scale'], interpolation='nearest')
+ else:
+ gt_seg = mmcv.imresize(
+ results[key], results['scale'], interpolation='nearest')
+ results[key] = gt_seg
+
+ def __call__(self, results):
+ """Call function to resize images, bounding boxes, masks, semantic
+ segmentation map.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
+ 'keep_ratio' keys are added into result dict.
+ """
+
+ if 'scale' not in results:
+ self._random_scale(results)
+ self._resize_img(results)
+ self._resize_seg(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += (f'(img_scale={self.img_scale}, '
+ f'multiscale_mode={self.multiscale_mode}, '
+ f'ratio_range={self.ratio_range}, '
+ f'keep_ratio={self.keep_ratio})')
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomFlip(object):
+ """Flip the image & seg.
+
+ If the input dict contains the key "flip", then the flag will be used,
+ otherwise it will be randomly decided by a ratio specified in the init
+ method.
+
+ Args:
+ prob (float, optional): The flipping probability. Default: None.
+ direction(str, optional): The flipping direction. Options are
+ 'horizontal' and 'vertical'. Default: 'horizontal'.
+ """
+
+ @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip')
+ def __init__(self, prob=None, direction='horizontal'):
+ self.prob = prob
+ self.direction = direction
+ if prob is not None:
+ assert prob >= 0 and prob <= 1
+ assert direction in ['horizontal', 'vertical']
+
+ def __call__(self, results):
+ """Call function to flip bounding boxes, masks, semantic segmentation
+ maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Flipped results, 'flip', 'flip_direction' keys are added into
+ result dict.
+ """
+
+ if 'flip' not in results:
+ flip = True if np.random.rand() < self.prob else False
+ results['flip'] = flip
+ if 'flip_direction' not in results:
+ results['flip_direction'] = self.direction
+ if results['flip']:
+ # flip image
+ results['img'] = mmcv.imflip(
+ results['img'], direction=results['flip_direction'])
+
+ # flip segs
+ for key in results.get('seg_fields', []):
+ # use copy() to make numpy stride positive
+ results[key] = mmcv.imflip(
+ results[key], direction=results['flip_direction']).copy()
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(prob={self.prob})'
+
+
+@PIPELINES.register_module()
+class Pad(object):
+ """Pad the image & mask.
+
+ There are two padding modes: (1) pad to a fixed size and (2) pad to the
+ minimum size that is divisible by some number.
+ Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
+
+ Args:
+ size (tuple, optional): Fixed padding size.
+ size_divisor (int, optional): The divisor of padded size.
+ pad_val (float, optional): Padding value. Default: 0.
+ seg_pad_val (float, optional): Padding value of segmentation map.
+ Default: 255.
+ """
+
+ def __init__(self,
+ size=None,
+ size_divisor=None,
+ pad_val=0,
+ seg_pad_val=255):
+ self.size = size
+ self.size_divisor = size_divisor
+ self.pad_val = pad_val
+ self.seg_pad_val = seg_pad_val
+ # only one of size and size_divisor should be valid
+ assert size is not None or size_divisor is not None
+ assert size is None or size_divisor is None
+
+ def _pad_img(self, results):
+ """Pad images according to ``self.size``."""
+ if self.size is not None:
+ padded_img = mmcv.impad(
+ results['img'], shape=self.size, pad_val=self.pad_val)
+ elif self.size_divisor is not None:
+ padded_img = mmcv.impad_to_multiple(
+ results['img'], self.size_divisor, pad_val=self.pad_val)
+ results['img'] = padded_img
+ results['pad_shape'] = padded_img.shape
+ results['pad_fixed_size'] = self.size
+ results['pad_size_divisor'] = self.size_divisor
+
+ def _pad_seg(self, results):
+ """Pad masks according to ``results['pad_shape']``."""
+ for key in results.get('seg_fields', []):
+ results[key] = mmcv.impad(
+ results[key],
+ shape=results['pad_shape'][:2],
+ pad_val=self.seg_pad_val)
+
+ def __call__(self, results):
+ """Call function to pad images, masks, semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Updated result dict.
+ """
+
+ self._pad_img(results)
+ self._pad_seg(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \
+ f'pad_val={self.pad_val})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Normalize(object):
+ """Normalize the image.
+
+ Added key is "img_norm_cfg".
+
+ Args:
+ mean (sequence): Mean values of 3 channels.
+ std (sequence): Std values of 3 channels.
+ to_rgb (bool): Whether to convert the image from BGR to RGB,
+ default is true.
+ """
+
+ def __init__(self, mean, std, to_rgb=True):
+ self.mean = np.array(mean, dtype=np.float32)
+ self.std = np.array(std, dtype=np.float32)
+ self.to_rgb = to_rgb
+
+ def __call__(self, results):
+ """Call function to normalize images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Normalized results, 'img_norm_cfg' key is added into
+ result dict.
+ """
+
+ results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std,
+ self.to_rgb)
+ results['img_norm_cfg'] = dict(
+ mean=self.mean, std=self.std, to_rgb=self.to_rgb)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \
+ f'{self.to_rgb})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Rerange(object):
+ """Rerange the image pixel value.
+
+ Args:
+ min_value (float or int): Minimum value of the reranged image.
+ Default: 0.
+ max_value (float or int): Maximum value of the reranged image.
+ Default: 255.
+ """
+
+ def __init__(self, min_value=0, max_value=255):
+ assert isinstance(min_value, float) or isinstance(min_value, int)
+ assert isinstance(max_value, float) or isinstance(max_value, int)
+ assert min_value < max_value
+ self.min_value = min_value
+ self.max_value = max_value
+
+ def __call__(self, results):
+ """Call function to rerange images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Reranged results.
+ """
+
+ img = results['img']
+ img_min_value = np.min(img)
+ img_max_value = np.max(img)
+
+ assert img_min_value < img_max_value
+ # rerange to [0, 1]
+ img = (img - img_min_value) / (img_max_value - img_min_value)
+ # rerange to [min_value, max_value]
+ img = img * (self.max_value - self.min_value) + self.min_value
+ results['img'] = img
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(min_value={self.min_value}, max_value={self.max_value})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class CLAHE(object):
+ """Use CLAHE method to process the image.
+
+ See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
+ Graphics Gems, 1994:474-485.` for more information.
+
+ Args:
+ clip_limit (float): Threshold for contrast limiting. Default: 40.0.
+ tile_grid_size (tuple[int]): Size of grid for histogram equalization.
+ Input image will be divided into equally sized rectangular tiles.
+ It defines the number of tiles in row and column. Default: (8, 8).
+ """
+
+ def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
+ assert isinstance(clip_limit, (float, int))
+ self.clip_limit = clip_limit
+ assert is_tuple_of(tile_grid_size, int)
+ assert len(tile_grid_size) == 2
+ self.tile_grid_size = tile_grid_size
+
+ def __call__(self, results):
+ """Call function to Use CLAHE method process images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Processed results.
+ """
+
+ for i in range(results['img'].shape[2]):
+ results['img'][:, :, i] = mmcv.clahe(
+ np.array(results['img'][:, :, i], dtype=np.uint8),
+ self.clip_limit, self.tile_grid_size)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(clip_limit={self.clip_limit}, '\
+ f'tile_grid_size={self.tile_grid_size})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomCrop(object):
+ """Random crop the image & seg.
+
+ Args:
+ crop_size (tuple): Expected size after cropping, (h, w).
+ cat_max_ratio (float): The maximum ratio that single category could
+ occupy.
+ """
+
+ def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255):
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ self.crop_size = crop_size
+ self.cat_max_ratio = cat_max_ratio
+ self.ignore_index = ignore_index
+
+ def get_crop_bbox(self, img):
+ """Randomly get a crop bounding box."""
+ margin_h = max(img.shape[0] - self.crop_size[0], 0)
+ margin_w = max(img.shape[1] - self.crop_size[1], 0)
+ offset_h = np.random.randint(0, margin_h + 1)
+ offset_w = np.random.randint(0, margin_w + 1)
+ crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
+ crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
+
+ return crop_y1, crop_y2, crop_x1, crop_x2
+
+ def crop(self, img, crop_bbox):
+ """Crop from ``img``"""
+ crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
+ img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
+ return img
+
+ def __call__(self, results):
+ """Call function to randomly crop images, semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Randomly cropped results, 'img_shape' key in result dict is
+ updated according to crop size.
+ """
+
+ img = results['img']
+ crop_bbox = self.get_crop_bbox(img)
+ if self.cat_max_ratio < 1.:
+ # Repeat 10 times
+ for _ in range(10):
+ seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox)
+ labels, cnt = np.unique(seg_temp, return_counts=True)
+ cnt = cnt[labels != self.ignore_index]
+ if len(cnt) > 1 and np.max(cnt) / np.sum(
+ cnt) < self.cat_max_ratio:
+ break
+ crop_bbox = self.get_crop_bbox(img)
+
+ # crop the image
+ img = self.crop(img, crop_bbox)
+ img_shape = img.shape
+ results['img'] = img
+ results['img_shape'] = img_shape
+
+ # crop semantic seg
+ for key in results.get('seg_fields', []):
+ results[key] = self.crop(results[key], crop_bbox)
+
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(crop_size={self.crop_size})'
+
+
+@PIPELINES.register_module()
+class RandomRotate(object):
+ """Rotate the image & seg.
+
+ Args:
+ prob (float): The rotation probability.
+ degree (float, tuple[float]): Range of degrees to select from. If
+ degree is a number instead of tuple like (min, max),
+ the range of degree will be (``-degree``, ``+degree``)
+ pad_val (float, optional): Padding value of image. Default: 0.
+ seg_pad_val (float, optional): Padding value of segmentation map.
+ Default: 255.
+ center (tuple[float], optional): Center point (w, h) of the rotation in
+ the source image. If not specified, the center of the image will be
+ used. Default: None.
+ auto_bound (bool): Whether to adjust the image size to cover the whole
+ rotated image. Default: False
+ """
+
+ def __init__(self,
+ prob,
+ degree,
+ pad_val=0,
+ seg_pad_val=255,
+ center=None,
+ auto_bound=False):
+ self.prob = prob
+ assert prob >= 0 and prob <= 1
+ if isinstance(degree, (float, int)):
+ assert degree > 0, f'degree {degree} should be positive'
+ self.degree = (-degree, degree)
+ else:
+ self.degree = degree
+ assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
+ f'tuple of (min, max)'
+ self.pal_val = pad_val
+ self.seg_pad_val = seg_pad_val
+ self.center = center
+ self.auto_bound = auto_bound
+
+ def __call__(self, results):
+ """Call function to rotate image, semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Rotated results.
+ """
+
+ rotate = True if np.random.rand() < self.prob else False
+ degree = np.random.uniform(min(*self.degree), max(*self.degree))
+ if rotate:
+ # rotate image
+ results['img'] = mmcv.imrotate(
+ results['img'],
+ angle=degree,
+ border_value=self.pal_val,
+ center=self.center,
+ auto_bound=self.auto_bound)
+
+ # rotate segs
+ for key in results.get('seg_fields', []):
+ results[key] = mmcv.imrotate(
+ results[key],
+ angle=degree,
+ border_value=self.seg_pad_val,
+ center=self.center,
+ auto_bound=self.auto_bound,
+ interpolation='nearest')
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(prob={self.prob}, ' \
+ f'degree={self.degree}, ' \
+ f'pad_val={self.pal_val}, ' \
+ f'seg_pad_val={self.seg_pad_val}, ' \
+ f'center={self.center}, ' \
+ f'auto_bound={self.auto_bound})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RGB2Gray(object):
+ """Convert RGB image to grayscale image.
+
+ This transform calculate the weighted mean of input image channels with
+ ``weights`` and then expand the channels to ``out_channels``. When
+ ``out_channels`` is None, the number of output channels is the same as
+ input channels.
+
+ Args:
+ out_channels (int): Expected number of output channels after
+ transforming. Default: None.
+ weights (tuple[float]): The weights to calculate the weighted mean.
+ Default: (0.299, 0.587, 0.114).
+ """
+
+ def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)):
+ assert out_channels is None or out_channels > 0
+ self.out_channels = out_channels
+ assert isinstance(weights, tuple)
+ for item in weights:
+ assert isinstance(item, (float, int))
+ self.weights = weights
+
+ def __call__(self, results):
+ """Call function to convert RGB image to grayscale image.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with grayscale image.
+ """
+ img = results['img']
+ assert len(img.shape) == 3
+ assert img.shape[2] == len(self.weights)
+ weights = np.array(self.weights).reshape((1, 1, -1))
+ img = (img * weights).sum(2, keepdims=True)
+ if self.out_channels is None:
+ img = img.repeat(weights.shape[2], axis=2)
+ else:
+ img = img.repeat(self.out_channels, axis=2)
+
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(out_channels={self.out_channels}, ' \
+ f'weights={self.weights})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class AdjustGamma(object):
+ """Using gamma correction to process the image.
+
+ Args:
+ gamma (float or int): Gamma value used in gamma correction.
+ Default: 1.0.
+ """
+
+ def __init__(self, gamma=1.0):
+ assert isinstance(gamma, float) or isinstance(gamma, int)
+ assert gamma > 0
+ self.gamma = gamma
+ inv_gamma = 1.0 / gamma
+ self.table = np.array([(i / 255.0)**inv_gamma * 255
+ for i in np.arange(256)]).astype('uint8')
+
+ def __call__(self, results):
+ """Call function to process the image with gamma correction.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Processed results.
+ """
+
+ results['img'] = mmcv.lut_transform(
+ np.array(results['img'], dtype=np.uint8), self.table)
+
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(gamma={self.gamma})'
+
+
+@PIPELINES.register_module()
+class SegRescale(object):
+ """Rescale semantic segmentation maps.
+
+ Args:
+ scale_factor (float): The scale factor of the final output.
+ """
+
+ def __init__(self, scale_factor=1):
+ self.scale_factor = scale_factor
+
+ def __call__(self, results):
+ """Call function to scale the semantic segmentation map.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with semantic segmentation map scaled.
+ """
+ for key in results.get('seg_fields', []):
+ if self.scale_factor != 1:
+ results[key] = mmcv.imrescale(
+ results[key], self.scale_factor, interpolation='nearest')
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
+
+
+@PIPELINES.register_module()
+class PhotoMetricDistortion(object):
+ """Apply photometric distortion to image sequentially, every transformation
+ is applied with a probability of 0.5. The position of random contrast is in
+ second or second to last.
+
+ 1. random brightness
+ 2. random contrast (mode 0)
+ 3. convert color from BGR to HSV
+ 4. random saturation
+ 5. random hue
+ 6. convert color from HSV to BGR
+ 7. random contrast (mode 1)
+
+ Args:
+ brightness_delta (int): delta of brightness.
+ contrast_range (tuple): range of contrast.
+ saturation_range (tuple): range of saturation.
+ hue_delta (int): delta of hue.
+ """
+
+ def __init__(self,
+ brightness_delta=32,
+ contrast_range=(0.5, 1.5),
+ saturation_range=(0.5, 1.5),
+ hue_delta=18):
+ self.brightness_delta = brightness_delta
+ self.contrast_lower, self.contrast_upper = contrast_range
+ self.saturation_lower, self.saturation_upper = saturation_range
+ self.hue_delta = hue_delta
+
+ def convert(self, img, alpha=1, beta=0):
+ """Multiple with alpha and add beat with clip."""
+ img = img.astype(np.float32) * alpha + beta
+ img = np.clip(img, 0, 255)
+ return img.astype(np.uint8)
+
+ def brightness(self, img):
+ """Brightness distortion."""
+ if random.randint(2):
+ return self.convert(
+ img,
+ beta=random.uniform(-self.brightness_delta,
+ self.brightness_delta))
+ return img
+
+ def contrast(self, img):
+ """Contrast distortion."""
+ if random.randint(2):
+ return self.convert(
+ img,
+ alpha=random.uniform(self.contrast_lower, self.contrast_upper))
+ return img
+
+ def saturation(self, img):
+ """Saturation distortion."""
+ if random.randint(2):
+ img = mmcv.bgr2hsv(img)
+ img[:, :, 1] = self.convert(
+ img[:, :, 1],
+ alpha=random.uniform(self.saturation_lower,
+ self.saturation_upper))
+ img = mmcv.hsv2bgr(img)
+ return img
+
+ def hue(self, img):
+ """Hue distortion."""
+ if random.randint(2):
+ img = mmcv.bgr2hsv(img)
+ img[:, :,
+ 0] = (img[:, :, 0].astype(int) +
+ random.randint(-self.hue_delta, self.hue_delta)) % 180
+ img = mmcv.hsv2bgr(img)
+ return img
+
+ def __call__(self, results):
+ """Call function to perform photometric distortion on images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with images distorted.
+ """
+
+ img = results['img']
+ # random brightness
+ img = self.brightness(img)
+
+ # mode == 0 --> do random contrast first
+ # mode == 1 --> do random contrast last
+ mode = random.randint(2)
+ if mode == 1:
+ img = self.contrast(img)
+
+ # random saturation
+ img = self.saturation(img)
+
+ # random hue
+ img = self.hue(img)
+
+ # random contrast
+ if mode == 0:
+ img = self.contrast(img)
+
+ results['img'] = img
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += (f'(brightness_delta={self.brightness_delta}, '
+ f'contrast_range=({self.contrast_lower}, '
+ f'{self.contrast_upper}), '
+ f'saturation_range=({self.saturation_lower}, '
+ f'{self.saturation_upper}), '
+ f'hue_delta={self.hue_delta})')
+ return repr_str
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/stare.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/stare.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbd14e0920e7f6a73baff1432e5a32ccfdb0dfae
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/stare.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class STAREDataset(CustomDataset):
+ """STARE dataset.
+
+ In segmentation map annotation for STARE, 0 stands for background, which is
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+ '.ah.png'.
+ """
+
+ CLASSES = ('background', 'vessel')
+
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+ def __init__(self, **kwargs):
+ super(STAREDataset, self).__init__(
+ img_suffix='.png',
+ seg_map_suffix='.ah.png',
+ reduce_zero_label=False,
+ **kwargs)
+ assert osp.exists(self.img_dir)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/voc.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/voc.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8855203b14ee0dc4da9099a2945d4aedcffbcd6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/datasets/voc.py
@@ -0,0 +1,29 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class PascalVOCDataset(CustomDataset):
+ """Pascal VOC dataset.
+
+ Args:
+ split (str): Split txt file for Pascal VOC.
+ """
+
+ CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
+ 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
+ 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
+ 'train', 'tvmonitor')
+
+ PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
+
+ def __init__(self, split, **kwargs):
+ super(PascalVOCDataset, self).__init__(
+ img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs)
+ assert osp.exists(self.img_dir) and self.split is not None
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cf93f8bec9cf0cef0a3bd76ca3ca92eb188f535
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/__init__.py
@@ -0,0 +1,12 @@
+from .backbones import * # noqa: F401,F403
+from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
+ build_head, build_loss, build_segmentor)
+from .decode_heads import * # noqa: F401,F403
+from .losses import * # noqa: F401,F403
+from .necks import * # noqa: F401,F403
+from .segmentors import * # noqa: F401,F403
+
+__all__ = [
+ 'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
+ 'build_head', 'build_loss', 'build_segmentor'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1116c00a17c8bd9ed7f18743baee22b3b7d3f8d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/__init__.py
@@ -0,0 +1,16 @@
+from .cgnet import CGNet
+# from .fast_scnn import FastSCNN
+from .hrnet import HRNet
+from .mobilenet_v2 import MobileNetV2
+from .mobilenet_v3 import MobileNetV3
+from .resnest import ResNeSt
+from .resnet import ResNet, ResNetV1c, ResNetV1d
+from .resnext import ResNeXt
+from .unet import UNet
+from .vit import VisionTransformer
+
+__all__ = [
+ 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet',
+ 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
+ 'VisionTransformer'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/cgnet.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/cgnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..45c235e2e7fcef21e933ecb3ff88a37fa953abe6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/cgnet.py
@@ -0,0 +1,367 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from annotator.mmpkg.mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
+ constant_init, kaiming_init)
+from annotator.mmpkg.mmcv.runner import load_checkpoint
+from annotator.mmpkg.mmcv.utils.parrots_wrapper import _BatchNorm
+
+from annotator.mmpkg.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+
+
+class GlobalContextExtractor(nn.Module):
+ """Global Context Extractor for CGNet.
+
+ This class is employed to refine the joint feature of both local feature
+ and surrounding context.
+
+ Args:
+ channel (int): Number of input feature channels.
+ reduction (int): Reductions for global context extractor. Default: 16.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ def __init__(self, channel, reduction=16, with_cp=False):
+ super(GlobalContextExtractor, self).__init__()
+ self.channel = channel
+ self.reduction = reduction
+ assert reduction >= 1 and channel >= reduction
+ self.with_cp = with_cp
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel), nn.Sigmoid())
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ num_batch, num_channel = x.size()[:2]
+ y = self.avg_pool(x).view(num_batch, num_channel)
+ y = self.fc(y).view(num_batch, num_channel, 1, 1)
+ return x * y
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+class ContextGuidedBlock(nn.Module):
+ """Context Guided Block for CGNet.
+
+ This class consists of four components: local feature extractor,
+ surrounding feature extractor, joint feature extractor and global
+ context extractor.
+
+ Args:
+ in_channels (int): Number of input feature channels.
+ out_channels (int): Number of output feature channels.
+ dilation (int): Dilation rate for surrounding context extractor.
+ Default: 2.
+ reduction (int): Reduction for global context extractor. Default: 16.
+ skip_connect (bool): Add input to output or not. Default: True.
+ downsample (bool): Downsample the input to 1/2 or not. Default: False.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='PReLU').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ dilation=2,
+ reduction=16,
+ skip_connect=True,
+ downsample=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='PReLU'),
+ with_cp=False):
+ super(ContextGuidedBlock, self).__init__()
+ self.with_cp = with_cp
+ self.downsample = downsample
+
+ channels = out_channels if downsample else out_channels // 2
+ if 'type' in act_cfg and act_cfg['type'] == 'PReLU':
+ act_cfg['num_parameters'] = channels
+ kernel_size = 3 if downsample else 1
+ stride = 2 if downsample else 1
+ padding = (kernel_size - 1) // 2
+
+ self.conv1x1 = ConvModule(
+ in_channels,
+ channels,
+ kernel_size,
+ stride,
+ padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ self.f_loc = build_conv_layer(
+ conv_cfg,
+ channels,
+ channels,
+ kernel_size=3,
+ padding=1,
+ groups=channels,
+ bias=False)
+ self.f_sur = build_conv_layer(
+ conv_cfg,
+ channels,
+ channels,
+ kernel_size=3,
+ padding=dilation,
+ groups=channels,
+ dilation=dilation,
+ bias=False)
+
+ self.bn = build_norm_layer(norm_cfg, 2 * channels)[1]
+ self.activate = nn.PReLU(2 * channels)
+
+ if downsample:
+ self.bottleneck = build_conv_layer(
+ conv_cfg,
+ 2 * channels,
+ out_channels,
+ kernel_size=1,
+ bias=False)
+
+ self.skip_connect = skip_connect and not downsample
+ self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ out = self.conv1x1(x)
+ loc = self.f_loc(out)
+ sur = self.f_sur(out)
+
+ joi_feat = torch.cat([loc, sur], 1) # the joint feature
+ joi_feat = self.bn(joi_feat)
+ joi_feat = self.activate(joi_feat)
+ if self.downsample:
+ joi_feat = self.bottleneck(joi_feat) # channel = out_channels
+ # f_glo is employed to refine the joint feature
+ out = self.f_glo(joi_feat)
+
+ if self.skip_connect:
+ return x + out
+ else:
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+class InputInjection(nn.Module):
+ """Downsampling module for CGNet."""
+
+ def __init__(self, num_downsampling):
+ super(InputInjection, self).__init__()
+ self.pool = nn.ModuleList()
+ for i in range(num_downsampling):
+ self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
+
+ def forward(self, x):
+ for pool in self.pool:
+ x = pool(x)
+ return x
+
+
+@BACKBONES.register_module()
+class CGNet(nn.Module):
+ """CGNet backbone.
+
+ A Light-weight Context Guided Network for Semantic Segmentation
+ arXiv: https://arxiv.org/abs/1811.08201
+
+ Args:
+ in_channels (int): Number of input image channels. Normally 3.
+ num_channels (tuple[int]): Numbers of feature channels at each stages.
+ Default: (32, 64, 128).
+ num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2.
+ Default: (3, 21).
+ dilations (tuple[int]): Dilation rate for surrounding context
+ extractors at stage 1 and stage 2. Default: (2, 4).
+ reductions (tuple[int]): Reductions for global context extractors at
+ stage 1 and stage 2. Default: (8, 16).
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='PReLU').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ def __init__(self,
+ in_channels=3,
+ num_channels=(32, 64, 128),
+ num_blocks=(3, 21),
+ dilations=(2, 4),
+ reductions=(8, 16),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='PReLU'),
+ norm_eval=False,
+ with_cp=False):
+
+ super(CGNet, self).__init__()
+ self.in_channels = in_channels
+ self.num_channels = num_channels
+ assert isinstance(self.num_channels, tuple) and len(
+ self.num_channels) == 3
+ self.num_blocks = num_blocks
+ assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2
+ self.dilations = dilations
+ assert isinstance(self.dilations, tuple) and len(self.dilations) == 2
+ self.reductions = reductions
+ assert isinstance(self.reductions, tuple) and len(self.reductions) == 2
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU':
+ self.act_cfg['num_parameters'] = num_channels[0]
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ cur_channels = in_channels
+ self.stem = nn.ModuleList()
+ for i in range(3):
+ self.stem.append(
+ ConvModule(
+ cur_channels,
+ num_channels[0],
+ 3,
+ 2 if i == 0 else 1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ cur_channels = num_channels[0]
+
+ self.inject_2x = InputInjection(1) # down-sample for Input, factor=2
+ self.inject_4x = InputInjection(2) # down-sample for Input, factor=4
+
+ cur_channels += in_channels
+ self.norm_prelu_0 = nn.Sequential(
+ build_norm_layer(norm_cfg, cur_channels)[1],
+ nn.PReLU(cur_channels))
+
+ # stage 1
+ self.level1 = nn.ModuleList()
+ for i in range(num_blocks[0]):
+ self.level1.append(
+ ContextGuidedBlock(
+ cur_channels if i == 0 else num_channels[1],
+ num_channels[1],
+ dilations[0],
+ reductions[0],
+ downsample=(i == 0),
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ with_cp=with_cp)) # CG block
+
+ cur_channels = 2 * num_channels[1] + in_channels
+ self.norm_prelu_1 = nn.Sequential(
+ build_norm_layer(norm_cfg, cur_channels)[1],
+ nn.PReLU(cur_channels))
+
+ # stage 2
+ self.level2 = nn.ModuleList()
+ for i in range(num_blocks[1]):
+ self.level2.append(
+ ContextGuidedBlock(
+ cur_channels if i == 0 else num_channels[2],
+ num_channels[2],
+ dilations[1],
+ reductions[1],
+ downsample=(i == 0),
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ with_cp=with_cp)) # CG block
+
+ cur_channels = 2 * num_channels[2]
+ self.norm_prelu_2 = nn.Sequential(
+ build_norm_layer(norm_cfg, cur_channels)[1],
+ nn.PReLU(cur_channels))
+
+ def forward(self, x):
+ output = []
+
+ # stage 0
+ inp_2x = self.inject_2x(x)
+ inp_4x = self.inject_4x(x)
+ for layer in self.stem:
+ x = layer(x)
+ x = self.norm_prelu_0(torch.cat([x, inp_2x], 1))
+ output.append(x)
+
+ # stage 1
+ for i, layer in enumerate(self.level1):
+ x = layer(x)
+ if i == 0:
+ down1 = x
+ x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1))
+ output.append(x)
+
+ # stage 2
+ for i, layer in enumerate(self.level2):
+ x = layer(x)
+ if i == 0:
+ down2 = x
+ x = self.norm_prelu_2(torch.cat([down2, x], 1))
+ output.append(x)
+
+ return output
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ elif isinstance(m, nn.PReLU):
+ constant_init(m, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def train(self, mode=True):
+ """Convert the model into training mode will keeping the normalization
+ layer freezed."""
+ super(CGNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/fast_scnn.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/fast_scnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..417114417ebc830ea11ae7216aa12d8f7a79e5cb
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/fast_scnn.py
@@ -0,0 +1,375 @@
+import torch
+import torch.nn as nn
+from annotator.mmpkg.mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init,
+ kaiming_init)
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from annotator.mmpkg.mmseg.models.decode_heads.psp_head import PPM
+from annotator.mmpkg.mmseg.ops import resize
+from ..builder import BACKBONES
+from ..utils.inverted_residual import InvertedResidual
+
+
+class LearningToDownsample(nn.Module):
+ """Learning to downsample module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ dw_channels (tuple[int]): Number of output channels of the first and
+ the second depthwise conv (dwconv) layers.
+ out_channels (int): Number of output channels of the whole
+ 'learning to downsample' module.
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ """
+
+ def __init__(self,
+ in_channels,
+ dw_channels,
+ out_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU')):
+ super(LearningToDownsample, self).__init__()
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ dw_channels1 = dw_channels[0]
+ dw_channels2 = dw_channels[1]
+
+ self.conv = ConvModule(
+ in_channels,
+ dw_channels1,
+ 3,
+ stride=2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.dsconv1 = DepthwiseSeparableConvModule(
+ dw_channels1,
+ dw_channels2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg)
+ self.dsconv2 = DepthwiseSeparableConvModule(
+ dw_channels2,
+ out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.dsconv1(x)
+ x = self.dsconv2(x)
+ return x
+
+
+class GlobalFeatureExtractor(nn.Module):
+ """Global feature extractor module.
+
+ Args:
+ in_channels (int): Number of input channels of the GFE module.
+ Default: 64
+ block_channels (tuple[int]): Tuple of ints. Each int specifies the
+ number of output channels of each Inverted Residual module.
+ Default: (64, 96, 128)
+ out_channels(int): Number of output channels of the GFE module.
+ Default: 128
+ expand_ratio (int): Adjusts number of channels of the hidden layer
+ in InvertedResidual by this amount.
+ Default: 6
+ num_blocks (tuple[int]): Tuple of ints. Each int specifies the
+ number of times each Inverted Residual module is repeated.
+ The repeated Inverted Residual modules are called a 'group'.
+ Default: (3, 3, 3)
+ strides (tuple[int]): Tuple of ints. Each int specifies
+ the downsampling factor of each 'group'.
+ Default: (2, 2, 1)
+ pool_scales (tuple[int]): Tuple of ints. Each int specifies
+ the parameter required in 'global average pooling' within PPM.
+ Default: (1, 2, 3, 6)
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False
+ """
+
+ def __init__(self,
+ in_channels=64,
+ block_channels=(64, 96, 128),
+ out_channels=128,
+ expand_ratio=6,
+ num_blocks=(3, 3, 3),
+ strides=(2, 2, 1),
+ pool_scales=(1, 2, 3, 6),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ align_corners=False):
+ super(GlobalFeatureExtractor, self).__init__()
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ assert len(block_channels) == len(num_blocks) == 3
+ self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
+ num_blocks[0], strides[0],
+ expand_ratio)
+ self.bottleneck2 = self._make_layer(block_channels[0],
+ block_channels[1], num_blocks[1],
+ strides[1], expand_ratio)
+ self.bottleneck3 = self._make_layer(block_channels[1],
+ block_channels[2], num_blocks[2],
+ strides[2], expand_ratio)
+ self.ppm = PPM(
+ pool_scales,
+ block_channels[2],
+ block_channels[2] // 4,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=align_corners)
+ self.out = ConvModule(
+ block_channels[2] * 2,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def _make_layer(self,
+ in_channels,
+ out_channels,
+ blocks,
+ stride=1,
+ expand_ratio=6):
+ layers = [
+ InvertedResidual(
+ in_channels,
+ out_channels,
+ stride,
+ expand_ratio,
+ norm_cfg=self.norm_cfg)
+ ]
+ for i in range(1, blocks):
+ layers.append(
+ InvertedResidual(
+ out_channels,
+ out_channels,
+ 1,
+ expand_ratio,
+ norm_cfg=self.norm_cfg))
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.bottleneck1(x)
+ x = self.bottleneck2(x)
+ x = self.bottleneck3(x)
+ x = torch.cat([x, *self.ppm(x)], dim=1)
+ x = self.out(x)
+ return x
+
+
+class FeatureFusionModule(nn.Module):
+ """Feature fusion module.
+
+ Args:
+ higher_in_channels (int): Number of input channels of the
+ higher-resolution branch.
+ lower_in_channels (int): Number of input channels of the
+ lower-resolution branch.
+ out_channels (int): Number of output channels.
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False
+ """
+
+ def __init__(self,
+ higher_in_channels,
+ lower_in_channels,
+ out_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ align_corners=False):
+ super(FeatureFusionModule, self).__init__()
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.align_corners = align_corners
+ self.dwconv = ConvModule(
+ lower_in_channels,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.conv_lower_res = ConvModule(
+ out_channels,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.conv_higher_res = ConvModule(
+ higher_in_channels,
+ out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.relu = nn.ReLU(True)
+
+ def forward(self, higher_res_feature, lower_res_feature):
+ lower_res_feature = resize(
+ lower_res_feature,
+ size=higher_res_feature.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ lower_res_feature = self.dwconv(lower_res_feature)
+ lower_res_feature = self.conv_lower_res(lower_res_feature)
+
+ higher_res_feature = self.conv_higher_res(higher_res_feature)
+ out = higher_res_feature + lower_res_feature
+ return self.relu(out)
+
+
+@BACKBONES.register_module()
+class FastSCNN(nn.Module):
+ """Fast-SCNN Backbone.
+
+ Args:
+ in_channels (int): Number of input image channels. Default: 3.
+ downsample_dw_channels (tuple[int]): Number of output channels after
+ the first conv layer & the second conv layer in
+ Learning-To-Downsample (LTD) module.
+ Default: (32, 48).
+ global_in_channels (int): Number of input channels of
+ Global Feature Extractor(GFE).
+ Equal to number of output channels of LTD.
+ Default: 64.
+ global_block_channels (tuple[int]): Tuple of integers that describe
+ the output channels for each of the MobileNet-v2 bottleneck
+ residual blocks in GFE.
+ Default: (64, 96, 128).
+ global_block_strides (tuple[int]): Tuple of integers
+ that describe the strides (downsampling factors) for each of the
+ MobileNet-v2 bottleneck residual blocks in GFE.
+ Default: (2, 2, 1).
+ global_out_channels (int): Number of output channels of GFE.
+ Default: 128.
+ higher_in_channels (int): Number of input channels of the higher
+ resolution branch in FFM.
+ Equal to global_in_channels.
+ Default: 64.
+ lower_in_channels (int): Number of input channels of the lower
+ resolution branch in FFM.
+ Equal to global_out_channels.
+ Default: 128.
+ fusion_out_channels (int): Number of output channels of FFM.
+ Default: 128.
+ out_indices (tuple): Tuple of indices of list
+ [higher_res_features, lower_res_features, fusion_output].
+ Often set to (0,1,2) to enable aux. heads.
+ Default: (0, 1, 2).
+ conv_cfg (dict | None): Config of conv layers. Default: None
+ norm_cfg (dict | None): Config of norm layers. Default:
+ dict(type='BN')
+ act_cfg (dict): Config of activation layers. Default:
+ dict(type='ReLU')
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False
+ """
+
+ def __init__(self,
+ in_channels=3,
+ downsample_dw_channels=(32, 48),
+ global_in_channels=64,
+ global_block_channels=(64, 96, 128),
+ global_block_strides=(2, 2, 1),
+ global_out_channels=128,
+ higher_in_channels=64,
+ lower_in_channels=128,
+ fusion_out_channels=128,
+ out_indices=(0, 1, 2),
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ align_corners=False):
+
+ super(FastSCNN, self).__init__()
+ if global_in_channels != higher_in_channels:
+ raise AssertionError('Global Input Channels must be the same \
+ with Higher Input Channels!')
+ elif global_out_channels != lower_in_channels:
+ raise AssertionError('Global Output Channels must be the same \
+ with Lower Input Channels!')
+
+ self.in_channels = in_channels
+ self.downsample_dw_channels1 = downsample_dw_channels[0]
+ self.downsample_dw_channels2 = downsample_dw_channels[1]
+ self.global_in_channels = global_in_channels
+ self.global_block_channels = global_block_channels
+ self.global_block_strides = global_block_strides
+ self.global_out_channels = global_out_channels
+ self.higher_in_channels = higher_in_channels
+ self.lower_in_channels = lower_in_channels
+ self.fusion_out_channels = fusion_out_channels
+ self.out_indices = out_indices
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.align_corners = align_corners
+ self.learning_to_downsample = LearningToDownsample(
+ in_channels,
+ downsample_dw_channels,
+ global_in_channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.global_feature_extractor = GlobalFeatureExtractor(
+ global_in_channels,
+ global_block_channels,
+ global_out_channels,
+ strides=self.global_block_strides,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+ self.feature_fusion = FeatureFusionModule(
+ higher_in_channels,
+ lower_in_channels,
+ fusion_out_channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+
+ def init_weights(self, pretrained=None):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+
+ def forward(self, x):
+ higher_res_features = self.learning_to_downsample(x)
+ lower_res_features = self.global_feature_extractor(higher_res_features)
+ fusion_output = self.feature_fusion(higher_res_features,
+ lower_res_features)
+
+ outs = [higher_res_features, lower_res_features, fusion_output]
+ outs = [outs[i] for i in self.out_indices]
+ return tuple(outs)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/hrnet.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/hrnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d77fd6eadeec25a6b84619f0d7efa7c577b0464
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/hrnet.py
@@ -0,0 +1,555 @@
+import torch.nn as nn
+from annotator.mmpkg.mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
+ kaiming_init)
+from annotator.mmpkg.mmcv.runner import load_checkpoint
+from annotator.mmpkg.mmcv.utils.parrots_wrapper import _BatchNorm
+
+from annotator.mmpkg.mmseg.ops import Upsample, resize
+from annotator.mmpkg.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from .resnet import BasicBlock, Bottleneck
+
+
+class HRModule(nn.Module):
+ """High-Resolution Module for HRNet.
+
+ In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
+ is in this module.
+ """
+
+ def __init__(self,
+ num_branches,
+ blocks,
+ num_blocks,
+ in_channels,
+ num_channels,
+ multiscale_output=True,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True)):
+ super(HRModule, self).__init__()
+ self._check_branches(num_branches, num_blocks, in_channels,
+ num_channels)
+
+ self.in_channels = in_channels
+ self.num_branches = num_branches
+
+ self.multiscale_output = multiscale_output
+ self.norm_cfg = norm_cfg
+ self.conv_cfg = conv_cfg
+ self.with_cp = with_cp
+ self.branches = self._make_branches(num_branches, blocks, num_blocks,
+ num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(inplace=False)
+
+ def _check_branches(self, num_branches, num_blocks, in_channels,
+ num_channels):
+ """Check branches configuration."""
+ if num_branches != len(num_blocks):
+ error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \
+ f'{len(num_blocks)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \
+ f'{len(num_channels)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(in_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \
+ f'{len(in_channels)})'
+ raise ValueError(error_msg)
+
+ def _make_one_branch(self,
+ branch_index,
+ block,
+ num_blocks,
+ num_channels,
+ stride=1):
+ """Build one branch."""
+ downsample = None
+ if stride != 1 or \
+ self.in_channels[branch_index] != \
+ num_channels[branch_index] * block.expansion:
+ downsample = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ self.in_channels[branch_index],
+ num_channels[branch_index] * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, num_channels[branch_index] *
+ block.expansion)[1])
+
+ layers = []
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index],
+ stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+ self.in_channels[branch_index] = \
+ num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index],
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+
+ return nn.Sequential(*layers)
+
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ """Build multiple branch."""
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(
+ self._make_one_branch(i, block, num_blocks, num_channels))
+
+ return nn.ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ """Build fuse layer."""
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ in_channels = self.in_channels
+ fuse_layers = []
+ num_out_branches = num_branches if self.multiscale_output else 1
+ for i in range(num_out_branches):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ build_norm_layer(self.norm_cfg, in_channels[i])[1],
+ # we set align_corners=False for HRNet
+ Upsample(
+ scale_factor=2**(j - i),
+ mode='bilinear',
+ align_corners=False)))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv_downsamples = []
+ for k in range(i - j):
+ if k == i - j - 1:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[i])[1]))
+ else:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[j],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[j])[1],
+ nn.ReLU(inplace=False)))
+ fuse_layer.append(nn.Sequential(*conv_downsamples))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def forward(self, x):
+ """Forward function."""
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ y = 0
+ for j in range(self.num_branches):
+ if i == j:
+ y += x[j]
+ elif j > i:
+ y = y + resize(
+ self.fuse_layers[i][j](x[j]),
+ size=x[i].shape[2:],
+ mode='bilinear',
+ align_corners=False)
+ else:
+ y += self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+ return x_fuse
+
+
+@BACKBONES.register_module()
+class HRNet(nn.Module):
+ """HRNet backbone.
+
+ High-Resolution Representations for Labeling Pixels and Regions
+ arXiv: https://arxiv.org/abs/1904.04514
+
+ Args:
+ extra (dict): detailed configuration for each stage of HRNet.
+ in_channels (int): Number of input image channels. Normally 3.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+
+ Example:
+ >>> from annotator.mmpkg.mmseg.models import HRNet
+ >>> import torch
+ >>> extra = dict(
+ >>> stage1=dict(
+ >>> num_modules=1,
+ >>> num_branches=1,
+ >>> block='BOTTLENECK',
+ >>> num_blocks=(4, ),
+ >>> num_channels=(64, )),
+ >>> stage2=dict(
+ >>> num_modules=1,
+ >>> num_branches=2,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4),
+ >>> num_channels=(32, 64)),
+ >>> stage3=dict(
+ >>> num_modules=4,
+ >>> num_branches=3,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4),
+ >>> num_channels=(32, 64, 128)),
+ >>> stage4=dict(
+ >>> num_modules=3,
+ >>> num_branches=4,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4, 4),
+ >>> num_channels=(32, 64, 128, 256)))
+ >>> self = HRNet(extra, in_channels=1)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 1, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 32, 8, 8)
+ (1, 64, 4, 4)
+ (1, 128, 2, 2)
+ (1, 256, 1, 1)
+ """
+
+ blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
+
+ def __init__(self,
+ extra,
+ in_channels=3,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ with_cp=False,
+ zero_init_residual=False):
+ super(HRNet, self).__init__()
+ self.extra = extra
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ self.zero_init_residual = zero_init_residual
+
+ # stem net
+ self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ 64,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # stage 1
+ self.stage1_cfg = self.extra['stage1']
+ num_channels = self.stage1_cfg['num_channels'][0]
+ block_type = self.stage1_cfg['block']
+ num_blocks = self.stage1_cfg['num_blocks'][0]
+
+ block = self.blocks_dict[block_type]
+ stage1_out_channels = num_channels * block.expansion
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
+
+ # stage 2
+ self.stage2_cfg = self.extra['stage2']
+ num_channels = self.stage2_cfg['num_channels']
+ block_type = self.stage2_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition1 = self._make_transition_layer([stage1_out_channels],
+ num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(
+ self.stage2_cfg, num_channels)
+
+ # stage 3
+ self.stage3_cfg = self.extra['stage3']
+ num_channels = self.stage3_cfg['num_channels']
+ block_type = self.stage3_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition2 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+ self.stage3, pre_stage_channels = self._make_stage(
+ self.stage3_cfg, num_channels)
+
+ # stage 4
+ self.stage4_cfg = self.extra['stage4']
+ num_channels = self.stage4_cfg['num_channels']
+ block_type = self.stage4_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition3 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+ self.stage4, pre_stage_channels = self._make_stage(
+ self.stage4_cfg, num_channels)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: the normalization layer named "norm2" """
+ return getattr(self, self.norm2_name)
+
+ def _make_transition_layer(self, num_channels_pre_layer,
+ num_channels_cur_layer):
+ """Make transition layer."""
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ num_channels_cur_layer[i])[1],
+ nn.ReLU(inplace=True)))
+ else:
+ transition_layers.append(None)
+ else:
+ conv_downsamples = []
+ for j in range(i + 1 - num_branches_pre):
+ in_channels = num_channels_pre_layer[-1]
+ out_channels = num_channels_cur_layer[i] \
+ if j == i - num_branches_pre else in_channels
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, out_channels)[1],
+ nn.ReLU(inplace=True)))
+ transition_layers.append(nn.Sequential(*conv_downsamples))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
+ """Make each layer."""
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
+
+ layers = []
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+ inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+
+ return nn.Sequential(*layers)
+
+ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
+ """Make each stage."""
+ num_modules = layer_config['num_modules']
+ num_branches = layer_config['num_branches']
+ num_blocks = layer_config['num_blocks']
+ num_channels = layer_config['num_channels']
+ block = self.blocks_dict[layer_config['block']]
+
+ hr_modules = []
+ for i in range(num_modules):
+ # multi_scale_output is only used for the last module
+ if not multiscale_output and i == num_modules - 1:
+ reset_multiscale_output = False
+ else:
+ reset_multiscale_output = True
+
+ hr_modules.append(
+ HRModule(
+ num_branches,
+ block,
+ num_blocks,
+ in_channels,
+ num_channels,
+ reset_multiscale_output,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+
+ return nn.Sequential(*hr_modules), in_channels
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ constant_init(m.norm3, 0)
+ elif isinstance(m, BasicBlock):
+ constant_init(m.norm2, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ """Forward function."""
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.norm2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = []
+ for i in range(self.stage2_cfg['num_branches']):
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+
+ x_list = []
+ for i in range(self.stage3_cfg['num_branches']):
+ if self.transition2[i] is not None:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+
+ x_list = []
+ for i in range(self.stage4_cfg['num_branches']):
+ if self.transition3[i] is not None:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage4(x_list)
+
+ return y_list
+
+ def train(self, mode=True):
+ """Convert the model into training mode will keeping the normalization
+ layer freezed."""
+ super(HRNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/mobilenet_v2.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/mobilenet_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b5b6cd6d04c9da04669550d7f1fd24381460bf3
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/mobilenet_v2.py
@@ -0,0 +1,180 @@
+import logging
+
+import torch.nn as nn
+from annotator.mmpkg.mmcv.cnn import ConvModule, constant_init, kaiming_init
+from annotator.mmpkg.mmcv.runner import load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ..builder import BACKBONES
+from ..utils import InvertedResidual, make_divisible
+
+
+@BACKBONES.register_module()
+class MobileNetV2(nn.Module):
+ """MobileNetV2 backbone.
+
+ Args:
+ widen_factor (float): Width multiplier, multiply number of
+ channels in each layer by this amount. Default: 1.0.
+ strides (Sequence[int], optional): Strides of the first block of each
+ layer. If not specified, default config in ``arch_setting`` will
+ be used.
+ dilations (Sequence[int]): Dilation of each layer.
+ out_indices (None or Sequence[int]): Output from which stages.
+ Default: (7, ).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU6').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ # Parameters to build layers. 3 parameters are needed to construct a
+ # layer, from left to right: expand_ratio, channel, num_blocks.
+ arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4],
+ [6, 96, 3], [6, 160, 3], [6, 320, 1]]
+
+ def __init__(self,
+ widen_factor=1.,
+ strides=(1, 2, 2, 2, 1, 2, 1),
+ dilations=(1, 1, 1, 1, 1, 1, 1),
+ out_indices=(1, 2, 4, 6),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU6'),
+ norm_eval=False,
+ with_cp=False):
+ super(MobileNetV2, self).__init__()
+ self.widen_factor = widen_factor
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == len(self.arch_settings)
+ self.out_indices = out_indices
+ for index in out_indices:
+ if index not in range(0, 7):
+ raise ValueError('the item in out_indices must in '
+ f'range(0, 8). But received {index}')
+
+ if frozen_stages not in range(-1, 7):
+ raise ValueError('frozen_stages must be in range(-1, 7). '
+ f'But received {frozen_stages}')
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ self.in_channels = make_divisible(32 * widen_factor, 8)
+
+ self.conv1 = ConvModule(
+ in_channels=3,
+ out_channels=self.in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.layers = []
+
+ for i, layer_cfg in enumerate(self.arch_settings):
+ expand_ratio, channel, num_blocks = layer_cfg
+ stride = self.strides[i]
+ dilation = self.dilations[i]
+ out_channels = make_divisible(channel * widen_factor, 8)
+ inverted_res_layer = self.make_layer(
+ out_channels=out_channels,
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ expand_ratio=expand_ratio)
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, inverted_res_layer)
+ self.layers.append(layer_name)
+
+ def make_layer(self, out_channels, num_blocks, stride, dilation,
+ expand_ratio):
+ """Stack InvertedResidual blocks to build a layer for MobileNetV2.
+
+ Args:
+ out_channels (int): out_channels of block.
+ num_blocks (int): Number of blocks.
+ stride (int): Stride of the first block.
+ dilation (int): Dilation of the first block.
+ expand_ratio (int): Expand the number of channels of the
+ hidden layer in InvertedResidual by this ratio.
+ """
+ layers = []
+ for i in range(num_blocks):
+ layers.append(
+ InvertedResidual(
+ self.in_channels,
+ out_channels,
+ stride if i == 0 else 1,
+ expand_ratio=expand_ratio,
+ dilation=dilation if i == 0 else 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ with_cp=self.with_cp))
+ self.in_channels = out_channels
+
+ return nn.Sequential(*layers)
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ x = self.conv1(x)
+
+ outs = []
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ layer = getattr(self, f'layer{i}')
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(MobileNetV2, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/mobilenet_v3.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/mobilenet_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3c22bdd22356a600454f14c2ed12e7ef72c8ca1
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/mobilenet_v3.py
@@ -0,0 +1,255 @@
+import logging
+
+import annotator.mmpkg.mmcv as mmcv
+import torch.nn as nn
+from annotator.mmpkg.mmcv.cnn import ConvModule, constant_init, kaiming_init
+from annotator.mmpkg.mmcv.cnn.bricks import Conv2dAdaptivePadding
+from annotator.mmpkg.mmcv.runner import load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ..builder import BACKBONES
+from ..utils import InvertedResidualV3 as InvertedResidual
+
+
+@BACKBONES.register_module()
+class MobileNetV3(nn.Module):
+ """MobileNetV3 backbone.
+
+ This backbone is the improved implementation of `Searching for MobileNetV3
+ `_.
+
+ Args:
+ arch (str): Architecture of mobilnetv3, from {'small', 'large'}.
+ Default: 'small'.
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ out_indices (tuple[int]): Output from which layer.
+ Default: (0, 1, 12).
+ frozen_stages (int): Stages to be frozen (all param fixed).
+ Default: -1, which means not freezing any parameters.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save
+ some memory while slowing down the training speed.
+ Default: False.
+ """
+ # Parameters to build each block:
+ # [kernel size, mid channels, out channels, with_se, act type, stride]
+ arch_settings = {
+ 'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4
+ [3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8
+ [3, 88, 24, False, 'ReLU', 1],
+ [5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16
+ [5, 240, 40, True, 'HSwish', 1],
+ [5, 240, 40, True, 'HSwish', 1],
+ [5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16
+ [5, 144, 48, True, 'HSwish', 1],
+ [5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32
+ [5, 576, 96, True, 'HSwish', 1],
+ [5, 576, 96, True, 'HSwish', 1]],
+ 'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2
+ [3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4
+ [3, 72, 24, False, 'ReLU', 1],
+ [5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8
+ [5, 120, 40, True, 'ReLU', 1],
+ [5, 120, 40, True, 'ReLU', 1],
+ [3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16
+ [3, 200, 80, False, 'HSwish', 1],
+ [3, 184, 80, False, 'HSwish', 1],
+ [3, 184, 80, False, 'HSwish', 1],
+ [3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16
+ [3, 672, 112, True, 'HSwish', 1],
+ [5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32
+ [5, 960, 160, True, 'HSwish', 1],
+ [5, 960, 160, True, 'HSwish', 1]]
+ } # yapf: disable
+
+ def __init__(self,
+ arch='small',
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ out_indices=(0, 1, 12),
+ frozen_stages=-1,
+ reduction_factor=1,
+ norm_eval=False,
+ with_cp=False):
+ super(MobileNetV3, self).__init__()
+ assert arch in self.arch_settings
+ assert isinstance(reduction_factor, int) and reduction_factor > 0
+ assert mmcv.is_tuple_of(out_indices, int)
+ for index in out_indices:
+ if index not in range(0, len(self.arch_settings[arch]) + 2):
+ raise ValueError(
+ 'the item in out_indices must in '
+ f'range(0, {len(self.arch_settings[arch])+2}). '
+ f'But received {index}')
+
+ if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
+ raise ValueError('frozen_stages must be in range(-1, '
+ f'{len(self.arch_settings[arch])+2}). '
+ f'But received {frozen_stages}')
+ self.arch = arch
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.reduction_factor = reduction_factor
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ self.layers = self._make_layer()
+
+ def _make_layer(self):
+ layers = []
+
+ # build the first layer (layer0)
+ in_channels = 16
+ layer = ConvModule(
+ in_channels=3,
+ out_channels=in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ conv_cfg=dict(type='Conv2dAdaptivePadding'),
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type='HSwish'))
+ self.add_module('layer0', layer)
+ layers.append('layer0')
+
+ layer_setting = self.arch_settings[self.arch]
+ for i, params in enumerate(layer_setting):
+ (kernel_size, mid_channels, out_channels, with_se, act,
+ stride) = params
+
+ if self.arch == 'large' and i >= 12 or self.arch == 'small' and \
+ i >= 8:
+ mid_channels = mid_channels // self.reduction_factor
+ out_channels = out_channels // self.reduction_factor
+
+ if with_se:
+ se_cfg = dict(
+ channels=mid_channels,
+ ratio=4,
+ act_cfg=(dict(type='ReLU'),
+ dict(type='HSigmoid', bias=3.0, divisor=6.0)))
+ else:
+ se_cfg = None
+
+ layer = InvertedResidual(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ mid_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ se_cfg=se_cfg,
+ with_expand_conv=(in_channels != mid_channels),
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type=act),
+ with_cp=self.with_cp)
+ in_channels = out_channels
+ layer_name = 'layer{}'.format(i + 1)
+ self.add_module(layer_name, layer)
+ layers.append(layer_name)
+
+ # build the last layer
+ # block5 layer12 os=32 for small model
+ # block6 layer16 os=32 for large model
+ layer = ConvModule(
+ in_channels=in_channels,
+ out_channels=576 if self.arch == 'small' else 960,
+ kernel_size=1,
+ stride=1,
+ dilation=4,
+ padding=0,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=dict(type='HSwish'))
+ layer_name = 'layer{}'.format(len(layer_setting) + 1)
+ self.add_module(layer_name, layer)
+ layers.append(layer_name)
+
+ # next, convert backbone MobileNetV3 to a semantic segmentation version
+ if self.arch == 'small':
+ self.layer4.depthwise_conv.conv.stride = (1, 1)
+ self.layer9.depthwise_conv.conv.stride = (1, 1)
+ for i in range(4, len(layers)):
+ layer = getattr(self, layers[i])
+ if isinstance(layer, InvertedResidual):
+ modified_module = layer.depthwise_conv.conv
+ else:
+ modified_module = layer.conv
+
+ if i < 9:
+ modified_module.dilation = (2, 2)
+ pad = 2
+ else:
+ modified_module.dilation = (4, 4)
+ pad = 4
+
+ if not isinstance(modified_module, Conv2dAdaptivePadding):
+ # Adjust padding
+ pad *= (modified_module.kernel_size[0] - 1) // 2
+ modified_module.padding = (pad, pad)
+ else:
+ self.layer7.depthwise_conv.conv.stride = (1, 1)
+ self.layer13.depthwise_conv.conv.stride = (1, 1)
+ for i in range(7, len(layers)):
+ layer = getattr(self, layers[i])
+ if isinstance(layer, InvertedResidual):
+ modified_module = layer.depthwise_conv.conv
+ else:
+ modified_module = layer.conv
+
+ if i < 13:
+ modified_module.dilation = (2, 2)
+ pad = 2
+ else:
+ modified_module.dilation = (4, 4)
+ pad = 4
+
+ if not isinstance(modified_module, Conv2dAdaptivePadding):
+ # Adjust padding
+ pad *= (modified_module.kernel_size[0] - 1) // 2
+ modified_module.padding = (pad, pad)
+
+ return layers
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ outs = []
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return outs
+
+ def _freeze_stages(self):
+ for i in range(self.frozen_stages + 1):
+ layer = getattr(self, f'layer{i}')
+ layer.eval()
+ for param in layer.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(MobileNetV3, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/resnest.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/resnest.py
new file mode 100644
index 0000000000000000000000000000000000000000..076ef62195bac2a9660261446b5756c3880dfdf2
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/resnest.py
@@ -0,0 +1,314 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from annotator.mmpkg.mmcv.cnn import build_conv_layer, build_norm_layer
+
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNetV1d
+
+
+class RSoftmax(nn.Module):
+ """Radix Softmax module in ``SplitAttentionConv2d``.
+
+ Args:
+ radix (int): Radix of input.
+ groups (int): Groups of input.
+ """
+
+ def __init__(self, radix, groups):
+ super().__init__()
+ self.radix = radix
+ self.groups = groups
+
+ def forward(self, x):
+ batch = x.size(0)
+ if self.radix > 1:
+ x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
+ x = F.softmax(x, dim=1)
+ x = x.reshape(batch, -1)
+ else:
+ x = torch.sigmoid(x)
+ return x
+
+
+class SplitAttentionConv2d(nn.Module):
+ """Split-Attention Conv2d in ResNeSt.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int | tuple[int]): Same as nn.Conv2d.
+ stride (int | tuple[int]): Same as nn.Conv2d.
+ padding (int | tuple[int]): Same as nn.Conv2d.
+ dilation (int | tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels. Default: 4.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ dcn (dict): Config dict for DCN. Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ radix=2,
+ reduction_factor=4,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None):
+ super(SplitAttentionConv2d, self).__init__()
+ inter_channels = max(in_channels * radix // reduction_factor, 32)
+ self.radix = radix
+ self.groups = groups
+ self.channels = channels
+ self.with_dcn = dcn is not None
+ self.dcn = dcn
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if self.with_dcn and not fallback_on_stride:
+ assert conv_cfg is None, 'conv_cfg must be None for DCN'
+ conv_cfg = dcn
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ channels * radix,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups * radix,
+ bias=False)
+ self.norm0_name, norm0 = build_norm_layer(
+ norm_cfg, channels * radix, postfix=0)
+ self.add_module(self.norm0_name, norm0)
+ self.relu = nn.ReLU(inplace=True)
+ self.fc1 = build_conv_layer(
+ None, channels, inter_channels, 1, groups=self.groups)
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, inter_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.fc2 = build_conv_layer(
+ None, inter_channels, channels * radix, 1, groups=self.groups)
+ self.rsoftmax = RSoftmax(radix, groups)
+
+ @property
+ def norm0(self):
+ """nn.Module: the normalization layer named "norm0" """
+ return getattr(self, self.norm0_name)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm0(x)
+ x = self.relu(x)
+
+ batch, rchannel = x.shape[:2]
+ batch = x.size(0)
+ if self.radix > 1:
+ splits = x.view(batch, self.radix, -1, *x.shape[2:])
+ gap = splits.sum(dim=1)
+ else:
+ gap = x
+ gap = F.adaptive_avg_pool2d(gap, 1)
+ gap = self.fc1(gap)
+
+ gap = self.norm1(gap)
+ gap = self.relu(gap)
+
+ atten = self.fc2(gap)
+ atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
+
+ if self.radix > 1:
+ attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
+ out = torch.sum(attens * splits, dim=1)
+ else:
+ out = atten * x
+ return out.contiguous()
+
+
+class Bottleneck(_Bottleneck):
+ """Bottleneck block for ResNeSt.
+
+ Args:
+ inplane (int): Input planes of this block.
+ planes (int): Middle planes of this block.
+ groups (int): Groups of conv2.
+ width_per_group (int): Width per group of conv2. 64x4d indicates
+ ``groups=64, width_per_group=4`` and 32x8d indicates
+ ``groups=32, width_per_group=8``.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels in
+ SplitAttentionConv2d. Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ kwargs (dict): Key word arguments for base class.
+ """
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ groups=1,
+ base_width=4,
+ base_channels=64,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ """Bottleneck block for ResNeSt."""
+ super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+
+ if groups == 1:
+ width = self.planes
+ else:
+ width = math.floor(self.planes *
+ (base_width / base_channels)) * groups
+
+ self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width, postfix=1)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.with_modulated_dcn = False
+ self.conv2 = SplitAttentionConv2d(
+ width,
+ width,
+ kernel_size=3,
+ stride=1 if self.avg_down_stride else self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ radix=radix,
+ reduction_factor=reduction_factor,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ dcn=self.dcn)
+ delattr(self, self.norm2_name)
+
+ if self.avg_down_stride:
+ self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
+
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+
+ if self.avg_down_stride:
+ out = self.avd_layer(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+@BACKBONES.register_module()
+class ResNeSt(ResNetV1d):
+ """ResNeSt backbone.
+
+ Args:
+ groups (int): Number of groups of Bottleneck. Default: 1
+ base_width (int): Base width of Bottleneck. Default: 4
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels in
+ SplitAttentionConv2d. Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ kwargs (dict): Keyword arguments for ResNet.
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3)),
+ 200: (Bottleneck, (3, 24, 36, 3))
+ }
+
+ def __init__(self,
+ groups=1,
+ base_width=4,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ self.groups = groups
+ self.base_width = base_width
+ self.radix = radix
+ self.reduction_factor = reduction_factor
+ self.avg_down_stride = avg_down_stride
+ super(ResNeSt, self).__init__(**kwargs)
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``."""
+ return ResLayer(
+ groups=self.groups,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ radix=self.radix,
+ reduction_factor=self.reduction_factor,
+ avg_down_stride=self.avg_down_stride,
+ **kwargs)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/resnet.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3304dc5238110adcf21fa4c0a4e230158894fea
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/resnet.py
@@ -0,0 +1,688 @@
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from annotator.mmpkg.mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
+ constant_init, kaiming_init)
+from annotator.mmpkg.mmcv.runner import load_checkpoint
+from annotator.mmpkg.mmcv.utils.parrots_wrapper import _BatchNorm
+
+from annotator.mmpkg.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import ResLayer
+
+
+class BasicBlock(nn.Module):
+ """Basic block for ResNet."""
+
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None):
+ super(BasicBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes,
+ 3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ conv_cfg, planes, planes, 3, padding=1, bias=False)
+ self.add_module(self.norm2_name, norm2)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ """Bottleneck block for ResNet.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
+ "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None):
+ super(Bottleneck, self).__init__()
+ assert style in ['pytorch', 'caffe']
+ assert dcn is None or isinstance(dcn, dict)
+ assert plugins is None or isinstance(plugins, list)
+ if plugins is not None:
+ allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
+ assert all(p['position'] in allowed_position for p in plugins)
+
+ self.inplanes = inplanes
+ self.planes = planes
+ self.stride = stride
+ self.dilation = dilation
+ self.style = style
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.dcn = dcn
+ self.with_dcn = dcn is not None
+ self.plugins = plugins
+ self.with_plugins = plugins is not None
+
+ if self.with_plugins:
+ # collect plugins for conv1/conv2/conv3
+ self.after_conv1_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv1'
+ ]
+ self.after_conv2_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv2'
+ ]
+ self.after_conv3_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv3'
+ ]
+
+ if self.style == 'pytorch':
+ self.conv1_stride = 1
+ self.conv2_stride = stride
+ else:
+ self.conv1_stride = stride
+ self.conv2_stride = 1
+
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ norm_cfg, planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = dcn.pop('fallback_on_stride', False)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(
+ conv_cfg,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ self.conv2 = build_conv_layer(
+ dcn,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ conv_cfg,
+ planes,
+ planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+
+ if self.with_plugins:
+ self.after_conv1_plugin_names = self.make_block_plugins(
+ planes, self.after_conv1_plugins)
+ self.after_conv2_plugin_names = self.make_block_plugins(
+ planes, self.after_conv2_plugins)
+ self.after_conv3_plugin_names = self.make_block_plugins(
+ planes * self.expansion, self.after_conv3_plugins)
+
+ def make_block_plugins(self, in_channels, plugins):
+ """make plugins for block.
+
+ Args:
+ in_channels (int): Input channels of plugin.
+ plugins (list[dict]): List of plugins cfg to build.
+
+ Returns:
+ list[str]: List of the names of plugin.
+ """
+ assert isinstance(plugins, list)
+ plugin_names = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ name, layer = build_plugin_layer(
+ plugin,
+ in_channels=in_channels,
+ postfix=plugin.pop('postfix', ''))
+ assert not hasattr(self, name), f'duplicate plugin {name}'
+ self.add_module(name, layer)
+ plugin_names.append(name)
+ return plugin_names
+
+ def forward_plugin(self, x, plugin_names):
+ """Forward function for plugins."""
+ out = x
+ for name in plugin_names:
+ out = getattr(self, name)(x)
+ return out
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+
+ @property
+ def norm3(self):
+ """nn.Module: normalization layer after the third convolution layer"""
+ return getattr(self, self.norm3_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+@BACKBONES.register_module()
+class ResNet(nn.Module):
+ """ResNet backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ in_channels (int): Number of input image channels. Default" 3.
+ stem_channels (int): Number of stem channels. Default: 64.
+ base_channels (int): Number of base channels of res layer. Default: 64.
+ num_stages (int): Resnet stages, normally 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ plugins (list[dict]): List of plugins for stages, each dict contains:
+
+ - cfg (dict, required): Cfg dict to build plugin.
+
+ - position (str, required): Position inside block to insert plugin,
+ options: 'after_conv1', 'after_conv2', 'after_conv3'.
+
+ - stages (tuple[bool], optional): Stages to apply plugin, length
+ should be same as 'num_stages'
+ multi_grid (Sequence[int]|None): Multi grid dilation rates of last
+ stage. Default: None
+ contract_dilation (bool): Whether contract first dilation of each layer
+ Default: False
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+
+ Example:
+ >>> from annotator.mmpkg.mmseg.models import ResNet
+ >>> import torch
+ >>> self = ResNet(depth=18)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 64, 8, 8)
+ (1, 128, 4, 4)
+ (1, 256, 2, 2)
+ (1, 512, 1, 1)
+ """
+
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ depth,
+ in_channels=3,
+ stem_channels=64,
+ base_channels=64,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3),
+ style='pytorch',
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=False,
+ dcn=None,
+ stage_with_dcn=(False, False, False, False),
+ plugins=None,
+ multi_grid=None,
+ contract_dilation=False,
+ with_cp=False,
+ zero_init_residual=True):
+ super(ResNet, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+ self.depth = depth
+ self.stem_channels = stem_channels
+ self.base_channels = base_channels
+ self.num_stages = num_stages
+ assert num_stages >= 1 and num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.dcn = dcn
+ self.stage_with_dcn = stage_with_dcn
+ if dcn is not None:
+ assert len(stage_with_dcn) == num_stages
+ self.plugins = plugins
+ self.multi_grid = multi_grid
+ self.contract_dilation = contract_dilation
+ self.zero_init_residual = zero_init_residual
+ self.block, stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ self.inplanes = stem_channels
+
+ self._make_stem_layer(in_channels, stem_channels)
+
+ self.res_layers = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ dcn = self.dcn if self.stage_with_dcn[i] else None
+ if plugins is not None:
+ stage_plugins = self.make_stage_plugins(plugins, i)
+ else:
+ stage_plugins = None
+ # multi grid is applied to last layer only
+ stage_multi_grid = multi_grid if i == len(
+ self.stage_blocks) - 1 else None
+ planes = base_channels * 2**i
+ res_layer = self.make_res_layer(
+ block=self.block,
+ inplanes=self.inplanes,
+ planes=planes,
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ dcn=dcn,
+ plugins=stage_plugins,
+ multi_grid=stage_multi_grid,
+ contract_dilation=contract_dilation)
+ self.inplanes = planes * self.block.expansion
+ layer_name = f'layer{i+1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ self.feat_dim = self.block.expansion * base_channels * 2**(
+ len(self.stage_blocks) - 1)
+
+ def make_stage_plugins(self, plugins, stage_idx):
+ """make plugins for ResNet 'stage_idx'th stage .
+
+ Currently we support to insert 'context_block',
+ 'empirical_attention_block', 'nonlocal_block' into the backbone like
+ ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
+ Bottleneck.
+
+ An example of plugins format could be :
+ >>> plugins=[
+ ... dict(cfg=dict(type='xxx', arg1='xxx'),
+ ... stages=(False, True, True, True),
+ ... position='after_conv2'),
+ ... dict(cfg=dict(type='yyy'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='1'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='2'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3')
+ ... ]
+ >>> self = ResNet(depth=18)
+ >>> stage_plugins = self.make_stage_plugins(plugins, 0)
+ >>> assert len(stage_plugins) == 3
+
+ Suppose 'stage_idx=0', the structure of blocks in the stage would be:
+ conv1-> conv2->conv3->yyy->zzz1->zzz2
+ Suppose 'stage_idx=1', the structure of blocks in the stage would be:
+ conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
+
+ If stages is missing, the plugin would be applied to all stages.
+
+ Args:
+ plugins (list[dict]): List of plugins cfg to build. The postfix is
+ required if multiple same type plugins are inserted.
+ stage_idx (int): Index of stage to build
+
+ Returns:
+ list[dict]: Plugins for current stage
+ """
+ stage_plugins = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ stages = plugin.pop('stages', None)
+ assert stages is None or len(stages) == self.num_stages
+ # whether to insert plugin into current stage
+ if stages is None or stages[stage_idx]:
+ stage_plugins.append(plugin)
+
+ return stage_plugins
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``."""
+ return ResLayer(**kwargs)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ def _make_stem_layer(self, in_channels, stem_channels):
+ """Make stem layer for ResNet."""
+ if self.deep_stem:
+ self.stem = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ self.conv_cfg,
+ stem_channels // 2,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ self.conv_cfg,
+ stem_channels // 2,
+ stem_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels)[1],
+ nn.ReLU(inplace=True))
+ else:
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, stem_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def _freeze_stages(self):
+ """Freeze stages param and norm stats."""
+ if self.frozen_stages >= 0:
+ if self.deep_stem:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ else:
+ self.norm1.eval()
+ for m in [self.conv1, self.norm1]:
+ for param in m.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = getattr(self, f'layer{i}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+
+ if self.dcn is not None:
+ for m in self.modules():
+ if isinstance(m, Bottleneck) and hasattr(
+ m, 'conv2_offset'):
+ constant_init(m.conv2_offset, 0)
+
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ constant_init(m.norm3, 0)
+ elif isinstance(m, BasicBlock):
+ constant_init(m.norm2, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ """Forward function."""
+ if self.deep_stem:
+ x = self.stem(x)
+ else:
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep normalization layer
+ freezed."""
+ super(ResNet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+
+@BACKBONES.register_module()
+class ResNetV1c(ResNet):
+ """ResNetV1c variant described in [1]_.
+
+ Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
+ in the input stem with three 3x3 convs.
+
+ References:
+ .. [1] https://arxiv.org/pdf/1812.01187.pdf
+ """
+
+ def __init__(self, **kwargs):
+ super(ResNetV1c, self).__init__(
+ deep_stem=True, avg_down=False, **kwargs)
+
+
+@BACKBONES.register_module()
+class ResNetV1d(ResNet):
+ """ResNetV1d variant described in [1]_.
+
+ Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
+ the input stem with three 3x3 convs. And in the downsampling block, a 2x2
+ avg_pool with stride 2 is added before conv, whose stride is changed to 1.
+ """
+
+ def __init__(self, **kwargs):
+ super(ResNetV1d, self).__init__(
+ deep_stem=True, avg_down=True, **kwargs)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/resnext.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/resnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..be0194da1714e8431309a9dd8a42afebdbc1baf5
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/resnext.py
@@ -0,0 +1,145 @@
+import math
+
+from annotator.mmpkg.mmcv.cnn import build_conv_layer, build_norm_layer
+
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNet
+
+
+class Bottleneck(_Bottleneck):
+ """Bottleneck block for ResNeXt.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
+ "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+
+ def __init__(self,
+ inplanes,
+ planes,
+ groups=1,
+ base_width=4,
+ base_channels=64,
+ **kwargs):
+ super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+
+ if groups == 1:
+ width = self.planes
+ else:
+ width = math.floor(self.planes *
+ (base_width / base_channels)) * groups
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ self.norm_cfg, width, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ self.with_modulated_dcn = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ self.conv2 = build_conv_layer(
+ self.dcn,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+
+@BACKBONES.register_module()
+class ResNeXt(ResNet):
+ """ResNeXt backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ in_channels (int): Number of input image channels. Normally 3.
+ num_stages (int): Resnet stages, normally 4.
+ groups (int): Group of resnext.
+ base_width (int): Base width of resnext.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+
+ Example:
+ >>> from annotator.mmpkg.mmseg.models import ResNeXt
+ >>> import torch
+ >>> self = ResNeXt(depth=50)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 256, 8, 8)
+ (1, 512, 4, 4)
+ (1, 1024, 2, 2)
+ (1, 2048, 1, 1)
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self, groups=1, base_width=4, **kwargs):
+ self.groups = groups
+ self.base_width = base_width
+ super(ResNeXt, self).__init__(**kwargs)
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``"""
+ return ResLayer(
+ groups=self.groups,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ **kwargs)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/unet.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d19902ba273af02f8c9ce60f6632634633c1101
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/unet.py
@@ -0,0 +1,429 @@
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from annotator.mmpkg.mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
+ build_norm_layer, constant_init, kaiming_init)
+from annotator.mmpkg.mmcv.runner import load_checkpoint
+from annotator.mmpkg.mmcv.utils.parrots_wrapper import _BatchNorm
+
+from annotator.mmpkg.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import UpConvBlock
+
+
+class BasicConvBlock(nn.Module):
+ """Basic convolutional block for UNet.
+
+ This module consists of several plain convolutional layers.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ num_convs (int): Number of convolutional layers. Default: 2.
+ stride (int): Whether use stride convolution to downsample
+ the input feature map. If stride=2, it only uses stride convolution
+ in the first convolutional layer to downsample the input feature
+ map. Options are 1 or 2. Default: 1.
+ dilation (int): Whether use dilated convolution to expand the
+ receptive field. Set dilation rate of each convolutional layer and
+ the dilation rate of the first convolutional layer is always 1.
+ Default: 1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_convs=2,
+ stride=1,
+ dilation=1,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ dcn=None,
+ plugins=None):
+ super(BasicConvBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.with_cp = with_cp
+ convs = []
+ for i in range(num_convs):
+ convs.append(
+ ConvModule(
+ in_channels=in_channels if i == 0 else out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride if i == 0 else 1,
+ dilation=1 if i == 0 else dilation,
+ padding=1 if i == 0 else dilation,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ self.convs = nn.Sequential(*convs)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.convs, x)
+ else:
+ out = self.convs(x)
+ return out
+
+
+@UPSAMPLE_LAYERS.register_module()
+class DeconvModule(nn.Module):
+ """Deconvolution upsample module in decoder for UNet (2X upsample).
+
+ This module uses deconvolution to upsample feature map in the decoder
+ of UNet.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ kernel_size (int): Kernel size of the convolutional layer. Default: 4.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ with_cp=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ *,
+ kernel_size=4,
+ scale_factor=2):
+ super(DeconvModule, self).__init__()
+
+ assert (kernel_size - scale_factor >= 0) and\
+ (kernel_size - scale_factor) % 2 == 0,\
+ f'kernel_size should be greater than or equal to scale_factor '\
+ f'and (kernel_size - scale_factor) should be even numbers, '\
+ f'while the kernel size is {kernel_size} and scale_factor is '\
+ f'{scale_factor}.'
+
+ stride = scale_factor
+ padding = (kernel_size - scale_factor) // 2
+ self.with_cp = with_cp
+ deconv = nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+
+ norm_name, norm = build_norm_layer(norm_cfg, out_channels)
+ activate = build_activation_layer(act_cfg)
+ self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.deconv_upsamping, x)
+ else:
+ out = self.deconv_upsamping(x)
+ return out
+
+
+@UPSAMPLE_LAYERS.register_module()
+class InterpConv(nn.Module):
+ """Interpolation upsample module in decoder for UNet.
+
+ This module uses interpolation to upsample feature map in the decoder
+ of UNet. It consists of one interpolation upsample layer and one
+ convolutional layer. It can be one interpolation upsample layer followed
+ by one convolutional layer (conv_first=False) or one convolutional layer
+ followed by one interpolation upsample layer (conv_first=True).
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ conv_first (bool): Whether convolutional layer or interpolation
+ upsample layer first. Default: False. It means interpolation
+ upsample layer followed by one convolutional layer.
+ kernel_size (int): Kernel size of the convolutional layer. Default: 1.
+ stride (int): Stride of the convolutional layer. Default: 1.
+ padding (int): Padding of the convolutional layer. Default: 1.
+ upsample_cfg (dict): Interpolation config of the upsample layer.
+ Default: dict(
+ scale_factor=2, mode='bilinear', align_corners=False).
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ with_cp=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ *,
+ conv_cfg=None,
+ conv_first=False,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ upsample_cfg=dict(
+ scale_factor=2, mode='bilinear', align_corners=False)):
+ super(InterpConv, self).__init__()
+
+ self.with_cp = with_cp
+ conv = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ upsample = nn.Upsample(**upsample_cfg)
+ if conv_first:
+ self.interp_upsample = nn.Sequential(conv, upsample)
+ else:
+ self.interp_upsample = nn.Sequential(upsample, conv)
+
+ def forward(self, x):
+ """Forward function."""
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(self.interp_upsample, x)
+ else:
+ out = self.interp_upsample(x)
+ return out
+
+
+@BACKBONES.register_module()
+class UNet(nn.Module):
+ """UNet backbone.
+ U-Net: Convolutional Networks for Biomedical Image Segmentation.
+ https://arxiv.org/pdf/1505.04597.pdf
+
+ Args:
+ in_channels (int): Number of input image channels. Default" 3.
+ base_channels (int): Number of base channels of each stage.
+ The output channels of the first stage. Default: 64.
+ num_stages (int): Number of stages in encoder, normally 5. Default: 5.
+ strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
+ len(strides) is equal to num_stages. Normally the stride of the
+ first stage in encoder is 1. If strides[i]=2, it uses stride
+ convolution to downsample in the correspondence encoder stage.
+ Default: (1, 1, 1, 1, 1).
+ enc_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondence encoder stage.
+ Default: (2, 2, 2, 2, 2).
+ dec_num_convs (Sequence[int]): Number of convolutional layers in the
+ convolution block of the correspondence decoder stage.
+ Default: (2, 2, 2, 2).
+ downsamples (Sequence[int]): Whether use MaxPool to downsample the
+ feature map after the first stage of encoder
+ (stages: [1, num_stages)). If the correspondence encoder stage use
+ stride convolution (strides[i]=2), it will never use MaxPool to
+ downsample, even downsamples[i-1]=True.
+ Default: (True, True, True, True).
+ enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
+ Default: (1, 1, 1, 1, 1).
+ dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
+ Default: (1, 1, 1, 1).
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ upsample_cfg (dict): The upsample config of the upsample module in
+ decoder. Default: dict(type='InterpConv').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+
+ Notice:
+ The input image size should be divisible by the whole downsample rate
+ of the encoder. More detail of the whole downsample rate can be found
+ in UNet._check_input_divisible.
+
+ """
+
+ def __init__(self,
+ in_channels=3,
+ base_channels=64,
+ num_stages=5,
+ strides=(1, 1, 1, 1, 1),
+ enc_num_convs=(2, 2, 2, 2, 2),
+ dec_num_convs=(2, 2, 2, 2),
+ downsamples=(True, True, True, True),
+ enc_dilations=(1, 1, 1, 1, 1),
+ dec_dilations=(1, 1, 1, 1),
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ norm_eval=False,
+ dcn=None,
+ plugins=None):
+ super(UNet, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+ assert len(strides) == num_stages, \
+ 'The length of strides should be equal to num_stages, '\
+ f'while the strides is {strides}, the length of '\
+ f'strides is {len(strides)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(enc_num_convs) == num_stages, \
+ 'The length of enc_num_convs should be equal to num_stages, '\
+ f'while the enc_num_convs is {enc_num_convs}, the length of '\
+ f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(dec_num_convs) == (num_stages-1), \
+ 'The length of dec_num_convs should be equal to (num_stages-1), '\
+ f'while the dec_num_convs is {dec_num_convs}, the length of '\
+ f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(downsamples) == (num_stages-1), \
+ 'The length of downsamples should be equal to (num_stages-1), '\
+ f'while the downsamples is {downsamples}, the length of '\
+ f'downsamples is {len(downsamples)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(enc_dilations) == num_stages, \
+ 'The length of enc_dilations should be equal to num_stages, '\
+ f'while the enc_dilations is {enc_dilations}, the length of '\
+ f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
+ f'{num_stages}.'
+ assert len(dec_dilations) == (num_stages-1), \
+ 'The length of dec_dilations should be equal to (num_stages-1), '\
+ f'while the dec_dilations is {dec_dilations}, the length of '\
+ f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
+ f'{num_stages}.'
+ self.num_stages = num_stages
+ self.strides = strides
+ self.downsamples = downsamples
+ self.norm_eval = norm_eval
+ self.base_channels = base_channels
+
+ self.encoder = nn.ModuleList()
+ self.decoder = nn.ModuleList()
+
+ for i in range(num_stages):
+ enc_conv_block = []
+ if i != 0:
+ if strides[i] == 1 and downsamples[i - 1]:
+ enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
+ upsample = (strides[i] != 1 or downsamples[i - 1])
+ self.decoder.append(
+ UpConvBlock(
+ conv_block=BasicConvBlock,
+ in_channels=base_channels * 2**i,
+ skip_channels=base_channels * 2**(i - 1),
+ out_channels=base_channels * 2**(i - 1),
+ num_convs=dec_num_convs[i - 1],
+ stride=1,
+ dilation=dec_dilations[i - 1],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ upsample_cfg=upsample_cfg if upsample else None,
+ dcn=None,
+ plugins=None))
+
+ enc_conv_block.append(
+ BasicConvBlock(
+ in_channels=in_channels,
+ out_channels=base_channels * 2**i,
+ num_convs=enc_num_convs[i],
+ stride=strides[i],
+ dilation=enc_dilations[i],
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dcn=None,
+ plugins=None))
+ self.encoder.append((nn.Sequential(*enc_conv_block)))
+ in_channels = base_channels * 2**i
+
+ def forward(self, x):
+ self._check_input_divisible(x)
+ enc_outs = []
+ for enc in self.encoder:
+ x = enc(x)
+ enc_outs.append(x)
+ dec_outs = [x]
+ for i in reversed(range(len(self.decoder))):
+ x = self.decoder[i](enc_outs[i], x)
+ dec_outs.append(x)
+
+ return dec_outs
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep normalization layer
+ freezed."""
+ super(UNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ def _check_input_divisible(self, x):
+ h, w = x.shape[-2:]
+ whole_downsample_rate = 1
+ for i in range(1, self.num_stages):
+ if self.strides[i] == 2 or self.downsamples[i - 1]:
+ whole_downsample_rate *= 2
+ assert (h % whole_downsample_rate == 0) \
+ and (w % whole_downsample_rate == 0),\
+ f'The input image size {(h, w)} should be divisible by the whole '\
+ f'downsample rate {whole_downsample_rate}, when num_stages is '\
+ f'{self.num_stages}, strides is {self.strides}, and downsamples '\
+ f'is {self.downsamples}.'
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/vit.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab1a393741b21c8185f4204946b751b1913ef98c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/backbones/vit.py
@@ -0,0 +1,459 @@
+"""Modified from https://github.com/rwightman/pytorch-image-
+models/blob/master/timm/models/vision_transformer.py."""
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from annotator.mmpkg.mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
+ constant_init, kaiming_init, normal_init)
+from annotator.mmpkg.mmcv.runner import _load_checkpoint
+from annotator.mmpkg.mmcv.utils.parrots_wrapper import _BatchNorm
+
+from annotator.mmpkg.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import DropPath, trunc_normal_
+
+
+class Mlp(nn.Module):
+ """MLP layer for Encoder block.
+
+ Args:
+ in_features(int): Input dimension for the first fully
+ connected layer.
+ hidden_features(int): Output dimension for the first fully
+ connected layer.
+ out_features(int): Output dementsion for the second fully
+ connected layer.
+ act_cfg(dict): Config dict for activation layer.
+ Default: dict(type='GELU').
+ drop(float): Drop rate for the dropout layer. Dropout rate has
+ to be between 0 and 1. Default: 0.
+ """
+
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_cfg=dict(type='GELU'),
+ drop=0.):
+ super(Mlp, self).__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = Linear(in_features, hidden_features)
+ self.act = build_activation_layer(act_cfg)
+ self.fc2 = Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ """Attention layer for Encoder block.
+
+ Args:
+ dim (int): Dimension for the input vector.
+ num_heads (int): Number of parallel attention heads.
+ qkv_bias (bool): Enable bias for qkv if True. Default: False.
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ attn_drop (float): Drop rate for attention output weights.
+ Default: 0.
+ proj_drop (float): Drop rate for output weights. Default: 0.
+ """
+
+ def __init__(self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.,
+ proj_drop=0.):
+ super(Attention, self).__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ b, n, c = x.shape
+ qkv = self.qkv(x).reshape(b, n, 3, self.num_heads,
+ c // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(b, n, c)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ """Implements encoder block with residual connection.
+
+ Args:
+ dim (int): The feature dimension.
+ num_heads (int): Number of parallel attention heads.
+ mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float): Drop rate for mlp output weights. Default: 0.
+ attn_drop (float): Drop rate for attention output weights.
+ Default: 0.
+ proj_drop (float): Drop rate for attn layer output weights.
+ Default: 0.
+ drop_path (float): Drop rate for paths of model.
+ Default: 0.
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN', requires_grad=True).
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ """
+
+ def __init__(self,
+ dim,
+ num_heads,
+ mlp_ratio=4,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ proj_drop=0.,
+ drop_path=0.,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN', eps=1e-6),
+ with_cp=False):
+ super(Block, self).__init__()
+ self.with_cp = with_cp
+ _, self.norm1 = build_norm_layer(norm_cfg, dim)
+ self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop,
+ proj_drop)
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+ _, self.norm2 = build_norm_layer(norm_cfg, dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_cfg=act_cfg,
+ drop=drop)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ out = x + self.drop_path(self.attn(self.norm1(x)))
+ out = out + self.drop_path(self.mlp(self.norm2(out)))
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+class PatchEmbed(nn.Module):
+ """Image to Patch Embedding.
+
+ Args:
+ img_size (int | tuple): Input image size.
+ default: 224.
+ patch_size (int): Width and height for a patch.
+ default: 16.
+ in_channels (int): Input channels for images. Default: 3.
+ embed_dim (int): The embedding dimension. Default: 768.
+ """
+
+ def __init__(self,
+ img_size=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768):
+ super(PatchEmbed, self).__init__()
+ if isinstance(img_size, int):
+ self.img_size = (img_size, img_size)
+ elif isinstance(img_size, tuple):
+ self.img_size = img_size
+ else:
+ raise TypeError('img_size must be type of int or tuple')
+ h, w = self.img_size
+ self.patch_size = (patch_size, patch_size)
+ self.num_patches = (h // patch_size) * (w // patch_size)
+ self.proj = Conv2d(
+ in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ return self.proj(x).flatten(2).transpose(1, 2)
+
+
+@BACKBONES.register_module()
+class VisionTransformer(nn.Module):
+ """Vision transformer backbone.
+
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for
+ Image Recognition at Scale` - https://arxiv.org/abs/2010.11929
+
+ Args:
+ img_size (tuple): input image size. Default: (224, 224).
+ patch_size (int, tuple): patch size. Default: 16.
+ in_channels (int): number of input channels. Default: 3.
+ embed_dim (int): embedding dimension. Default: 768.
+ depth (int): depth of transformer. Default: 12.
+ num_heads (int): number of attention heads. Default: 12.
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
+ Default: 4.
+ out_indices (list | tuple | int): Output from which stages.
+ Default: -1.
+ qkv_bias (bool): enable bias for qkv if True. Default: True.
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set.
+ drop_rate (float): dropout rate. Default: 0.
+ attn_drop_rate (float): attention dropout rate. Default: 0.
+ drop_path_rate (float): Rate of DropPath. Default: 0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN', eps=1e-6, requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='GELU').
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ final_norm (bool): Whether to add a additional layer to normalize
+ final feature map. Default: False.
+ interpolate_mode (str): Select the interpolate mode for position
+ embeding vector resize. Default: bicubic.
+ with_cls_token (bool): If concatenating class token into image tokens
+ as transformer input. Default: True.
+ with_cp (bool): Use checkpoint or not. Using checkpoint
+ will save some memory while slowing down the training speed.
+ Default: False.
+ """
+
+ def __init__(self,
+ img_size=(224, 224),
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ out_indices=11,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_cfg=dict(type='LN', eps=1e-6, requires_grad=True),
+ act_cfg=dict(type='GELU'),
+ norm_eval=False,
+ final_norm=False,
+ with_cls_token=True,
+ interpolate_mode='bicubic',
+ with_cp=False):
+ super(VisionTransformer, self).__init__()
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.features = self.embed_dim = embed_dim
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=embed_dim)
+
+ self.with_cls_token = with_cls_token
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ if isinstance(out_indices, int):
+ self.out_indices = [out_indices]
+ elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
+ self.out_indices = out_indices
+ else:
+ raise TypeError('out_indices must be type of int, list or tuple')
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=dpr[i],
+ attn_drop=attn_drop_rate,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ with_cp=with_cp) for i in range(depth)
+ ])
+
+ self.interpolate_mode = interpolate_mode
+ self.final_norm = final_norm
+ if final_norm:
+ _, self.norm = build_norm_layer(norm_cfg, embed_dim)
+
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ checkpoint = _load_checkpoint(pretrained, logger=logger)
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+
+ if 'pos_embed' in state_dict.keys():
+ if self.pos_embed.shape != state_dict['pos_embed'].shape:
+ logger.info(msg=f'Resize the pos_embed shape from \
+{state_dict["pos_embed"].shape} to {self.pos_embed.shape}')
+ h, w = self.img_size
+ pos_size = int(
+ math.sqrt(state_dict['pos_embed'].shape[1] - 1))
+ state_dict['pos_embed'] = self.resize_pos_embed(
+ state_dict['pos_embed'], (h, w), (pos_size, pos_size),
+ self.patch_size, self.interpolate_mode)
+
+ self.load_state_dict(state_dict, False)
+
+ elif pretrained is None:
+ # We only implement the 'jax_impl' initialization implemented at
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ for n, m in self.named_modules():
+ if isinstance(m, Linear):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ if 'mlp' in n:
+ normal_init(m.bias, std=1e-6)
+ else:
+ constant_init(m.bias, 0)
+ elif isinstance(m, Conv2d):
+ kaiming_init(m.weight, mode='fan_in')
+ if m.bias is not None:
+ constant_init(m.bias, 0)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
+ constant_init(m.bias, 0)
+ constant_init(m.weight, 1.0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def _pos_embeding(self, img, patched_img, pos_embed):
+ """Positiong embeding method.
+
+ Resize the pos_embed, if the input image size doesn't match
+ the training size.
+ Args:
+ img (torch.Tensor): The inference image tensor, the shape
+ must be [B, C, H, W].
+ patched_img (torch.Tensor): The patched image, it should be
+ shape of [B, L1, C].
+ pos_embed (torch.Tensor): The pos_embed weighs, it should be
+ shape of [B, L2, c].
+ Return:
+ torch.Tensor: The pos encoded image feature.
+ """
+ assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
+ 'the shapes of patched_img and pos_embed must be [B, L, C]'
+ x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
+ if x_len != pos_len:
+ if pos_len == (self.img_size[0] // self.patch_size) * (
+ self.img_size[1] // self.patch_size) + 1:
+ pos_h = self.img_size[0] // self.patch_size
+ pos_w = self.img_size[1] // self.patch_size
+ else:
+ raise ValueError(
+ 'Unexpected shape of pos_embed, got {}.'.format(
+ pos_embed.shape))
+ pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
+ (pos_h, pos_w), self.patch_size,
+ self.interpolate_mode)
+ return self.pos_drop(patched_img + pos_embed)
+
+ @staticmethod
+ def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
+ """Resize pos_embed weights.
+
+ Resize pos_embed using bicubic interpolate method.
+ Args:
+ pos_embed (torch.Tensor): pos_embed weights.
+ input_shpae (tuple): Tuple for (input_h, intput_w).
+ pos_shape (tuple): Tuple for (pos_h, pos_w).
+ patch_size (int): Patch size.
+ Return:
+ torch.Tensor: The resized pos_embed of shape [B, L_new, C]
+ """
+ assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
+ input_h, input_w = input_shpae
+ pos_h, pos_w = pos_shape
+ cls_token_weight = pos_embed[:, 0]
+ pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
+ pos_embed_weight = pos_embed_weight.reshape(
+ 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
+ pos_embed_weight = F.interpolate(
+ pos_embed_weight,
+ size=[input_h // patch_size, input_w // patch_size],
+ align_corners=False,
+ mode=mode)
+ cls_token_weight = cls_token_weight.unsqueeze(1)
+ pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
+ pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
+ return pos_embed
+
+ def forward(self, inputs):
+ B = inputs.shape[0]
+
+ x = self.patch_embed(inputs)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = self._pos_embeding(inputs, x, self.pos_embed)
+
+ if not self.with_cls_token:
+ # Remove class token for transformer input
+ x = x[:, 1:]
+
+ outs = []
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i == len(self.blocks) - 1:
+ if self.final_norm:
+ x = self.norm(x)
+ if i in self.out_indices:
+ if self.with_cls_token:
+ # Remove class token and reshape token for decoder head
+ out = x[:, 1:]
+ else:
+ out = x
+ B, _, C = out.shape
+ out = out.reshape(B, inputs.shape[2] // self.patch_size,
+ inputs.shape[3] // self.patch_size,
+ C).permute(0, 3, 1, 2)
+ outs.append(out)
+
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super(VisionTransformer, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, nn.LayerNorm):
+ m.eval()
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/builder.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd29ff66d523b854c739b580137db6f4155fc550
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/builder.py
@@ -0,0 +1,46 @@
+import warnings
+
+from annotator.mmpkg.mmcv.cnn import MODELS as MMCV_MODELS
+from annotator.mmpkg.mmcv.utils import Registry
+
+MODELS = Registry('models', parent=MMCV_MODELS)
+
+BACKBONES = MODELS
+NECKS = MODELS
+HEADS = MODELS
+LOSSES = MODELS
+SEGMENTORS = MODELS
+
+
+def build_backbone(cfg):
+ """Build backbone."""
+ return BACKBONES.build(cfg)
+
+
+def build_neck(cfg):
+ """Build neck."""
+ return NECKS.build(cfg)
+
+
+def build_head(cfg):
+ """Build head."""
+ return HEADS.build(cfg)
+
+
+def build_loss(cfg):
+ """Build loss."""
+ return LOSSES.build(cfg)
+
+
+def build_segmentor(cfg, train_cfg=None, test_cfg=None):
+ """Build segmentor."""
+ if train_cfg is not None or test_cfg is not None:
+ warnings.warn(
+ 'train_cfg and test_cfg is deprecated, '
+ 'please specify them in model', UserWarning)
+ assert cfg.get('train_cfg') is None or train_cfg is None, \
+ 'train_cfg specified in both outer field and model field '
+ assert cfg.get('test_cfg') is None or test_cfg is None, \
+ 'test_cfg specified in both outer field and model field '
+ return SEGMENTORS.build(
+ cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac66d3cfe0ea04af45c0f3594bf135841c3812e3
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/__init__.py
@@ -0,0 +1,28 @@
+from .ann_head import ANNHead
+from .apc_head import APCHead
+from .aspp_head import ASPPHead
+from .cc_head import CCHead
+from .da_head import DAHead
+from .dm_head import DMHead
+from .dnl_head import DNLHead
+from .ema_head import EMAHead
+from .enc_head import EncHead
+from .fcn_head import FCNHead
+from .fpn_head import FPNHead
+from .gc_head import GCHead
+from .lraspp_head import LRASPPHead
+from .nl_head import NLHead
+from .ocr_head import OCRHead
+# from .point_head import PointHead
+from .psa_head import PSAHead
+from .psp_head import PSPHead
+from .sep_aspp_head import DepthwiseSeparableASPPHead
+from .sep_fcn_head import DepthwiseSeparableFCNHead
+from .uper_head import UPerHead
+
+__all__ = [
+ 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
+ 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
+ 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
+ 'APCHead', 'DMHead', 'LRASPPHead'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/ann_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/ann_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..958c88e0ca4b9acdaf146b836462b9a101b2cdad
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/ann_head.py
@@ -0,0 +1,245 @@
+import torch
+import torch.nn as nn
+from annotator.mmpkg.mmcv.cnn import ConvModule
+
+from ..builder import HEADS
+from ..utils import SelfAttentionBlock as _SelfAttentionBlock
+from .decode_head import BaseDecodeHead
+
+
+class PPMConcat(nn.ModuleList):
+ """Pyramid Pooling Module that only concat the features of each layer.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module.
+ """
+
+ def __init__(self, pool_scales=(1, 3, 6, 8)):
+ super(PPMConcat, self).__init__(
+ [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])
+
+ def forward(self, feats):
+ """Forward function."""
+ ppm_outs = []
+ for ppm in self:
+ ppm_out = ppm(feats)
+ ppm_outs.append(ppm_out.view(*feats.shape[:2], -1))
+ concat_outs = torch.cat(ppm_outs, dim=2)
+ return concat_outs
+
+
+class SelfAttentionBlock(_SelfAttentionBlock):
+ """Make a ANN used SelfAttentionBlock.
+
+ Args:
+ low_in_channels (int): Input channels of lower level feature,
+ which is the key feature for self-attention.
+ high_in_channels (int): Input channels of higher level feature,
+ which is the query feature for self-attention.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ share_key_query (bool): Whether share projection weight between key
+ and query projection.
+ query_scale (int): The scale of query feature map.
+ key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module of key feature.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+
+ def __init__(self, low_in_channels, high_in_channels, channels,
+ out_channels, share_key_query, query_scale, key_pool_scales,
+ conv_cfg, norm_cfg, act_cfg):
+ key_psp = PPMConcat(key_pool_scales)
+ if query_scale > 1:
+ query_downsample = nn.MaxPool2d(kernel_size=query_scale)
+ else:
+ query_downsample = None
+ super(SelfAttentionBlock, self).__init__(
+ key_in_channels=low_in_channels,
+ query_in_channels=high_in_channels,
+ channels=channels,
+ out_channels=out_channels,
+ share_key_query=share_key_query,
+ query_downsample=query_downsample,
+ key_downsample=key_psp,
+ key_query_num_convs=1,
+ key_query_norm=True,
+ value_out_num_convs=1,
+ value_out_norm=False,
+ matmul_norm=True,
+ with_out=True,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+
+class AFNB(nn.Module):
+ """Asymmetric Fusion Non-local Block(AFNB)
+
+ Args:
+ low_in_channels (int): Input channels of lower level feature,
+ which is the key feature for self-attention.
+ high_in_channels (int): Input channels of higher level feature,
+ which is the query feature for self-attention.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ and query projection.
+ query_scales (tuple[int]): The scales of query feature map.
+ Default: (1,)
+ key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module of key feature.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+
+ def __init__(self, low_in_channels, high_in_channels, channels,
+ out_channels, query_scales, key_pool_scales, conv_cfg,
+ norm_cfg, act_cfg):
+ super(AFNB, self).__init__()
+ self.stages = nn.ModuleList()
+ for query_scale in query_scales:
+ self.stages.append(
+ SelfAttentionBlock(
+ low_in_channels=low_in_channels,
+ high_in_channels=high_in_channels,
+ channels=channels,
+ out_channels=out_channels,
+ share_key_query=False,
+ query_scale=query_scale,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ self.bottleneck = ConvModule(
+ out_channels + high_in_channels,
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ def forward(self, low_feats, high_feats):
+ """Forward function."""
+ priors = [stage(high_feats, low_feats) for stage in self.stages]
+ context = torch.stack(priors, dim=0).sum(dim=0)
+ output = self.bottleneck(torch.cat([context, high_feats], 1))
+ return output
+
+
+class APNB(nn.Module):
+ """Asymmetric Pyramid Non-local Block (APNB)
+
+ Args:
+ in_channels (int): Input channels of key/query feature,
+ which is the key feature for self-attention.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ query_scales (tuple[int]): The scales of query feature map.
+ Default: (1,)
+ key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module of key feature.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+
+ def __init__(self, in_channels, channels, out_channels, query_scales,
+ key_pool_scales, conv_cfg, norm_cfg, act_cfg):
+ super(APNB, self).__init__()
+ self.stages = nn.ModuleList()
+ for query_scale in query_scales:
+ self.stages.append(
+ SelfAttentionBlock(
+ low_in_channels=in_channels,
+ high_in_channels=in_channels,
+ channels=channels,
+ out_channels=out_channels,
+ share_key_query=True,
+ query_scale=query_scale,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ self.bottleneck = ConvModule(
+ 2 * in_channels,
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ def forward(self, feats):
+ """Forward function."""
+ priors = [stage(feats, feats) for stage in self.stages]
+ context = torch.stack(priors, dim=0).sum(dim=0)
+ output = self.bottleneck(torch.cat([context, feats], 1))
+ return output
+
+
+@HEADS.register_module()
+class ANNHead(BaseDecodeHead):
+ """Asymmetric Non-local Neural Networks for Semantic Segmentation.
+
+ This head is the implementation of `ANNNet
+ `_.
+
+ Args:
+ project_channels (int): Projection channels for Nonlocal.
+ query_scales (tuple[int]): The scales of query feature map.
+ Default: (1,)
+ key_pool_scales (tuple[int]): The pooling scales of key feature map.
+ Default: (1, 3, 6, 8).
+ """
+
+ def __init__(self,
+ project_channels,
+ query_scales=(1, ),
+ key_pool_scales=(1, 3, 6, 8),
+ **kwargs):
+ super(ANNHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ assert len(self.in_channels) == 2
+ low_in_channels, high_in_channels = self.in_channels
+ self.project_channels = project_channels
+ self.fusion = AFNB(
+ low_in_channels=low_in_channels,
+ high_in_channels=high_in_channels,
+ out_channels=high_in_channels,
+ channels=project_channels,
+ query_scales=query_scales,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.bottleneck = ConvModule(
+ high_in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.context = APNB(
+ in_channels=self.channels,
+ out_channels=self.channels,
+ channels=project_channels,
+ query_scales=query_scales,
+ key_pool_scales=key_pool_scales,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ low_feats, high_feats = self._transform_inputs(inputs)
+ output = self.fusion(low_feats, high_feats)
+ output = self.dropout(output)
+ output = self.bottleneck(output)
+ output = self.context(output)
+ output = self.cls_seg(output)
+
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/apc_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/apc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f363dba391c3eb6fb5a4d61c145fd4976a5717d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/apc_head.py
@@ -0,0 +1,158 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.mmpkg.mmcv.cnn import ConvModule
+
+from annotator.mmpkg.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class ACM(nn.Module):
+ """Adaptive Context Module used in APCNet.
+
+ Args:
+ pool_scale (int): Pooling scale used in Adaptive Context
+ Module to extract region features.
+ fusion (bool): Add one conv to fuse residual feature.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict | None): Config of conv layers.
+ norm_cfg (dict | None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+
+ def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
+ norm_cfg, act_cfg):
+ super(ACM, self).__init__()
+ self.pool_scale = pool_scale
+ self.fusion = fusion
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.pooled_redu_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.input_redu_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.global_info = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
+
+ self.residual_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ if self.fusion:
+ self.fusion_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, x):
+ """Forward function."""
+ pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
+ # [batch_size, channels, h, w]
+ x = self.input_redu_conv(x)
+ # [batch_size, channels, pool_scale, pool_scale]
+ pooled_x = self.pooled_redu_conv(pooled_x)
+ batch_size = x.size(0)
+ # [batch_size, pool_scale * pool_scale, channels]
+ pooled_x = pooled_x.view(batch_size, self.channels,
+ -1).permute(0, 2, 1).contiguous()
+ # [batch_size, h * w, pool_scale * pool_scale]
+ affinity_matrix = self.gla(x + resize(
+ self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
+ ).permute(0, 2, 3, 1).reshape(
+ batch_size, -1, self.pool_scale**2)
+ affinity_matrix = F.sigmoid(affinity_matrix)
+ # [batch_size, h * w, channels]
+ z_out = torch.matmul(affinity_matrix, pooled_x)
+ # [batch_size, channels, h * w]
+ z_out = z_out.permute(0, 2, 1).contiguous()
+ # [batch_size, channels, h, w]
+ z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
+ z_out = self.residual_conv(z_out)
+ z_out = F.relu(z_out + x)
+ if self.fusion:
+ z_out = self.fusion_conv(z_out)
+
+ return z_out
+
+
+@HEADS.register_module()
+class APCHead(BaseDecodeHead):
+ """Adaptive Pyramid Context Network for Semantic Segmentation.
+
+ This head is the implementation of
+ `APCNet `_.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Adaptive Context
+ Module. Default: (1, 2, 3, 6).
+ fusion (bool): Add one conv to fuse residual feature.
+ """
+
+ def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):
+ super(APCHead, self).__init__(**kwargs)
+ assert isinstance(pool_scales, (list, tuple))
+ self.pool_scales = pool_scales
+ self.fusion = fusion
+ acm_modules = []
+ for pool_scale in self.pool_scales:
+ acm_modules.append(
+ ACM(pool_scale,
+ self.fusion,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.acm_modules = nn.ModuleList(acm_modules)
+ self.bottleneck = ConvModule(
+ self.in_channels + len(pool_scales) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ acm_outs = [x]
+ for acm_module in self.acm_modules:
+ acm_outs.append(acm_module(x))
+ acm_outs = torch.cat(acm_outs, dim=1)
+ output = self.bottleneck(acm_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/aspp_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/aspp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c0aadb2b097a604d96ba1c99c05663b7884b6e0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/aspp_head.py
@@ -0,0 +1,107 @@
+import torch
+import torch.nn as nn
+from annotator.mmpkg.mmcv.cnn import ConvModule
+
+from annotator.mmpkg.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class ASPPModule(nn.ModuleList):
+ """Atrous Spatial Pyramid Pooling (ASPP) Module.
+
+ Args:
+ dilations (tuple[int]): Dilation rate of each layer.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+
+ def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
+ act_cfg):
+ super(ASPPModule, self).__init__()
+ self.dilations = dilations
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ for dilation in dilations:
+ self.append(
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ 1 if dilation == 1 else 3,
+ dilation=dilation,
+ padding=0 if dilation == 1 else dilation,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+
+ def forward(self, x):
+ """Forward function."""
+ aspp_outs = []
+ for aspp_module in self:
+ aspp_outs.append(aspp_module(x))
+
+ return aspp_outs
+
+
+@HEADS.register_module()
+class ASPPHead(BaseDecodeHead):
+ """Rethinking Atrous Convolution for Semantic Image Segmentation.
+
+ This head is the implementation of `DeepLabV3
+ `_.
+
+ Args:
+ dilations (tuple[int]): Dilation rates for ASPP module.
+ Default: (1, 6, 12, 18).
+ """
+
+ def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
+ super(ASPPHead, self).__init__(**kwargs)
+ assert isinstance(dilations, (list, tuple))
+ self.dilations = dilations
+ self.image_pool = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.aspp_modules = ASPPModule(
+ dilations,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.bottleneck = ConvModule(
+ (len(dilations) + 1) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ aspp_outs = [
+ resize(
+ self.image_pool(x),
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ ]
+ aspp_outs.extend(self.aspp_modules(x))
+ aspp_outs = torch.cat(aspp_outs, dim=1)
+ output = self.bottleneck(aspp_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/cascade_decode_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/cascade_decode_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..d02122ca0e68743b1bf7a893afae96042f23838c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/cascade_decode_head.py
@@ -0,0 +1,57 @@
+from abc import ABCMeta, abstractmethod
+
+from .decode_head import BaseDecodeHead
+
+
+class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
+ """Base class for cascade decode head used in
+ :class:`CascadeEncoderDecoder."""
+
+ def __init__(self, *args, **kwargs):
+ super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs)
+
+ @abstractmethod
+ def forward(self, inputs, prev_output):
+ """Placeholder of forward function."""
+ pass
+
+ def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
+ train_cfg):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ seg_logits = self.forward(inputs, prev_output)
+ losses = self.losses(seg_logits, gt_semantic_seg)
+
+ return losses
+
+ def forward_test(self, inputs, prev_output, img_metas, test_cfg):
+ """Forward function for testing.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+
+ Returns:
+ Tensor: Output segmentation map.
+ """
+ return self.forward(inputs, prev_output)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/cc_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/cc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f4f5b052445a4071952aa04274274da7d897c2c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/cc_head.py
@@ -0,0 +1,45 @@
+import torch
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+try:
+ try:
+ from mmcv.ops import CrissCrossAttention
+ except ImportError:
+ from annotator.mmpkg.mmcv.ops import CrissCrossAttention
+except ModuleNotFoundError:
+ CrissCrossAttention = None
+
+
+@HEADS.register_module()
+class CCHead(FCNHead):
+ """CCNet: Criss-Cross Attention for Semantic Segmentation.
+
+ This head is the implementation of `CCNet
+ `_.
+
+ Args:
+ recurrence (int): Number of recurrence of Criss Cross Attention
+ module. Default: 2.
+ """
+
+ def __init__(self, recurrence=2, **kwargs):
+ if CrissCrossAttention is None:
+ raise RuntimeError('Please install mmcv-full for '
+ 'CrissCrossAttention ops')
+ super(CCHead, self).__init__(num_convs=2, **kwargs)
+ self.recurrence = recurrence
+ self.cca = CrissCrossAttention(self.channels)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ for _ in range(self.recurrence):
+ output = self.cca(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/da_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/da_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0b7616501c04cc0faf92accac9d3fdb6807f9e1
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/da_head.py
@@ -0,0 +1,178 @@
+import torch
+import torch.nn.functional as F
+from annotator.mmpkg.mmcv.cnn import ConvModule, Scale
+from torch import nn
+
+from annotator.mmpkg.mmseg.core import add_prefix
+from ..builder import HEADS
+from ..utils import SelfAttentionBlock as _SelfAttentionBlock
+from .decode_head import BaseDecodeHead
+
+
+class PAM(_SelfAttentionBlock):
+ """Position Attention Module (PAM)
+
+ Args:
+ in_channels (int): Input channels of key/query feature.
+ channels (int): Output channels of key/query transform.
+ """
+
+ def __init__(self, in_channels, channels):
+ super(PAM, self).__init__(
+ key_in_channels=in_channels,
+ query_in_channels=in_channels,
+ channels=channels,
+ out_channels=in_channels,
+ share_key_query=False,
+ query_downsample=None,
+ key_downsample=None,
+ key_query_num_convs=1,
+ key_query_norm=False,
+ value_out_num_convs=1,
+ value_out_norm=False,
+ matmul_norm=False,
+ with_out=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None)
+
+ self.gamma = Scale(0)
+
+ def forward(self, x):
+ """Forward function."""
+ out = super(PAM, self).forward(x, x)
+
+ out = self.gamma(out) + x
+ return out
+
+
+class CAM(nn.Module):
+ """Channel Attention Module (CAM)"""
+
+ def __init__(self):
+ super(CAM, self).__init__()
+ self.gamma = Scale(0)
+
+ def forward(self, x):
+ """Forward function."""
+ batch_size, channels, height, width = x.size()
+ proj_query = x.view(batch_size, channels, -1)
+ proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)
+ energy = torch.bmm(proj_query, proj_key)
+ energy_new = torch.max(
+ energy, -1, keepdim=True)[0].expand_as(energy) - energy
+ attention = F.softmax(energy_new, dim=-1)
+ proj_value = x.view(batch_size, channels, -1)
+
+ out = torch.bmm(attention, proj_value)
+ out = out.view(batch_size, channels, height, width)
+
+ out = self.gamma(out) + x
+ return out
+
+
+@HEADS.register_module()
+class DAHead(BaseDecodeHead):
+ """Dual Attention Network for Scene Segmentation.
+
+ This head is the implementation of `DANet
+ `_.
+
+ Args:
+ pam_channels (int): The channels of Position Attention Module(PAM).
+ """
+
+ def __init__(self, pam_channels, **kwargs):
+ super(DAHead, self).__init__(**kwargs)
+ self.pam_channels = pam_channels
+ self.pam_in_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.pam = PAM(self.channels, pam_channels)
+ self.pam_out_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.pam_conv_seg = nn.Conv2d(
+ self.channels, self.num_classes, kernel_size=1)
+
+ self.cam_in_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.cam = CAM()
+ self.cam_out_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.cam_conv_seg = nn.Conv2d(
+ self.channels, self.num_classes, kernel_size=1)
+
+ def pam_cls_seg(self, feat):
+ """PAM feature classification."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.pam_conv_seg(feat)
+ return output
+
+ def cam_cls_seg(self, feat):
+ """CAM feature classification."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.cam_conv_seg(feat)
+ return output
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ pam_feat = self.pam_in_conv(x)
+ pam_feat = self.pam(pam_feat)
+ pam_feat = self.pam_out_conv(pam_feat)
+ pam_out = self.pam_cls_seg(pam_feat)
+
+ cam_feat = self.cam_in_conv(x)
+ cam_feat = self.cam(cam_feat)
+ cam_feat = self.cam_out_conv(cam_feat)
+ cam_out = self.cam_cls_seg(cam_feat)
+
+ feat_sum = pam_feat + cam_feat
+ pam_cam_out = self.cls_seg(feat_sum)
+
+ return pam_cam_out, pam_out, cam_out
+
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Forward function for testing, only ``pam_cam`` is used."""
+ return self.forward(inputs)[0]
+
+ def losses(self, seg_logit, seg_label):
+ """Compute ``pam_cam``, ``pam``, ``cam`` loss."""
+ pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit
+ loss = dict()
+ loss.update(
+ add_prefix(
+ super(DAHead, self).losses(pam_cam_seg_logit, seg_label),
+ 'pam_cam'))
+ loss.update(
+ add_prefix(
+ super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam'))
+ loss.update(
+ add_prefix(
+ super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam'))
+ return loss
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/decode_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/decode_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a74c89f2ef1274ffe947995722576ab2c78eaec1
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/decode_head.py
@@ -0,0 +1,234 @@
+from abc import ABCMeta, abstractmethod
+
+import torch
+import torch.nn as nn
+from annotator.mmpkg.mmcv.cnn import normal_init
+from annotator.mmpkg.mmcv.runner import auto_fp16, force_fp32
+
+from annotator.mmpkg.mmseg.core import build_pixel_sampler
+from annotator.mmpkg.mmseg.ops import resize
+from ..builder import build_loss
+from ..losses import accuracy
+
+
+class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
+ """Base class for BaseDecodeHead.
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ num_classes (int): Number of classes.
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
+ conv_cfg (dict|None): Config of conv layers. Default: None.
+ norm_cfg (dict|None): Config of norm layers. Default: None.
+ act_cfg (dict): Config of activation layers.
+ Default: dict(type='ReLU')
+ in_index (int|Sequence[int]): Input feature index. Default: -1
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ Default: None.
+ loss_decode (dict): Config of decode loss.
+ Default: dict(type='CrossEntropyLoss').
+ ignore_index (int | None): The label index to be ignored. When using
+ masked BCE loss, ignore_index should be set to None. Default: 255
+ sampler (dict|None): The config of segmentation map sampler.
+ Default: None.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ *,
+ num_classes,
+ dropout_ratio=0.1,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ in_index=-1,
+ input_transform=None,
+ loss_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0),
+ ignore_index=255,
+ sampler=None,
+ align_corners=False):
+ super(BaseDecodeHead, self).__init__()
+ self._init_inputs(in_channels, in_index, input_transform)
+ self.channels = channels
+ self.num_classes = num_classes
+ self.dropout_ratio = dropout_ratio
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.in_index = in_index
+ self.loss_decode = build_loss(loss_decode)
+ self.ignore_index = ignore_index
+ self.align_corners = align_corners
+ if sampler is not None:
+ self.sampler = build_pixel_sampler(sampler, context=self)
+ else:
+ self.sampler = None
+
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
+ if dropout_ratio > 0:
+ self.dropout = nn.Dropout2d(dropout_ratio)
+ else:
+ self.dropout = None
+ self.fp16_enabled = False
+
+ def extra_repr(self):
+ """Extra repr."""
+ s = f'input_transform={self.input_transform}, ' \
+ f'ignore_index={self.ignore_index}, ' \
+ f'align_corners={self.align_corners}'
+ return s
+
+ def _init_inputs(self, in_channels, in_index, input_transform):
+ """Check and initialize input transforms.
+
+ The in_channels, in_index and input_transform must match.
+ Specifically, when input_transform is None, only single feature map
+ will be selected. So in_channels and in_index must be of type int.
+ When input_transform
+
+ Args:
+ in_channels (int|Sequence[int]): Input channels.
+ in_index (int|Sequence[int]): Input feature index.
+ input_transform (str|None): Transformation type of input features.
+ Options: 'resize_concat', 'multiple_select', None.
+ 'resize_concat': Multiple feature maps will be resize to the
+ same size as first one and than concat together.
+ Usually used in FCN head of HRNet.
+ 'multiple_select': Multiple feature maps will be bundle into
+ a list and passed into decode head.
+ None: Only one select feature map is allowed.
+ """
+
+ if input_transform is not None:
+ assert input_transform in ['resize_concat', 'multiple_select']
+ self.input_transform = input_transform
+ self.in_index = in_index
+ if input_transform is not None:
+ assert isinstance(in_channels, (list, tuple))
+ assert isinstance(in_index, (list, tuple))
+ assert len(in_channels) == len(in_index)
+ if input_transform == 'resize_concat':
+ self.in_channels = sum(in_channels)
+ else:
+ self.in_channels = in_channels
+ else:
+ assert isinstance(in_channels, int)
+ assert isinstance(in_index, int)
+ self.in_channels = in_channels
+
+ def init_weights(self):
+ """Initialize weights of classification layer."""
+ normal_init(self.conv_seg, mean=0, std=0.01)
+
+ def _transform_inputs(self, inputs):
+ """Transform inputs for decoder.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+
+ Returns:
+ Tensor: The transformed inputs
+ """
+
+ if self.input_transform == 'resize_concat':
+ inputs = [inputs[i] for i in self.in_index]
+ upsampled_inputs = [
+ resize(
+ input=x,
+ size=inputs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners) for x in inputs
+ ]
+ inputs = torch.cat(upsampled_inputs, dim=1)
+ elif self.input_transform == 'multiple_select':
+ inputs = [inputs[i] for i in self.in_index]
+ else:
+ inputs = inputs[self.in_index]
+
+ return inputs
+
+ @auto_fp16()
+ @abstractmethod
+ def forward(self, inputs):
+ """Placeholder of forward function."""
+ pass
+
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ seg_logits = self.forward(inputs)
+ losses = self.losses(seg_logits, gt_semantic_seg)
+ return losses
+
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Forward function for testing.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+
+ Returns:
+ Tensor: Output segmentation map.
+ """
+ return self.forward(inputs)
+
+ def cls_seg(self, feat):
+ """Classify each pixel."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.conv_seg(feat)
+ return output
+
+ @force_fp32(apply_to=('seg_logit', ))
+ def losses(self, seg_logit, seg_label):
+ """Compute segmentation loss."""
+ loss = dict()
+ seg_logit = resize(
+ input=seg_logit,
+ size=seg_label.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ if self.sampler is not None:
+ seg_weight = self.sampler.sample(seg_logit, seg_label)
+ else:
+ seg_weight = None
+ seg_label = seg_label.squeeze(1)
+ loss['loss_seg'] = self.loss_decode(
+ seg_logit,
+ seg_label,
+ weight=seg_weight,
+ ignore_index=self.ignore_index)
+ loss['acc_seg'] = accuracy(seg_logit, seg_label)
+ return loss
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/dm_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/dm_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..de6d0f6390d96c1eef4242cdc9aed91ec7714c6a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/dm_head.py
@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.mmpkg.mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
+
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class DCM(nn.Module):
+ """Dynamic Convolutional Module used in DMNet.
+
+ Args:
+ filter_size (int): The filter size of generated convolution kernel
+ used in Dynamic Convolutional Module.
+ fusion (bool): Add one conv to fuse DCM output feature.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict | None): Config of conv layers.
+ norm_cfg (dict | None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+
+ def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,
+ norm_cfg, act_cfg):
+ super(DCM, self).__init__()
+ self.filter_size = filter_size
+ self.fusion = fusion
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1,
+ 0)
+
+ self.input_redu_conv = ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ if self.norm_cfg is not None:
+ self.norm = build_norm_layer(self.norm_cfg, self.channels)[1]
+ else:
+ self.norm = None
+ self.activate = build_activation_layer(self.act_cfg)
+
+ if self.fusion:
+ self.fusion_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, x):
+ """Forward function."""
+ generated_filter = self.filter_gen_conv(
+ F.adaptive_avg_pool2d(x, self.filter_size))
+ x = self.input_redu_conv(x)
+ b, c, h, w = x.shape
+ # [1, b * c, h, w], c = self.channels
+ x = x.view(1, b * c, h, w)
+ # [b * c, 1, filter_size, filter_size]
+ generated_filter = generated_filter.view(b * c, 1, self.filter_size,
+ self.filter_size)
+ pad = (self.filter_size - 1) // 2
+ if (self.filter_size - 1) % 2 == 0:
+ p2d = (pad, pad, pad, pad)
+ else:
+ p2d = (pad + 1, pad, pad + 1, pad)
+ x = F.pad(input=x, pad=p2d, mode='constant', value=0)
+ # [1, b * c, h, w]
+ output = F.conv2d(input=x, weight=generated_filter, groups=b * c)
+ # [b, c, h, w]
+ output = output.view(b, c, h, w)
+ if self.norm is not None:
+ output = self.norm(output)
+ output = self.activate(output)
+
+ if self.fusion:
+ output = self.fusion_conv(output)
+
+ return output
+
+
+@HEADS.register_module()
+class DMHead(BaseDecodeHead):
+ """Dynamic Multi-scale Filters for Semantic Segmentation.
+
+ This head is the implementation of
+ `DMNet `_.
+
+ Args:
+ filter_sizes (tuple[int]): The size of generated convolutional filters
+ used in Dynamic Convolutional Module. Default: (1, 3, 5, 7).
+ fusion (bool): Add one conv to fuse DCM output feature.
+ """
+
+ def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):
+ super(DMHead, self).__init__(**kwargs)
+ assert isinstance(filter_sizes, (list, tuple))
+ self.filter_sizes = filter_sizes
+ self.fusion = fusion
+ dcm_modules = []
+ for filter_size in self.filter_sizes:
+ dcm_modules.append(
+ DCM(filter_size,
+ self.fusion,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.dcm_modules = nn.ModuleList(dcm_modules)
+ self.bottleneck = ConvModule(
+ self.in_channels + len(filter_sizes) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ dcm_outs = [x]
+ for dcm_module in self.dcm_modules:
+ dcm_outs.append(dcm_module(x))
+ dcm_outs = torch.cat(dcm_outs, dim=1)
+ output = self.bottleneck(dcm_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/dnl_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/dnl_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3bb1de1499ad043cc51b2269b4d970d07c16076
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/dnl_head.py
@@ -0,0 +1,131 @@
+import torch
+from annotator.mmpkg.mmcv.cnn import NonLocal2d
+from torch import nn
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+class DisentangledNonLocal2d(NonLocal2d):
+ """Disentangled Non-Local Blocks.
+
+ Args:
+ temperature (float): Temperature to adjust attention. Default: 0.05
+ """
+
+ def __init__(self, *arg, temperature, **kwargs):
+ super().__init__(*arg, **kwargs)
+ self.temperature = temperature
+ self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
+
+ def embedded_gaussian(self, theta_x, phi_x):
+ """Embedded gaussian with temperature."""
+
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
+ pairwise_weight = torch.matmul(theta_x, phi_x)
+ if self.use_scale:
+ # theta_x.shape[-1] is `self.inter_channels`
+ pairwise_weight /= theta_x.shape[-1]**0.5
+ pairwise_weight /= self.temperature
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
+ return pairwise_weight
+
+ def forward(self, x):
+ # x: [N, C, H, W]
+ n = x.size(0)
+
+ # g_x: [N, HxW, C]
+ g_x = self.g(x).view(n, self.inter_channels, -1)
+ g_x = g_x.permute(0, 2, 1)
+
+ # theta_x: [N, HxW, C], phi_x: [N, C, HxW]
+ if self.mode == 'gaussian':
+ theta_x = x.view(n, self.in_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ if self.sub_sample:
+ phi_x = self.phi(x).view(n, self.in_channels, -1)
+ else:
+ phi_x = x.view(n, self.in_channels, -1)
+ elif self.mode == 'concatenation':
+ theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
+ else:
+ theta_x = self.theta(x).view(n, self.inter_channels, -1)
+ theta_x = theta_x.permute(0, 2, 1)
+ phi_x = self.phi(x).view(n, self.inter_channels, -1)
+
+ # subtract mean
+ theta_x -= theta_x.mean(dim=-2, keepdim=True)
+ phi_x -= phi_x.mean(dim=-1, keepdim=True)
+
+ pairwise_func = getattr(self, self.mode)
+ # pairwise_weight: [N, HxW, HxW]
+ pairwise_weight = pairwise_func(theta_x, phi_x)
+
+ # y: [N, HxW, C]
+ y = torch.matmul(pairwise_weight, g_x)
+ # y: [N, C, H, W]
+ y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
+ *x.size()[2:])
+
+ # unary_mask: [N, 1, HxW]
+ unary_mask = self.conv_mask(x)
+ unary_mask = unary_mask.view(n, 1, -1)
+ unary_mask = unary_mask.softmax(dim=-1)
+ # unary_x: [N, 1, C]
+ unary_x = torch.matmul(unary_mask, g_x)
+ # unary_x: [N, C, 1, 1]
+ unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
+ n, self.inter_channels, 1, 1)
+
+ output = x + self.conv_out(y + unary_x)
+
+ return output
+
+
+@HEADS.register_module()
+class DNLHead(FCNHead):
+ """Disentangled Non-Local Neural Networks.
+
+ This head is the implementation of `DNLNet
+ `_.
+
+ Args:
+ reduction (int): Reduction factor of projection transform. Default: 2.
+ use_scale (bool): Whether to scale pairwise_weight by
+ sqrt(1/inter_channels). Default: False.
+ mode (str): The nonlocal mode. Options are 'embedded_gaussian',
+ 'dot_product'. Default: 'embedded_gaussian.'.
+ temperature (float): Temperature to adjust attention. Default: 0.05
+ """
+
+ def __init__(self,
+ reduction=2,
+ use_scale=True,
+ mode='embedded_gaussian',
+ temperature=0.05,
+ **kwargs):
+ super(DNLHead, self).__init__(num_convs=2, **kwargs)
+ self.reduction = reduction
+ self.use_scale = use_scale
+ self.mode = mode
+ self.temperature = temperature
+ self.dnl_block = DisentangledNonLocal2d(
+ in_channels=self.channels,
+ reduction=self.reduction,
+ use_scale=self.use_scale,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ mode=self.mode,
+ temperature=self.temperature)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ output = self.dnl_block(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/ema_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/ema_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaebae7b25579cabcd3967da765568a282869a49
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/ema_head.py
@@ -0,0 +1,168 @@
+import math
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.mmpkg.mmcv.cnn import ConvModule
+
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+def reduce_mean(tensor):
+ """Reduce mean when distributed training."""
+ if not (dist.is_available() and dist.is_initialized()):
+ return tensor
+ tensor = tensor.clone()
+ dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
+ return tensor
+
+
+class EMAModule(nn.Module):
+ """Expectation Maximization Attention Module used in EMANet.
+
+ Args:
+ channels (int): Channels of the whole module.
+ num_bases (int): Number of bases.
+ num_stages (int): Number of the EM iterations.
+ """
+
+ def __init__(self, channels, num_bases, num_stages, momentum):
+ super(EMAModule, self).__init__()
+ assert num_stages >= 1, 'num_stages must be at least 1!'
+ self.num_bases = num_bases
+ self.num_stages = num_stages
+ self.momentum = momentum
+
+ bases = torch.zeros(1, channels, self.num_bases)
+ bases.normal_(0, math.sqrt(2. / self.num_bases))
+ # [1, channels, num_bases]
+ bases = F.normalize(bases, dim=1, p=2)
+ self.register_buffer('bases', bases)
+
+ def forward(self, feats):
+ """Forward function."""
+ batch_size, channels, height, width = feats.size()
+ # [batch_size, channels, height*width]
+ feats = feats.view(batch_size, channels, height * width)
+ # [batch_size, channels, num_bases]
+ bases = self.bases.repeat(batch_size, 1, 1)
+
+ with torch.no_grad():
+ for i in range(self.num_stages):
+ # [batch_size, height*width, num_bases]
+ attention = torch.einsum('bcn,bck->bnk', feats, bases)
+ attention = F.softmax(attention, dim=2)
+ # l1 norm
+ attention_normed = F.normalize(attention, dim=1, p=1)
+ # [batch_size, channels, num_bases]
+ bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
+ # l2 norm
+ bases = F.normalize(bases, dim=1, p=2)
+
+ feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
+ feats_recon = feats_recon.view(batch_size, channels, height, width)
+
+ if self.training:
+ bases = bases.mean(dim=0, keepdim=True)
+ bases = reduce_mean(bases)
+ # l2 norm
+ bases = F.normalize(bases, dim=1, p=2)
+ self.bases = (1 -
+ self.momentum) * self.bases + self.momentum * bases
+
+ return feats_recon
+
+
+@HEADS.register_module()
+class EMAHead(BaseDecodeHead):
+ """Expectation Maximization Attention Networks for Semantic Segmentation.
+
+ This head is the implementation of `EMANet
+ `_.
+
+ Args:
+ ema_channels (int): EMA module channels
+ num_bases (int): Number of bases.
+ num_stages (int): Number of the EM iterations.
+ concat_input (bool): Whether concat the input and output of convs
+ before classification layer. Default: True
+ momentum (float): Momentum to update the base. Default: 0.1.
+ """
+
+ def __init__(self,
+ ema_channels,
+ num_bases,
+ num_stages,
+ concat_input=True,
+ momentum=0.1,
+ **kwargs):
+ super(EMAHead, self).__init__(**kwargs)
+ self.ema_channels = ema_channels
+ self.num_bases = num_bases
+ self.num_stages = num_stages
+ self.concat_input = concat_input
+ self.momentum = momentum
+ self.ema_module = EMAModule(self.ema_channels, self.num_bases,
+ self.num_stages, self.momentum)
+
+ self.ema_in_conv = ConvModule(
+ self.in_channels,
+ self.ema_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ # project (0, inf) -> (-inf, inf)
+ self.ema_mid_conv = ConvModule(
+ self.ema_channels,
+ self.ema_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=None,
+ act_cfg=None)
+ for param in self.ema_mid_conv.parameters():
+ param.requires_grad = False
+
+ self.ema_out_conv = ConvModule(
+ self.ema_channels,
+ self.ema_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.bottleneck = ConvModule(
+ self.ema_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if self.concat_input:
+ self.conv_cat = ConvModule(
+ self.in_channels + self.channels,
+ self.channels,
+ kernel_size=3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ feats = self.ema_in_conv(x)
+ identity = feats
+ feats = self.ema_mid_conv(feats)
+ recon = self.ema_module(feats)
+ recon = F.relu(recon, inplace=True)
+ recon = self.ema_out_conv(recon)
+ output = F.relu(identity + recon, inplace=True)
+ output = self.bottleneck(output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/enc_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/enc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c2a22a90b26f3264f63234694f0f290a7891ea2
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/enc_head.py
@@ -0,0 +1,187 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.mmpkg.mmcv.cnn import ConvModule, build_norm_layer
+
+from annotator.mmpkg.mmseg.ops import Encoding, resize
+from ..builder import HEADS, build_loss
+from .decode_head import BaseDecodeHead
+
+
+class EncModule(nn.Module):
+ """Encoding Module used in EncNet.
+
+ Args:
+ in_channels (int): Input channels.
+ num_codes (int): Number of code words.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ """
+
+ def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
+ super(EncModule, self).__init__()
+ self.encoding_project = ConvModule(
+ in_channels,
+ in_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ # TODO: resolve this hack
+ # change to 1d
+ if norm_cfg is not None:
+ encoding_norm_cfg = norm_cfg.copy()
+ if encoding_norm_cfg['type'] in ['BN', 'IN']:
+ encoding_norm_cfg['type'] += '1d'
+ else:
+ encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
+ '2d', '1d')
+ else:
+ # fallback to BN1d
+ encoding_norm_cfg = dict(type='BN1d')
+ self.encoding = nn.Sequential(
+ Encoding(channels=in_channels, num_codes=num_codes),
+ build_norm_layer(encoding_norm_cfg, num_codes)[1],
+ nn.ReLU(inplace=True))
+ self.fc = nn.Sequential(
+ nn.Linear(in_channels, in_channels), nn.Sigmoid())
+
+ def forward(self, x):
+ """Forward function."""
+ encoding_projection = self.encoding_project(x)
+ encoding_feat = self.encoding(encoding_projection).mean(dim=1)
+ batch_size, channels, _, _ = x.size()
+ gamma = self.fc(encoding_feat)
+ y = gamma.view(batch_size, channels, 1, 1)
+ output = F.relu_(x + x * y)
+ return encoding_feat, output
+
+
+@HEADS.register_module()
+class EncHead(BaseDecodeHead):
+ """Context Encoding for Semantic Segmentation.
+
+ This head is the implementation of `EncNet
+ `_.
+
+ Args:
+ num_codes (int): Number of code words. Default: 32.
+ use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to
+ regularize the training. Default: True.
+ add_lateral (bool): Whether use lateral connection to fuse features.
+ Default: False.
+ loss_se_decode (dict): Config of decode loss.
+ Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
+ """
+
+ def __init__(self,
+ num_codes=32,
+ use_se_loss=True,
+ add_lateral=False,
+ loss_se_decode=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=0.2),
+ **kwargs):
+ super(EncHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ self.use_se_loss = use_se_loss
+ self.add_lateral = add_lateral
+ self.num_codes = num_codes
+ self.bottleneck = ConvModule(
+ self.in_channels[-1],
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if add_lateral:
+ self.lateral_convs = nn.ModuleList()
+ for in_channels in self.in_channels[:-1]: # skip the last one
+ self.lateral_convs.append(
+ ConvModule(
+ in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ self.fusion = ConvModule(
+ len(self.in_channels) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.enc_module = EncModule(
+ self.channels,
+ num_codes=num_codes,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if self.use_se_loss:
+ self.loss_se_decode = build_loss(loss_se_decode)
+ self.se_layer = nn.Linear(self.channels, self.num_classes)
+
+ def forward(self, inputs):
+ """Forward function."""
+ inputs = self._transform_inputs(inputs)
+ feat = self.bottleneck(inputs[-1])
+ if self.add_lateral:
+ laterals = [
+ resize(
+ lateral_conv(inputs[i]),
+ size=feat.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ feat = self.fusion(torch.cat([feat, *laterals], 1))
+ encode_feat, output = self.enc_module(feat)
+ output = self.cls_seg(output)
+ if self.use_se_loss:
+ se_output = self.se_layer(encode_feat)
+ return output, se_output
+ else:
+ return output
+
+ def forward_test(self, inputs, img_metas, test_cfg):
+ """Forward function for testing, ignore se_loss."""
+ if self.use_se_loss:
+ return self.forward(inputs)[0]
+ else:
+ return self.forward(inputs)
+
+ @staticmethod
+ def _convert_to_onehot_labels(seg_label, num_classes):
+ """Convert segmentation label to onehot.
+
+ Args:
+ seg_label (Tensor): Segmentation label of shape (N, H, W).
+ num_classes (int): Number of classes.
+
+ Returns:
+ Tensor: Onehot labels of shape (N, num_classes).
+ """
+
+ batch_size = seg_label.size(0)
+ onehot_labels = seg_label.new_zeros((batch_size, num_classes))
+ for i in range(batch_size):
+ hist = seg_label[i].float().histc(
+ bins=num_classes, min=0, max=num_classes - 1)
+ onehot_labels[i] = hist > 0
+ return onehot_labels
+
+ def losses(self, seg_logit, seg_label):
+ """Compute segmentation and semantic encoding loss."""
+ seg_logit, se_seg_logit = seg_logit
+ loss = dict()
+ loss.update(super(EncHead, self).losses(seg_logit, seg_label))
+ se_loss = self.loss_se_decode(
+ se_seg_logit,
+ self._convert_to_onehot_labels(seg_label, self.num_classes))
+ loss['loss_se'] = se_loss
+ return loss
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/fcn_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/fcn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4583c57246e8e3b1d15d240b943d046afa5cba5
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/fcn_head.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+from annotator.mmpkg.mmcv.cnn import ConvModule
+
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+@HEADS.register_module()
+class FCNHead(BaseDecodeHead):
+ """Fully Convolution Networks for Semantic Segmentation.
+
+ This head is implemented of `FCNNet `_.
+
+ Args:
+ num_convs (int): Number of convs in the head. Default: 2.
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
+ concat_input (bool): Whether concat the input and output of convs
+ before classification layer.
+ dilation (int): The dilation rate for convs in the head. Default: 1.
+ """
+
+ def __init__(self,
+ num_convs=2,
+ kernel_size=3,
+ concat_input=True,
+ dilation=1,
+ **kwargs):
+ assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int)
+ self.num_convs = num_convs
+ self.concat_input = concat_input
+ self.kernel_size = kernel_size
+ super(FCNHead, self).__init__(**kwargs)
+ if num_convs == 0:
+ assert self.in_channels == self.channels
+
+ conv_padding = (kernel_size // 2) * dilation
+ convs = []
+ convs.append(
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=conv_padding,
+ dilation=dilation,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ for i in range(num_convs - 1):
+ convs.append(
+ ConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=conv_padding,
+ dilation=dilation,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ if num_convs == 0:
+ self.convs = nn.Identity()
+ else:
+ self.convs = nn.Sequential(*convs)
+ if self.concat_input:
+ self.conv_cat = ConvModule(
+ self.in_channels + self.channels,
+ self.channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs(x)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/fpn_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/fpn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a9ba39eebc406bfa422dc98eeaa32a800008a83
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/fpn_head.py
@@ -0,0 +1,68 @@
+import numpy as np
+import torch.nn as nn
+from annotator.mmpkg.mmcv.cnn import ConvModule
+
+from annotator.mmpkg.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+@HEADS.register_module()
+class FPNHead(BaseDecodeHead):
+ """Panoptic Feature Pyramid Networks.
+
+ This head is the implementation of `Semantic FPN
+ `_.
+
+ Args:
+ feature_strides (tuple[int]): The strides for input feature maps.
+ stack_lateral. All strides suppose to be power of 2. The first
+ one is of largest resolution.
+ """
+
+ def __init__(self, feature_strides, **kwargs):
+ super(FPNHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ assert len(feature_strides) == len(self.in_channels)
+ assert min(feature_strides) == feature_strides[0]
+ self.feature_strides = feature_strides
+
+ self.scale_heads = nn.ModuleList()
+ for i in range(len(feature_strides)):
+ head_length = max(
+ 1,
+ int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
+ scale_head = []
+ for k in range(head_length):
+ scale_head.append(
+ ConvModule(
+ self.in_channels[i] if k == 0 else self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+ if feature_strides[i] != feature_strides[0]:
+ scale_head.append(
+ nn.Upsample(
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=self.align_corners))
+ self.scale_heads.append(nn.Sequential(*scale_head))
+
+ def forward(self, inputs):
+
+ x = self._transform_inputs(inputs)
+
+ output = self.scale_heads[0](x[0])
+ for i in range(1, len(self.feature_strides)):
+ # non inplace
+ output = output + resize(
+ self.scale_heads[i](x[i]),
+ size=output.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+
+ output = self.cls_seg(output)
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/gc_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/gc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..6342811f67e4affac7886c8fc745a28abcc32c55
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/gc_head.py
@@ -0,0 +1,47 @@
+import torch
+from annotator.mmpkg.mmcv.cnn import ContextBlock
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+@HEADS.register_module()
+class GCHead(FCNHead):
+ """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
+
+ This head is the implementation of `GCNet
+ `_.
+
+ Args:
+ ratio (float): Multiplier of channels ratio. Default: 1/4.
+ pooling_type (str): The pooling type of context aggregation.
+ Options are 'att', 'avg'. Default: 'avg'.
+ fusion_types (tuple[str]): The fusion type for feature fusion.
+ Options are 'channel_add', 'channel_mul'. Default: ('channel_add',)
+ """
+
+ def __init__(self,
+ ratio=1 / 4.,
+ pooling_type='att',
+ fusion_types=('channel_add', ),
+ **kwargs):
+ super(GCHead, self).__init__(num_convs=2, **kwargs)
+ self.ratio = ratio
+ self.pooling_type = pooling_type
+ self.fusion_types = fusion_types
+ self.gc_block = ContextBlock(
+ in_channels=self.channels,
+ ratio=self.ratio,
+ pooling_type=self.pooling_type,
+ fusion_types=self.fusion_types)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ output = self.gc_block(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/lraspp_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/lraspp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b29d80e77d05cc0c12118e335e266a73bda99ed0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/lraspp_head.py
@@ -0,0 +1,90 @@
+import torch
+import torch.nn as nn
+from annotator.mmpkg.mmcv import is_tuple_of
+from annotator.mmpkg.mmcv.cnn import ConvModule
+
+from annotator.mmpkg.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+@HEADS.register_module()
+class LRASPPHead(BaseDecodeHead):
+ """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
+
+ This head is the improved implementation of `Searching for MobileNetV3
+ `_.
+
+ Args:
+ branch_channels (tuple[int]): The number of output channels in every
+ each branch. Default: (32, 64).
+ """
+
+ def __init__(self, branch_channels=(32, 64), **kwargs):
+ super(LRASPPHead, self).__init__(**kwargs)
+ if self.input_transform != 'multiple_select':
+ raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
+ f'must be \'multiple_select\'. But received '
+ f'\'{self.input_transform}\'')
+ assert is_tuple_of(branch_channels, int)
+ assert len(branch_channels) == len(self.in_channels) - 1
+ self.branch_channels = branch_channels
+
+ self.convs = nn.Sequential()
+ self.conv_ups = nn.Sequential()
+ for i in range(len(branch_channels)):
+ self.convs.add_module(
+ f'conv{i}',
+ nn.Conv2d(
+ self.in_channels[i], branch_channels[i], 1, bias=False))
+ self.conv_ups.add_module(
+ f'conv_up{i}',
+ ConvModule(
+ self.channels + branch_channels[i],
+ self.channels,
+ 1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ bias=False))
+
+ self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
+
+ self.aspp_conv = ConvModule(
+ self.in_channels[-1],
+ self.channels,
+ 1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ bias=False)
+ self.image_pool = nn.Sequential(
+ nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
+ ConvModule(
+ self.in_channels[2],
+ self.channels,
+ 1,
+ act_cfg=dict(type='Sigmoid'),
+ bias=False))
+
+ def forward(self, inputs):
+ """Forward function."""
+ inputs = self._transform_inputs(inputs)
+
+ x = inputs[-1]
+
+ x = self.aspp_conv(x) * resize(
+ self.image_pool(x),
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ x = self.conv_up_input(x)
+
+ for i in range(len(self.branch_channels) - 1, -1, -1):
+ x = resize(
+ x,
+ size=inputs[i].size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ x = torch.cat([x, self.convs[i](inputs[i])], 1)
+ x = self.conv_ups[i](x)
+
+ return self.cls_seg(x)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/nl_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/nl_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..5990df1b8b0d57cfa772ec1b6b6be20a8f667ce7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/nl_head.py
@@ -0,0 +1,49 @@
+import torch
+from annotator.mmpkg.mmcv.cnn import NonLocal2d
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+@HEADS.register_module()
+class NLHead(FCNHead):
+ """Non-local Neural Networks.
+
+ This head is the implementation of `NLNet
+ `_.
+
+ Args:
+ reduction (int): Reduction factor of projection transform. Default: 2.
+ use_scale (bool): Whether to scale pairwise_weight by
+ sqrt(1/inter_channels). Default: True.
+ mode (str): The nonlocal mode. Options are 'embedded_gaussian',
+ 'dot_product'. Default: 'embedded_gaussian.'.
+ """
+
+ def __init__(self,
+ reduction=2,
+ use_scale=True,
+ mode='embedded_gaussian',
+ **kwargs):
+ super(NLHead, self).__init__(num_convs=2, **kwargs)
+ self.reduction = reduction
+ self.use_scale = use_scale
+ self.mode = mode
+ self.nl_block = NonLocal2d(
+ in_channels=self.channels,
+ reduction=self.reduction,
+ use_scale=self.use_scale,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ mode=self.mode)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ output = self.convs[0](x)
+ output = self.nl_block(output)
+ output = self.convs[1](output)
+ if self.concat_input:
+ output = self.conv_cat(torch.cat([x, output], dim=1))
+ output = self.cls_seg(output)
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/ocr_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/ocr_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c46d10e5baff54e182af0426a1ecfea9ca190a9f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/ocr_head.py
@@ -0,0 +1,127 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.mmpkg.mmcv.cnn import ConvModule
+
+from annotator.mmpkg.mmseg.ops import resize
+from ..builder import HEADS
+from ..utils import SelfAttentionBlock as _SelfAttentionBlock
+from .cascade_decode_head import BaseCascadeDecodeHead
+
+
+class SpatialGatherModule(nn.Module):
+ """Aggregate the context features according to the initial predicted
+ probability distribution.
+
+ Employ the soft-weighted method to aggregate the context.
+ """
+
+ def __init__(self, scale):
+ super(SpatialGatherModule, self).__init__()
+ self.scale = scale
+
+ def forward(self, feats, probs):
+ """Forward function."""
+ batch_size, num_classes, height, width = probs.size()
+ channels = feats.size(1)
+ probs = probs.view(batch_size, num_classes, -1)
+ feats = feats.view(batch_size, channels, -1)
+ # [batch_size, height*width, num_classes]
+ feats = feats.permute(0, 2, 1)
+ # [batch_size, channels, height*width]
+ probs = F.softmax(self.scale * probs, dim=2)
+ # [batch_size, channels, num_classes]
+ ocr_context = torch.matmul(probs, feats)
+ ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)
+ return ocr_context
+
+
+class ObjectAttentionBlock(_SelfAttentionBlock):
+ """Make a OCR used SelfAttentionBlock."""
+
+ def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg,
+ act_cfg):
+ if scale > 1:
+ query_downsample = nn.MaxPool2d(kernel_size=scale)
+ else:
+ query_downsample = None
+ super(ObjectAttentionBlock, self).__init__(
+ key_in_channels=in_channels,
+ query_in_channels=in_channels,
+ channels=channels,
+ out_channels=in_channels,
+ share_key_query=False,
+ query_downsample=query_downsample,
+ key_downsample=None,
+ key_query_num_convs=2,
+ key_query_norm=True,
+ value_out_num_convs=1,
+ value_out_norm=True,
+ matmul_norm=True,
+ with_out=True,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.bottleneck = ConvModule(
+ in_channels * 2,
+ in_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, query_feats, key_feats):
+ """Forward function."""
+ context = super(ObjectAttentionBlock,
+ self).forward(query_feats, key_feats)
+ output = self.bottleneck(torch.cat([context, query_feats], dim=1))
+ if self.query_downsample is not None:
+ output = resize(query_feats)
+
+ return output
+
+
+@HEADS.register_module()
+class OCRHead(BaseCascadeDecodeHead):
+ """Object-Contextual Representations for Semantic Segmentation.
+
+ This head is the implementation of `OCRNet
+ `_.
+
+ Args:
+ ocr_channels (int): The intermediate channels of OCR block.
+ scale (int): The scale of probability map in SpatialGatherModule in
+ Default: 1.
+ """
+
+ def __init__(self, ocr_channels, scale=1, **kwargs):
+ super(OCRHead, self).__init__(**kwargs)
+ self.ocr_channels = ocr_channels
+ self.scale = scale
+ self.object_context_block = ObjectAttentionBlock(
+ self.channels,
+ self.ocr_channels,
+ self.scale,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.spatial_gather_module = SpatialGatherModule(self.scale)
+
+ self.bottleneck = ConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs, prev_output):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ feats = self.bottleneck(x)
+ context = self.spatial_gather_module(feats, prev_output)
+ object_context = self.object_context_block(feats, context)
+ output = self.cls_seg(object_context)
+
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/point_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/point_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6782763e30386d99115977ebe5a4d9291bae8d9
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/point_head.py
@@ -0,0 +1,354 @@
+# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
+
+import torch
+import torch.nn as nn
+
+try:
+ from mmcv.cnn import ConvModule, normal_init
+ from mmcv.ops import point_sample
+except ImportError:
+ from annotator.mmpkg.mmcv.cnn import ConvModule, normal_init
+ from annotator.mmpkg.mmcv.ops import point_sample
+
+from annotator.mmpkg.mmseg.models.builder import HEADS
+from annotator.mmpkg.mmseg.ops import resize
+from ..losses import accuracy
+from .cascade_decode_head import BaseCascadeDecodeHead
+
+
+def calculate_uncertainty(seg_logits):
+ """Estimate uncertainty based on seg logits.
+
+ For each location of the prediction ``seg_logits`` we estimate
+ uncertainty as the difference between top first and top second
+ predicted logits.
+
+ Args:
+ seg_logits (Tensor): Semantic segmentation logits,
+ shape (batch_size, num_classes, height, width).
+
+ Returns:
+ scores (Tensor): T uncertainty scores with the most uncertain
+ locations having the highest uncertainty score, shape (
+ batch_size, 1, height, width)
+ """
+ top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]
+ return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
+
+
+@HEADS.register_module()
+class PointHead(BaseCascadeDecodeHead):
+ """A mask point head use in PointRend.
+
+ ``PointHead`` use shared multi-layer perceptron (equivalent to
+ nn.Conv1d) to predict the logit of input points. The fine-grained feature
+ and coarse feature will be concatenate together for predication.
+
+ Args:
+ num_fcs (int): Number of fc layers in the head. Default: 3.
+ in_channels (int): Number of input channels. Default: 256.
+ fc_channels (int): Number of fc channels. Default: 256.
+ num_classes (int): Number of classes for logits. Default: 80.
+ class_agnostic (bool): Whether use class agnostic classification.
+ If so, the output channels of logits will be 1. Default: False.
+ coarse_pred_each_layer (bool): Whether concatenate coarse feature with
+ the output of each fc layer. Default: True.
+ conv_cfg (dict|None): Dictionary to construct and config conv layer.
+ Default: dict(type='Conv1d'))
+ norm_cfg (dict|None): Dictionary to construct and config norm layer.
+ Default: None.
+ loss_point (dict): Dictionary to construct and config loss layer of
+ point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
+ loss_weight=1.0).
+ """
+
+ def __init__(self,
+ num_fcs=3,
+ coarse_pred_each_layer=True,
+ conv_cfg=dict(type='Conv1d'),
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU', inplace=False),
+ **kwargs):
+ super(PointHead, self).__init__(
+ input_transform='multiple_select',
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ **kwargs)
+
+ self.num_fcs = num_fcs
+ self.coarse_pred_each_layer = coarse_pred_each_layer
+
+ fc_in_channels = sum(self.in_channels) + self.num_classes
+ fc_channels = self.channels
+ self.fcs = nn.ModuleList()
+ for k in range(num_fcs):
+ fc = ConvModule(
+ fc_in_channels,
+ fc_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.fcs.append(fc)
+ fc_in_channels = fc_channels
+ fc_in_channels += self.num_classes if self.coarse_pred_each_layer \
+ else 0
+ self.fc_seg = nn.Conv1d(
+ fc_in_channels,
+ self.num_classes,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ if self.dropout_ratio > 0:
+ self.dropout = nn.Dropout(self.dropout_ratio)
+ delattr(self, 'conv_seg')
+
+ def init_weights(self):
+ """Initialize weights of classification layer."""
+ normal_init(self.fc_seg, std=0.001)
+
+ def cls_seg(self, feat):
+ """Classify each pixel with fc."""
+ if self.dropout is not None:
+ feat = self.dropout(feat)
+ output = self.fc_seg(feat)
+ return output
+
+ def forward(self, fine_grained_point_feats, coarse_point_feats):
+ x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)
+ for fc in self.fcs:
+ x = fc(x)
+ if self.coarse_pred_each_layer:
+ x = torch.cat((x, coarse_point_feats), dim=1)
+ return self.cls_seg(x)
+
+ def _get_fine_grained_point_feats(self, x, points):
+ """Sample from fine grained features.
+
+ Args:
+ x (list[Tensor]): Feature pyramid from by neck or backbone.
+ points (Tensor): Point coordinates, shape (batch_size,
+ num_points, 2).
+
+ Returns:
+ fine_grained_feats (Tensor): Sampled fine grained feature,
+ shape (batch_size, sum(channels of x), num_points).
+ """
+
+ fine_grained_feats_list = [
+ point_sample(_, points, align_corners=self.align_corners)
+ for _ in x
+ ]
+ if len(fine_grained_feats_list) > 1:
+ fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)
+ else:
+ fine_grained_feats = fine_grained_feats_list[0]
+
+ return fine_grained_feats
+
+ def _get_coarse_point_feats(self, prev_output, points):
+ """Sample from fine grained features.
+
+ Args:
+ prev_output (list[Tensor]): Prediction of previous decode head.
+ points (Tensor): Point coordinates, shape (batch_size,
+ num_points, 2).
+
+ Returns:
+ coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,
+ num_classes, num_points).
+ """
+
+ coarse_feats = point_sample(
+ prev_output, points, align_corners=self.align_corners)
+
+ return coarse_feats
+
+ def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
+ train_cfg):
+ """Forward function for training.
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+ train_cfg (dict): The training config.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ x = self._transform_inputs(inputs)
+ with torch.no_grad():
+ points = self.get_points_train(
+ prev_output, calculate_uncertainty, cfg=train_cfg)
+ fine_grained_point_feats = self._get_fine_grained_point_feats(
+ x, points)
+ coarse_point_feats = self._get_coarse_point_feats(prev_output, points)
+ point_logits = self.forward(fine_grained_point_feats,
+ coarse_point_feats)
+ point_label = point_sample(
+ gt_semantic_seg.float(),
+ points,
+ mode='nearest',
+ align_corners=self.align_corners)
+ point_label = point_label.squeeze(1).long()
+
+ losses = self.losses(point_logits, point_label)
+
+ return losses
+
+ def forward_test(self, inputs, prev_output, img_metas, test_cfg):
+ """Forward function for testing.
+
+ Args:
+ inputs (list[Tensor]): List of multi-level img features.
+ prev_output (Tensor): The output of previous decode head.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ test_cfg (dict): The testing config.
+
+ Returns:
+ Tensor: Output segmentation map.
+ """
+
+ x = self._transform_inputs(inputs)
+ refined_seg_logits = prev_output.clone()
+ for _ in range(test_cfg.subdivision_steps):
+ refined_seg_logits = resize(
+ refined_seg_logits,
+ scale_factor=test_cfg.scale_factor,
+ mode='bilinear',
+ align_corners=self.align_corners)
+ batch_size, channels, height, width = refined_seg_logits.shape
+ point_indices, points = self.get_points_test(
+ refined_seg_logits, calculate_uncertainty, cfg=test_cfg)
+ fine_grained_point_feats = self._get_fine_grained_point_feats(
+ x, points)
+ coarse_point_feats = self._get_coarse_point_feats(
+ prev_output, points)
+ point_logits = self.forward(fine_grained_point_feats,
+ coarse_point_feats)
+
+ point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
+ refined_seg_logits = refined_seg_logits.reshape(
+ batch_size, channels, height * width)
+ refined_seg_logits = refined_seg_logits.scatter_(
+ 2, point_indices, point_logits)
+ refined_seg_logits = refined_seg_logits.view(
+ batch_size, channels, height, width)
+
+ return refined_seg_logits
+
+ def losses(self, point_logits, point_label):
+ """Compute segmentation loss."""
+ loss = dict()
+ loss['loss_point'] = self.loss_decode(
+ point_logits, point_label, ignore_index=self.ignore_index)
+ loss['acc_point'] = accuracy(point_logits, point_label)
+ return loss
+
+ def get_points_train(self, seg_logits, uncertainty_func, cfg):
+ """Sample points for training.
+
+ Sample points in [0, 1] x [0, 1] coordinate space based on their
+ uncertainty. The uncertainties are calculated for each point using
+ 'uncertainty_func' function that takes point's logit prediction as
+ input.
+
+ Args:
+ seg_logits (Tensor): Semantic segmentation logits, shape (
+ batch_size, num_classes, height, width).
+ uncertainty_func (func): uncertainty calculation function.
+ cfg (dict): Training config of point head.
+
+ Returns:
+ point_coords (Tensor): A tensor of shape (batch_size, num_points,
+ 2) that contains the coordinates of ``num_points`` sampled
+ points.
+ """
+ num_points = cfg.num_points
+ oversample_ratio = cfg.oversample_ratio
+ importance_sample_ratio = cfg.importance_sample_ratio
+ assert oversample_ratio >= 1
+ assert 0 <= importance_sample_ratio <= 1
+ batch_size = seg_logits.shape[0]
+ num_sampled = int(num_points * oversample_ratio)
+ point_coords = torch.rand(
+ batch_size, num_sampled, 2, device=seg_logits.device)
+ point_logits = point_sample(seg_logits, point_coords)
+ # It is crucial to calculate uncertainty based on the sampled
+ # prediction value for the points. Calculating uncertainties of the
+ # coarse predictions first and sampling them for points leads to
+ # incorrect results. To illustrate this: assume uncertainty func(
+ # logits)=-abs(logits), a sampled point between two coarse
+ # predictions with -1 and 1 logits has 0 logits, and therefore 0
+ # uncertainty value. However, if we calculate uncertainties for the
+ # coarse predictions first, both will have -1 uncertainty,
+ # and sampled point will get -1 uncertainty.
+ point_uncertainties = uncertainty_func(point_logits)
+ num_uncertain_points = int(importance_sample_ratio * num_points)
+ num_random_points = num_points - num_uncertain_points
+ idx = torch.topk(
+ point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
+ shift = num_sampled * torch.arange(
+ batch_size, dtype=torch.long, device=seg_logits.device)
+ idx += shift[:, None]
+ point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
+ batch_size, num_uncertain_points, 2)
+ if num_random_points > 0:
+ rand_point_coords = torch.rand(
+ batch_size, num_random_points, 2, device=seg_logits.device)
+ point_coords = torch.cat((point_coords, rand_point_coords), dim=1)
+ return point_coords
+
+ def get_points_test(self, seg_logits, uncertainty_func, cfg):
+ """Sample points for testing.
+
+ Find ``num_points`` most uncertain points from ``uncertainty_map``.
+
+ Args:
+ seg_logits (Tensor): A tensor of shape (batch_size, num_classes,
+ height, width) for class-specific or class-agnostic prediction.
+ uncertainty_func (func): uncertainty calculation function.
+ cfg (dict): Testing config of point head.
+
+ Returns:
+ point_indices (Tensor): A tensor of shape (batch_size, num_points)
+ that contains indices from [0, height x width) of the most
+ uncertain points.
+ point_coords (Tensor): A tensor of shape (batch_size, num_points,
+ 2) that contains [0, 1] x [0, 1] normalized coordinates of the
+ most uncertain points from the ``height x width`` grid .
+ """
+
+ num_points = cfg.subdivision_num_points
+ uncertainty_map = uncertainty_func(seg_logits)
+ batch_size, _, height, width = uncertainty_map.shape
+ h_step = 1.0 / height
+ w_step = 1.0 / width
+
+ uncertainty_map = uncertainty_map.view(batch_size, height * width)
+ num_points = min(height * width, num_points)
+ point_indices = uncertainty_map.topk(num_points, dim=1)[1]
+ point_coords = torch.zeros(
+ batch_size,
+ num_points,
+ 2,
+ dtype=torch.float,
+ device=seg_logits.device)
+ point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
+ width).float() * w_step
+ point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
+ width).float() * h_step
+ return point_indices, point_coords
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/psa_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/psa_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba6fe3a8b8f8dc7c4e4d3b9bc09e9642c0b3732f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/psa_head.py
@@ -0,0 +1,199 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.mmpkg.mmcv.cnn import ConvModule
+
+from annotator.mmpkg.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+try:
+ try:
+ from mmcv.ops import PSAMask
+ except ImportError:
+ from annotator.mmpkg.mmcv.ops import PSAMask
+except ModuleNotFoundError:
+ PSAMask = None
+
+
+@HEADS.register_module()
+class PSAHead(BaseDecodeHead):
+ """Point-wise Spatial Attention Network for Scene Parsing.
+
+ This head is the implementation of `PSANet
+ `_.
+
+ Args:
+ mask_size (tuple[int]): The PSA mask size. It usually equals input
+ size.
+ psa_type (str): The type of psa module. Options are 'collect',
+ 'distribute', 'bi-direction'. Default: 'bi-direction'
+ compact (bool): Whether use compact map for 'collect' mode.
+ Default: True.
+ shrink_factor (int): The downsample factors of psa mask. Default: 2.
+ normalization_factor (float): The normalize factor of attention.
+ psa_softmax (bool): Whether use softmax for attention.
+ """
+
+ def __init__(self,
+ mask_size,
+ psa_type='bi-direction',
+ compact=False,
+ shrink_factor=2,
+ normalization_factor=1.0,
+ psa_softmax=True,
+ **kwargs):
+ if PSAMask is None:
+ raise RuntimeError('Please install mmcv-full for PSAMask ops')
+ super(PSAHead, self).__init__(**kwargs)
+ assert psa_type in ['collect', 'distribute', 'bi-direction']
+ self.psa_type = psa_type
+ self.compact = compact
+ self.shrink_factor = shrink_factor
+ self.mask_size = mask_size
+ mask_h, mask_w = mask_size
+ self.psa_softmax = psa_softmax
+ if normalization_factor is None:
+ normalization_factor = mask_h * mask_w
+ self.normalization_factor = normalization_factor
+
+ self.reduce = ConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.attention = nn.Sequential(
+ ConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ nn.Conv2d(
+ self.channels, mask_h * mask_w, kernel_size=1, bias=False))
+ if psa_type == 'bi-direction':
+ self.reduce_p = ConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.attention_p = nn.Sequential(
+ ConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ nn.Conv2d(
+ self.channels, mask_h * mask_w, kernel_size=1, bias=False))
+ self.psamask_collect = PSAMask('collect', mask_size)
+ self.psamask_distribute = PSAMask('distribute', mask_size)
+ else:
+ self.psamask = PSAMask(psa_type, mask_size)
+ self.proj = ConvModule(
+ self.channels * (2 if psa_type == 'bi-direction' else 1),
+ self.in_channels,
+ kernel_size=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ self.bottleneck = ConvModule(
+ self.in_channels * 2,
+ self.channels,
+ kernel_size=3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ identity = x
+ align_corners = self.align_corners
+ if self.psa_type in ['collect', 'distribute']:
+ out = self.reduce(x)
+ n, c, h, w = out.size()
+ if self.shrink_factor != 1:
+ if h % self.shrink_factor and w % self.shrink_factor:
+ h = (h - 1) // self.shrink_factor + 1
+ w = (w - 1) // self.shrink_factor + 1
+ align_corners = True
+ else:
+ h = h // self.shrink_factor
+ w = w // self.shrink_factor
+ align_corners = False
+ out = resize(
+ out,
+ size=(h, w),
+ mode='bilinear',
+ align_corners=align_corners)
+ y = self.attention(out)
+ if self.compact:
+ if self.psa_type == 'collect':
+ y = y.view(n, h * w,
+ h * w).transpose(1, 2).view(n, h * w, h, w)
+ else:
+ y = self.psamask(y)
+ if self.psa_softmax:
+ y = F.softmax(y, dim=1)
+ out = torch.bmm(
+ out.view(n, c, h * w), y.view(n, h * w, h * w)).view(
+ n, c, h, w) * (1.0 / self.normalization_factor)
+ else:
+ x_col = self.reduce(x)
+ x_dis = self.reduce_p(x)
+ n, c, h, w = x_col.size()
+ if self.shrink_factor != 1:
+ if h % self.shrink_factor and w % self.shrink_factor:
+ h = (h - 1) // self.shrink_factor + 1
+ w = (w - 1) // self.shrink_factor + 1
+ align_corners = True
+ else:
+ h = h // self.shrink_factor
+ w = w // self.shrink_factor
+ align_corners = False
+ x_col = resize(
+ x_col,
+ size=(h, w),
+ mode='bilinear',
+ align_corners=align_corners)
+ x_dis = resize(
+ x_dis,
+ size=(h, w),
+ mode='bilinear',
+ align_corners=align_corners)
+ y_col = self.attention(x_col)
+ y_dis = self.attention_p(x_dis)
+ if self.compact:
+ y_dis = y_dis.view(n, h * w,
+ h * w).transpose(1, 2).view(n, h * w, h, w)
+ else:
+ y_col = self.psamask_collect(y_col)
+ y_dis = self.psamask_distribute(y_dis)
+ if self.psa_softmax:
+ y_col = F.softmax(y_col, dim=1)
+ y_dis = F.softmax(y_dis, dim=1)
+ x_col = torch.bmm(
+ x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(
+ n, c, h, w) * (1.0 / self.normalization_factor)
+ x_dis = torch.bmm(
+ x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(
+ n, c, h, w) * (1.0 / self.normalization_factor)
+ out = torch.cat([x_col, x_dis], 1)
+ out = self.proj(out)
+ out = resize(
+ out,
+ size=identity.shape[2:],
+ mode='bilinear',
+ align_corners=align_corners)
+ out = self.bottleneck(torch.cat((identity, out), dim=1))
+ out = self.cls_seg(out)
+ return out
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/psp_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/psp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a88d807bfe11fe224305f8de745cde3aa739db0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/psp_head.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from annotator.mmpkg.mmcv.cnn import ConvModule
+
+from annotator.mmpkg.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class PPM(nn.ModuleList):
+ """Pooling Pyramid Module used in PSPNet.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module.
+ in_channels (int): Input channels.
+ channels (int): Channels after modules, before conv_seg.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict): Config of activation layers.
+ align_corners (bool): align_corners argument of F.interpolate.
+ """
+
+ def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
+ act_cfg, align_corners):
+ super(PPM, self).__init__()
+ self.pool_scales = pool_scales
+ self.align_corners = align_corners
+ self.in_channels = in_channels
+ self.channels = channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ for pool_scale in pool_scales:
+ self.append(
+ nn.Sequential(
+ nn.AdaptiveAvgPool2d(pool_scale),
+ ConvModule(
+ self.in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)))
+
+ def forward(self, x):
+ """Forward function."""
+ ppm_outs = []
+ for ppm in self:
+ ppm_out = ppm(x)
+ upsampled_ppm_out = resize(
+ ppm_out,
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ ppm_outs.append(upsampled_ppm_out)
+ return ppm_outs
+
+
+@HEADS.register_module()
+class PSPHead(BaseDecodeHead):
+ """Pyramid Scene Parsing Network.
+
+ This head is the implementation of
+ `PSPNet `_.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module. Default: (1, 2, 3, 6).
+ """
+
+ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
+ super(PSPHead, self).__init__(**kwargs)
+ assert isinstance(pool_scales, (list, tuple))
+ self.pool_scales = pool_scales
+ self.psp_modules = PPM(
+ self.pool_scales,
+ self.in_channels,
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+ self.bottleneck = ConvModule(
+ self.in_channels + len(pool_scales) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ psp_outs = [x]
+ psp_outs.extend(self.psp_modules(x))
+ psp_outs = torch.cat(psp_outs, dim=1)
+ output = self.bottleneck(psp_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/sep_aspp_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/sep_aspp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a23970699df7afd86f483316be3c8d1a34d43c18
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/sep_aspp_head.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from annotator.mmpkg.mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
+
+from annotator.mmpkg.mmseg.ops import resize
+from ..builder import HEADS
+from .aspp_head import ASPPHead, ASPPModule
+
+
+class DepthwiseSeparableASPPModule(ASPPModule):
+ """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
+ conv."""
+
+ def __init__(self, **kwargs):
+ super(DepthwiseSeparableASPPModule, self).__init__(**kwargs)
+ for i, dilation in enumerate(self.dilations):
+ if dilation > 1:
+ self[i] = DepthwiseSeparableConvModule(
+ self.in_channels,
+ self.channels,
+ 3,
+ dilation=dilation,
+ padding=dilation,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+
+@HEADS.register_module()
+class DepthwiseSeparableASPPHead(ASPPHead):
+ """Encoder-Decoder with Atrous Separable Convolution for Semantic Image
+ Segmentation.
+
+ This head is the implementation of `DeepLabV3+
+ `_.
+
+ Args:
+ c1_in_channels (int): The input channels of c1 decoder. If is 0,
+ the no decoder will be used.
+ c1_channels (int): The intermediate channels of c1 decoder.
+ """
+
+ def __init__(self, c1_in_channels, c1_channels, **kwargs):
+ super(DepthwiseSeparableASPPHead, self).__init__(**kwargs)
+ assert c1_in_channels >= 0
+ self.aspp_modules = DepthwiseSeparableASPPModule(
+ dilations=self.dilations,
+ in_channels=self.in_channels,
+ channels=self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ if c1_in_channels > 0:
+ self.c1_bottleneck = ConvModule(
+ c1_in_channels,
+ c1_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ else:
+ self.c1_bottleneck = None
+ self.sep_bottleneck = nn.Sequential(
+ DepthwiseSeparableConvModule(
+ self.channels + c1_channels,
+ self.channels,
+ 3,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ DepthwiseSeparableConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg))
+
+ def forward(self, inputs):
+ """Forward function."""
+ x = self._transform_inputs(inputs)
+ aspp_outs = [
+ resize(
+ self.image_pool(x),
+ size=x.size()[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ ]
+ aspp_outs.extend(self.aspp_modules(x))
+ aspp_outs = torch.cat(aspp_outs, dim=1)
+ output = self.bottleneck(aspp_outs)
+ if self.c1_bottleneck is not None:
+ c1_output = self.c1_bottleneck(inputs[0])
+ output = resize(
+ input=output,
+ size=c1_output.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ output = torch.cat([output, c1_output], dim=1)
+ output = self.sep_bottleneck(output)
+ output = self.cls_seg(output)
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/sep_fcn_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/sep_fcn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ea198ab8a96919dfb6974fd73b1476aa488aef2
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/sep_fcn_head.py
@@ -0,0 +1,51 @@
+from annotator.mmpkg.mmcv.cnn import DepthwiseSeparableConvModule
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+@HEADS.register_module()
+class DepthwiseSeparableFCNHead(FCNHead):
+ """Depthwise-Separable Fully Convolutional Network for Semantic
+ Segmentation.
+
+ This head is implemented according to Fast-SCNN paper.
+ Args:
+ in_channels(int): Number of output channels of FFM.
+ channels(int): Number of middle-stage channels in the decode head.
+ concat_input(bool): Whether to concatenate original decode input into
+ the result of several consecutive convolution layers.
+ Default: True.
+ num_classes(int): Used to determine the dimension of
+ final prediction tensor.
+ in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
+ norm_cfg (dict | None): Config of norm layers.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ loss_decode(dict): Config of loss type and some
+ relevant additional options.
+ """
+
+ def __init__(self, **kwargs):
+ super(DepthwiseSeparableFCNHead, self).__init__(**kwargs)
+ self.convs[0] = DepthwiseSeparableConvModule(
+ self.in_channels,
+ self.channels,
+ kernel_size=self.kernel_size,
+ padding=self.kernel_size // 2,
+ norm_cfg=self.norm_cfg)
+ for i in range(1, self.num_convs):
+ self.convs[i] = DepthwiseSeparableConvModule(
+ self.channels,
+ self.channels,
+ kernel_size=self.kernel_size,
+ padding=self.kernel_size // 2,
+ norm_cfg=self.norm_cfg)
+
+ if self.concat_input:
+ self.conv_cat = DepthwiseSeparableConvModule(
+ self.in_channels + self.channels,
+ self.channels,
+ kernel_size=self.kernel_size,
+ padding=self.kernel_size // 2,
+ norm_cfg=self.norm_cfg)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/uper_head.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/uper_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..952473578c1f5b903f5fc7f9d13a4e40ea5dec87
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/decode_heads/uper_head.py
@@ -0,0 +1,126 @@
+import torch
+import torch.nn as nn
+from annotator.mmpkg.mmcv.cnn import ConvModule
+
+from annotator.mmpkg.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+from .psp_head import PPM
+
+
+@HEADS.register_module()
+class UPerHead(BaseDecodeHead):
+ """Unified Perceptual Parsing for Scene Understanding.
+
+ This head is the implementation of `UPerNet
+ `_.
+
+ Args:
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+ Module applied on the last feature. Default: (1, 2, 3, 6).
+ """
+
+ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
+ super(UPerHead, self).__init__(
+ input_transform='multiple_select', **kwargs)
+ # PSP Module
+ self.psp_modules = PPM(
+ pool_scales,
+ self.in_channels[-1],
+ self.channels,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ align_corners=self.align_corners)
+ self.bottleneck = ConvModule(
+ self.in_channels[-1] + len(pool_scales) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ # FPN Module
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+ for in_channels in self.in_channels[:-1]: # skip the top layer
+ l_conv = ConvModule(
+ in_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ inplace=False)
+ fpn_conv = ConvModule(
+ self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ inplace=False)
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ self.fpn_bottleneck = ConvModule(
+ len(self.in_channels) * self.channels,
+ self.channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def psp_forward(self, inputs):
+ """Forward function of PSP module."""
+ x = inputs[-1]
+ psp_outs = [x]
+ psp_outs.extend(self.psp_modules(x))
+ psp_outs = torch.cat(psp_outs, dim=1)
+ output = self.bottleneck(psp_outs)
+
+ return output
+
+ def forward(self, inputs):
+ """Forward function."""
+
+ inputs = self._transform_inputs(inputs)
+
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+
+ laterals.append(self.psp_forward(inputs))
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] += resize(
+ laterals[i],
+ size=prev_shape,
+ mode='bilinear',
+ align_corners=self.align_corners)
+
+ # build outputs
+ fpn_outs = [
+ self.fpn_convs[i](laterals[i])
+ for i in range(used_backbone_levels - 1)
+ ]
+ # append psp feature
+ fpn_outs.append(laterals[-1])
+
+ for i in range(used_backbone_levels - 1, 0, -1):
+ fpn_outs[i] = resize(
+ fpn_outs[i],
+ size=fpn_outs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ fpn_outs = torch.cat(fpn_outs, dim=1)
+ output = self.fpn_bottleneck(fpn_outs)
+ output = self.cls_seg(output)
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..beca72045694273d63465bac2f27dbc6672271db
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/__init__.py
@@ -0,0 +1,12 @@
+from .accuracy import Accuracy, accuracy
+from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
+ cross_entropy, mask_cross_entropy)
+from .dice_loss import DiceLoss
+from .lovasz_loss import LovaszLoss
+from .utils import reduce_loss, weight_reduce_loss, weighted_loss
+
+__all__ = [
+ 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
+ 'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
+ 'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/accuracy.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/accuracy.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0fd2e7e74a0f721c4a814c09d6e453e5956bb38
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/accuracy.py
@@ -0,0 +1,78 @@
+import torch.nn as nn
+
+
+def accuracy(pred, target, topk=1, thresh=None):
+ """Calculate accuracy according to the prediction and target.
+
+ Args:
+ pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
+ target (torch.Tensor): The target of each prediction, shape (N, , ...)
+ topk (int | tuple[int], optional): If the predictions in ``topk``
+ matches the target, the predictions will be regarded as
+ correct ones. Defaults to 1.
+ thresh (float, optional): If not None, predictions with scores under
+ this threshold are considered incorrect. Default to None.
+
+ Returns:
+ float | tuple[float]: If the input ``topk`` is a single integer,
+ the function will return a single float as accuracy. If
+ ``topk`` is a tuple containing multiple integers, the
+ function will return a tuple containing accuracies of
+ each ``topk`` number.
+ """
+ assert isinstance(topk, (int, tuple))
+ if isinstance(topk, int):
+ topk = (topk, )
+ return_single = True
+ else:
+ return_single = False
+
+ maxk = max(topk)
+ if pred.size(0) == 0:
+ accu = [pred.new_tensor(0.) for i in range(len(topk))]
+ return accu[0] if return_single else accu
+ assert pred.ndim == target.ndim + 1
+ assert pred.size(0) == target.size(0)
+ assert maxk <= pred.size(1), \
+ f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
+ pred_value, pred_label = pred.topk(maxk, dim=1)
+ # transpose to shape (maxk, N, ...)
+ pred_label = pred_label.transpose(0, 1)
+ correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
+ if thresh is not None:
+ # Only prediction values larger than thresh are counted as correct
+ correct = correct & (pred_value > thresh).t()
+ res = []
+ for k in topk:
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / target.numel()))
+ return res[0] if return_single else res
+
+
+class Accuracy(nn.Module):
+ """Accuracy calculation module."""
+
+ def __init__(self, topk=(1, ), thresh=None):
+ """Module to calculate the accuracy.
+
+ Args:
+ topk (tuple, optional): The criterion used to calculate the
+ accuracy. Defaults to (1,).
+ thresh (float, optional): If not None, predictions with scores
+ under this threshold are considered incorrect. Default to None.
+ """
+ super().__init__()
+ self.topk = topk
+ self.thresh = thresh
+
+ def forward(self, pred, target):
+ """Forward function to calculate accuracy.
+
+ Args:
+ pred (torch.Tensor): Prediction of models.
+ target (torch.Tensor): Target for each prediction.
+
+ Returns:
+ tuple[float]: The accuracies under different topk criterions.
+ """
+ return accuracy(pred, target, self.topk, self.thresh)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/cross_entropy_loss.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/cross_entropy_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c0790c98616bb69621deed55547fc04c7392ef
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/cross_entropy_loss.py
@@ -0,0 +1,198 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import get_class_weight, weight_reduce_loss
+
+
+def cross_entropy(pred,
+ label,
+ weight=None,
+ class_weight=None,
+ reduction='mean',
+ avg_factor=None,
+ ignore_index=-100):
+ """The wrapper function for :func:`F.cross_entropy`"""
+ # class_weight is a manual rescaling weight given to each class.
+ # If given, has to be a Tensor of size C element-wise losses
+ loss = F.cross_entropy(
+ pred,
+ label,
+ weight=class_weight,
+ reduction='none',
+ ignore_index=ignore_index)
+
+ # apply weights and do the reduction
+ if weight is not None:
+ weight = weight.float()
+ loss = weight_reduce_loss(
+ loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
+
+ return loss
+
+
+def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
+ """Expand onehot labels to match the size of prediction."""
+ bin_labels = labels.new_zeros(target_shape)
+ valid_mask = (labels >= 0) & (labels != ignore_index)
+ inds = torch.nonzero(valid_mask, as_tuple=True)
+
+ if inds[0].numel() > 0:
+ if labels.dim() == 3:
+ bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
+ else:
+ bin_labels[inds[0], labels[valid_mask]] = 1
+
+ valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
+ if label_weights is None:
+ bin_label_weights = valid_mask
+ else:
+ bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
+ bin_label_weights *= valid_mask
+
+ return bin_labels, bin_label_weights
+
+
+def binary_cross_entropy(pred,
+ label,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=255):
+ """Calculate the binary CrossEntropy loss.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, 1).
+ label (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (int | None): The label index to be ignored. Default: 255
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ if pred.dim() != label.dim():
+ assert (pred.dim() == 2 and label.dim() == 1) or (
+ pred.dim() == 4 and label.dim() == 3), \
+ 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
+ 'H, W], label shape [N, H, W] are supported'
+ label, weight = _expand_onehot_labels(label, weight, pred.shape,
+ ignore_index)
+
+ # weighted element-wise losses
+ if weight is not None:
+ weight = weight.float()
+ loss = F.binary_cross_entropy_with_logits(
+ pred, label.float(), pos_weight=class_weight, reduction='none')
+ # do the reduction for the weighted loss
+ loss = weight_reduce_loss(
+ loss, weight, reduction=reduction, avg_factor=avg_factor)
+
+ return loss
+
+
+def mask_cross_entropy(pred,
+ target,
+ label,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None,
+ ignore_index=None):
+ """Calculate the CrossEntropy loss for masks.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
+ of classes.
+ target (torch.Tensor): The learning label of the prediction.
+ label (torch.Tensor): ``label`` indicates the class label of the mask'
+ corresponding object. This will be used to select the mask in the
+ of the class which the object belongs to when the mask prediction
+ if not class-agnostic.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+ ignore_index (None): Placeholder, to be consistent with other loss.
+ Default: None.
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert ignore_index is None, 'BCE loss does not support ignore_index'
+ # TODO: handle these two reserved arguments
+ assert reduction == 'mean' and avg_factor is None
+ num_rois = pred.size()[0]
+ inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
+ pred_slice = pred[inds, label].squeeze(1)
+ return F.binary_cross_entropy_with_logits(
+ pred_slice, target, weight=class_weight, reduction='mean')[None]
+
+
+@LOSSES.register_module()
+class CrossEntropyLoss(nn.Module):
+ """CrossEntropyLoss.
+
+ Args:
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
+ of softmax. Defaults to False.
+ use_mask (bool, optional): Whether to use mask cross entropy loss.
+ Defaults to False.
+ reduction (str, optional): . Defaults to 'mean'.
+ Options are "none", "mean" and "sum".
+ class_weight (list[float] | str, optional): Weight of each class. If in
+ str format, read them from a file. Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+ """
+
+ def __init__(self,
+ use_sigmoid=False,
+ use_mask=False,
+ reduction='mean',
+ class_weight=None,
+ loss_weight=1.0):
+ super(CrossEntropyLoss, self).__init__()
+ assert (use_sigmoid is False) or (use_mask is False)
+ self.use_sigmoid = use_sigmoid
+ self.use_mask = use_mask
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.class_weight = get_class_weight(class_weight)
+
+ if self.use_sigmoid:
+ self.cls_criterion = binary_cross_entropy
+ elif self.use_mask:
+ self.cls_criterion = mask_cross_entropy
+ else:
+ self.cls_criterion = cross_entropy
+
+ def forward(self,
+ cls_score,
+ label,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function."""
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.class_weight is not None:
+ class_weight = cls_score.new_tensor(self.class_weight)
+ else:
+ class_weight = None
+ loss_cls = self.loss_weight * self.cls_criterion(
+ cls_score,
+ label,
+ weight,
+ class_weight=class_weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss_cls
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/dice_loss.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/dice_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..27a77b962d7d8b3079c7d6cd9db52280c6fb4970
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/dice_loss.py
@@ -0,0 +1,119 @@
+"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/
+segmentron/solver/loss.py (Apache-2.0 License)"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import get_class_weight, weighted_loss
+
+
+@weighted_loss
+def dice_loss(pred,
+ target,
+ valid_mask,
+ smooth=1,
+ exponent=2,
+ class_weight=None,
+ ignore_index=255):
+ assert pred.shape[0] == target.shape[0]
+ total_loss = 0
+ num_classes = pred.shape[1]
+ for i in range(num_classes):
+ if i != ignore_index:
+ dice_loss = binary_dice_loss(
+ pred[:, i],
+ target[..., i],
+ valid_mask=valid_mask,
+ smooth=smooth,
+ exponent=exponent)
+ if class_weight is not None:
+ dice_loss *= class_weight[i]
+ total_loss += dice_loss
+ return total_loss / num_classes
+
+
+@weighted_loss
+def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
+ assert pred.shape[0] == target.shape[0]
+ pred = pred.reshape(pred.shape[0], -1)
+ target = target.reshape(target.shape[0], -1)
+ valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
+
+ num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
+ den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
+
+ return 1 - num / den
+
+
+@LOSSES.register_module()
+class DiceLoss(nn.Module):
+ """DiceLoss.
+
+ This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
+ Volumetric Medical Image Segmentation `_.
+
+ Args:
+ loss_type (str, optional): Binary or multi-class loss.
+ Default: 'multi_class'. Options are "binary" and "multi_class".
+ smooth (float): A float number to smooth loss, and avoid NaN error.
+ Default: 1
+ exponent (float): An float number to calculate denominator
+ value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ class_weight (list[float] | str, optional): Weight of each class. If in
+ str format, read them from a file. Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Default to 1.0.
+ ignore_index (int | None): The label index to be ignored. Default: 255.
+ """
+
+ def __init__(self,
+ smooth=1,
+ exponent=2,
+ reduction='mean',
+ class_weight=None,
+ loss_weight=1.0,
+ ignore_index=255,
+ **kwards):
+ super(DiceLoss, self).__init__()
+ self.smooth = smooth
+ self.exponent = exponent
+ self.reduction = reduction
+ self.class_weight = get_class_weight(class_weight)
+ self.loss_weight = loss_weight
+ self.ignore_index = ignore_index
+
+ def forward(self,
+ pred,
+ target,
+ avg_factor=None,
+ reduction_override=None,
+ **kwards):
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.class_weight is not None:
+ class_weight = pred.new_tensor(self.class_weight)
+ else:
+ class_weight = None
+
+ pred = F.softmax(pred, dim=1)
+ num_classes = pred.shape[1]
+ one_hot_target = F.one_hot(
+ torch.clamp(target.long(), 0, num_classes - 1),
+ num_classes=num_classes)
+ valid_mask = (target != self.ignore_index).long()
+
+ loss = self.loss_weight * dice_loss(
+ pred,
+ one_hot_target,
+ valid_mask=valid_mask,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ smooth=self.smooth,
+ exponent=self.exponent,
+ class_weight=class_weight,
+ ignore_index=self.ignore_index)
+ return loss
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/lovasz_loss.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/lovasz_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..50f0f70fd432316b081a0114c28df61d320b5a47
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/lovasz_loss.py
@@ -0,0 +1,303 @@
+"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
+ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
+Berman 2018 ESAT-PSI KU Leuven (MIT License)"""
+
+import annotator.mmpkg.mmcv as mmcv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import get_class_weight, weight_reduce_loss
+
+
+def lovasz_grad(gt_sorted):
+ """Computes gradient of the Lovasz extension w.r.t sorted errors.
+
+ See Alg. 1 in paper.
+ """
+ p = len(gt_sorted)
+ gts = gt_sorted.sum()
+ intersection = gts - gt_sorted.float().cumsum(0)
+ union = gts + (1 - gt_sorted).float().cumsum(0)
+ jaccard = 1. - intersection / union
+ if p > 1: # cover 1-pixel case
+ jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
+ return jaccard
+
+
+def flatten_binary_logits(logits, labels, ignore_index=None):
+ """Flattens predictions in the batch (binary case) Remove labels equal to
+ 'ignore_index'."""
+ logits = logits.view(-1)
+ labels = labels.view(-1)
+ if ignore_index is None:
+ return logits, labels
+ valid = (labels != ignore_index)
+ vlogits = logits[valid]
+ vlabels = labels[valid]
+ return vlogits, vlabels
+
+
+def flatten_probs(probs, labels, ignore_index=None):
+ """Flattens predictions in the batch."""
+ if probs.dim() == 3:
+ # assumes output of a sigmoid layer
+ B, H, W = probs.size()
+ probs = probs.view(B, 1, H, W)
+ B, C, H, W = probs.size()
+ probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C
+ labels = labels.view(-1)
+ if ignore_index is None:
+ return probs, labels
+ valid = (labels != ignore_index)
+ vprobs = probs[valid.nonzero().squeeze()]
+ vlabels = labels[valid]
+ return vprobs, vlabels
+
+
+def lovasz_hinge_flat(logits, labels):
+ """Binary Lovasz hinge loss.
+
+ Args:
+ logits (torch.Tensor): [P], logits at each prediction
+ (between -infty and +infty).
+ labels (torch.Tensor): [P], binary ground truth labels (0 or 1).
+
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ if len(labels) == 0:
+ # only void pixels, the gradients should be 0
+ return logits.sum() * 0.
+ signs = 2. * labels.float() - 1.
+ errors = (1. - logits * signs)
+ errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
+ perm = perm.data
+ gt_sorted = labels[perm]
+ grad = lovasz_grad(gt_sorted)
+ loss = torch.dot(F.relu(errors_sorted), grad)
+ return loss
+
+
+def lovasz_hinge(logits,
+ labels,
+ classes='present',
+ per_image=False,
+ class_weight=None,
+ reduction='mean',
+ avg_factor=None,
+ ignore_index=255):
+ """Binary Lovasz hinge loss.
+
+ Args:
+ logits (torch.Tensor): [B, H, W], logits at each pixel
+ (between -infty and +infty).
+ labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).
+ classes (str | list[int], optional): Placeholder, to be consistent with
+ other loss. Default: None.
+ per_image (bool, optional): If per_image is True, compute the loss per
+ image instead of per batch. Default: False.
+ class_weight (list[float], optional): Placeholder, to be consistent
+ with other loss. Default: None.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. This parameter only works when per_image is True.
+ Default: None.
+ ignore_index (int | None): The label index to be ignored. Default: 255.
+
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ if per_image:
+ loss = [
+ lovasz_hinge_flat(*flatten_binary_logits(
+ logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
+ for logit, label in zip(logits, labels)
+ ]
+ loss = weight_reduce_loss(
+ torch.stack(loss), None, reduction, avg_factor)
+ else:
+ loss = lovasz_hinge_flat(
+ *flatten_binary_logits(logits, labels, ignore_index))
+ return loss
+
+
+def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):
+ """Multi-class Lovasz-Softmax loss.
+
+ Args:
+ probs (torch.Tensor): [P, C], class probabilities at each prediction
+ (between 0 and 1).
+ labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).
+ classes (str | list[int], optional): Classes chosen to calculate loss.
+ 'all' for all classes, 'present' for classes present in labels, or
+ a list of classes to average. Default: 'present'.
+ class_weight (list[float], optional): The weight for each class.
+ Default: None.
+
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+ if probs.numel() == 0:
+ # only void pixels, the gradients should be 0
+ return probs * 0.
+ C = probs.size(1)
+ losses = []
+ class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
+ for c in class_to_sum:
+ fg = (labels == c).float() # foreground for class c
+ if (classes == 'present' and fg.sum() == 0):
+ continue
+ if C == 1:
+ if len(classes) > 1:
+ raise ValueError('Sigmoid output possible only with 1 class')
+ class_pred = probs[:, 0]
+ else:
+ class_pred = probs[:, c]
+ errors = (fg - class_pred).abs()
+ errors_sorted, perm = torch.sort(errors, 0, descending=True)
+ perm = perm.data
+ fg_sorted = fg[perm]
+ loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
+ if class_weight is not None:
+ loss *= class_weight[c]
+ losses.append(loss)
+ return torch.stack(losses).mean()
+
+
+def lovasz_softmax(probs,
+ labels,
+ classes='present',
+ per_image=False,
+ class_weight=None,
+ reduction='mean',
+ avg_factor=None,
+ ignore_index=255):
+ """Multi-class Lovasz-Softmax loss.
+
+ Args:
+ probs (torch.Tensor): [B, C, H, W], class probabilities at each
+ prediction (between 0 and 1).
+ labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and
+ C - 1).
+ classes (str | list[int], optional): Classes chosen to calculate loss.
+ 'all' for all classes, 'present' for classes present in labels, or
+ a list of classes to average. Default: 'present'.
+ per_image (bool, optional): If per_image is True, compute the loss per
+ image instead of per batch. Default: False.
+ class_weight (list[float], optional): The weight for each class.
+ Default: None.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. This parameter only works when per_image is True.
+ Default: None.
+ ignore_index (int | None): The label index to be ignored. Default: 255.
+
+ Returns:
+ torch.Tensor: The calculated loss.
+ """
+
+ if per_image:
+ loss = [
+ lovasz_softmax_flat(
+ *flatten_probs(
+ prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
+ classes=classes,
+ class_weight=class_weight)
+ for prob, label in zip(probs, labels)
+ ]
+ loss = weight_reduce_loss(
+ torch.stack(loss), None, reduction, avg_factor)
+ else:
+ loss = lovasz_softmax_flat(
+ *flatten_probs(probs, labels, ignore_index),
+ classes=classes,
+ class_weight=class_weight)
+ return loss
+
+
+@LOSSES.register_module()
+class LovaszLoss(nn.Module):
+ """LovaszLoss.
+
+ This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
+ for the optimization of the intersection-over-union measure in neural
+ networks `_.
+
+ Args:
+ loss_type (str, optional): Binary or multi-class loss.
+ Default: 'multi_class'. Options are "binary" and "multi_class".
+ classes (str | list[int], optional): Classes chosen to calculate loss.
+ 'all' for all classes, 'present' for classes present in labels, or
+ a list of classes to average. Default: 'present'.
+ per_image (bool, optional): If per_image is True, compute the loss per
+ image instead of per batch. Default: False.
+ reduction (str, optional): The method used to reduce the loss. Options
+ are "none", "mean" and "sum". This parameter only works when
+ per_image is True. Default: 'mean'.
+ class_weight (list[float] | str, optional): Weight of each class. If in
+ str format, read them from a file. Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+ """
+
+ def __init__(self,
+ loss_type='multi_class',
+ classes='present',
+ per_image=False,
+ reduction='mean',
+ class_weight=None,
+ loss_weight=1.0):
+ super(LovaszLoss, self).__init__()
+ assert loss_type in ('binary', 'multi_class'), "loss_type should be \
+ 'binary' or 'multi_class'."
+
+ if loss_type == 'binary':
+ self.cls_criterion = lovasz_hinge
+ else:
+ self.cls_criterion = lovasz_softmax
+ assert classes in ('all', 'present') or mmcv.is_list_of(classes, int)
+ if not per_image:
+ assert reduction == 'none', "reduction should be 'none' when \
+ per_image is False."
+
+ self.classes = classes
+ self.per_image = per_image
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.class_weight = get_class_weight(class_weight)
+
+ def forward(self,
+ cls_score,
+ label,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function."""
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.class_weight is not None:
+ class_weight = cls_score.new_tensor(self.class_weight)
+ else:
+ class_weight = None
+
+ # if multi-class loss, transform logits to probs
+ if self.cls_criterion == lovasz_softmax:
+ cls_score = F.softmax(cls_score, dim=1)
+
+ loss_cls = self.loss_weight * self.cls_criterion(
+ cls_score,
+ label,
+ self.classes,
+ self.per_image,
+ class_weight=class_weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss_cls
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/utils.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2afb477a153ba9dead71066fa66ee024482afd82
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/losses/utils.py
@@ -0,0 +1,121 @@
+import functools
+
+import annotator.mmpkg.mmcv as mmcv
+import numpy as np
+import torch.nn.functional as F
+
+
+def get_class_weight(class_weight):
+ """Get class weight for loss function.
+
+ Args:
+ class_weight (list[float] | str | None): If class_weight is a str,
+ take it as a file name and read from it.
+ """
+ if isinstance(class_weight, str):
+ # take it as a file path
+ if class_weight.endswith('.npy'):
+ class_weight = np.load(class_weight)
+ else:
+ # pkl, json or yaml
+ class_weight = mmcv.load(class_weight)
+
+ return class_weight
+
+
+def reduce_loss(loss, reduction):
+ """Reduce loss as specified.
+
+ Args:
+ loss (Tensor): Elementwise loss tensor.
+ reduction (str): Options are "none", "mean" and "sum".
+
+ Return:
+ Tensor: Reduced loss tensor.
+ """
+ reduction_enum = F._Reduction.get_enum(reduction)
+ # none: 0, elementwise_mean:1, sum: 2
+ if reduction_enum == 0:
+ return loss
+ elif reduction_enum == 1:
+ return loss.mean()
+ elif reduction_enum == 2:
+ return loss.sum()
+
+
+def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
+ """Apply element-wise weight and reduce loss.
+
+ Args:
+ loss (Tensor): Element-wise loss.
+ weight (Tensor): Element-wise weights.
+ reduction (str): Same as built-in losses of PyTorch.
+ avg_factor (float): Avarage factor when computing the mean of losses.
+
+ Returns:
+ Tensor: Processed loss values.
+ """
+ # if weight is specified, apply element-wise weight
+ if weight is not None:
+ assert weight.dim() == loss.dim()
+ if weight.dim() > 1:
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
+ loss = loss * weight
+
+ # if avg_factor is not specified, just reduce the loss
+ if avg_factor is None:
+ loss = reduce_loss(loss, reduction)
+ else:
+ # if reduction is mean, then average the loss by avg_factor
+ if reduction == 'mean':
+ loss = loss.sum() / avg_factor
+ # if reduction is 'none', then do nothing, otherwise raise an error
+ elif reduction != 'none':
+ raise ValueError('avg_factor can not be used with reduction="sum"')
+ return loss
+
+
+def weighted_loss(loss_func):
+ """Create a weighted version of a given loss function.
+
+ To use this decorator, the loss function must have the signature like
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
+ element-wise loss without any reduction. This decorator will add weight
+ and reduction arguments to the function. The decorated function will have
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
+ avg_factor=None, **kwargs)`.
+
+ :Example:
+
+ >>> import torch
+ >>> @weighted_loss
+ >>> def l1_loss(pred, target):
+ >>> return (pred - target).abs()
+
+ >>> pred = torch.Tensor([0, 2, 3])
+ >>> target = torch.Tensor([1, 1, 1])
+ >>> weight = torch.Tensor([1, 0, 1])
+
+ >>> l1_loss(pred, target)
+ tensor(1.3333)
+ >>> l1_loss(pred, target, weight)
+ tensor(1.)
+ >>> l1_loss(pred, target, reduction='none')
+ tensor([1., 1., 2.])
+ >>> l1_loss(pred, target, weight, avg_factor=2)
+ tensor(1.5000)
+ """
+
+ @functools.wraps(loss_func)
+ def wrapper(pred,
+ target,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ **kwargs):
+ # get element-wise loss
+ loss = loss_func(pred, target, **kwargs)
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+ return wrapper
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/necks/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/necks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b9d3d5b3fe80247642d962edd6fb787537d01d6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/necks/__init__.py
@@ -0,0 +1,4 @@
+from .fpn import FPN
+from .multilevel_neck import MultiLevelNeck
+
+__all__ = ['FPN', 'MultiLevelNeck']
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/necks/fpn.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/necks/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba47bbe1a0225587315627ac288e5ddf6497a244
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/necks/fpn.py
@@ -0,0 +1,212 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.mmpkg.mmcv.cnn import ConvModule, xavier_init
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class FPN(nn.Module):
+ """Feature Pyramid Network.
+
+ This is an implementation of - Feature Pyramid Networks for Object
+ Detection (https://arxiv.org/abs/1612.03144)
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale)
+ num_outs (int): Number of output scales.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool | str): If bool, it decides whether to add conv
+ layers on top of the original feature maps. Default to False.
+ If True, its actual mode is specified by `extra_convs_on_inputs`.
+ If str, it specifies the source feature map of the extra convs.
+ Only the following options are allowed
+
+ - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
+ - 'on_lateral': Last feature map after lateral convs.
+ - 'on_output': The last output feature map after fpn convs.
+ extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
+ on the original feature from the backbone. If True,
+ it is equivalent to `add_extra_convs='on_input'`. If False, it is
+ equivalent to set `add_extra_convs='on_output'`. Default to True.
+ relu_before_extra_convs (bool): Whether to apply relu before the extra
+ conv. Default: False.
+ no_norm_on_lateral (bool): Whether to apply norm on lateral.
+ Default: False.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (str): Config dict for activation layer in ConvModule.
+ Default: None.
+ upsample_cfg (dict): Config dict for interpolate layer.
+ Default: `dict(mode='nearest')`
+
+ Example:
+ >>> import torch
+ >>> in_channels = [2, 3, 5, 7]
+ >>> scales = [340, 170, 84, 43]
+ >>> inputs = [torch.rand(1, c, s, s)
+ ... for c, s in zip(in_channels, scales)]
+ >>> self = FPN(in_channels, 11, len(in_channels)).eval()
+ >>> outputs = self.forward(inputs)
+ >>> for i in range(len(outputs)):
+ ... print(f'outputs[{i}].shape = {outputs[i].shape}')
+ outputs[0].shape = torch.Size([1, 11, 340, 340])
+ outputs[1].shape = torch.Size([1, 11, 170, 170])
+ outputs[2].shape = torch.Size([1, 11, 84, 84])
+ outputs[3].shape = torch.Size([1, 11, 43, 43])
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=0,
+ end_level=-1,
+ add_extra_convs=False,
+ extra_convs_on_inputs=False,
+ relu_before_extra_convs=False,
+ no_norm_on_lateral=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None,
+ upsample_cfg=dict(mode='nearest')):
+ super(FPN, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.relu_before_extra_convs = relu_before_extra_convs
+ self.no_norm_on_lateral = no_norm_on_lateral
+ self.fp16_enabled = False
+ self.upsample_cfg = upsample_cfg.copy()
+
+ if end_level == -1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level < inputs, no extra level is allowed
+ self.backbone_end_level = end_level
+ assert end_level <= len(in_channels)
+ assert num_outs == end_level - start_level
+ self.start_level = start_level
+ self.end_level = end_level
+ self.add_extra_convs = add_extra_convs
+ assert isinstance(add_extra_convs, (str, bool))
+ if isinstance(add_extra_convs, str):
+ # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
+ assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
+ elif add_extra_convs: # True
+ if extra_convs_on_inputs:
+ # For compatibility with previous release
+ # TODO: deprecate `extra_convs_on_inputs`
+ self.add_extra_convs = 'on_input'
+ else:
+ self.add_extra_convs = 'on_output'
+
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
+ act_cfg=act_cfg,
+ inplace=False)
+ fpn_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ # add extra conv layers (e.g., RetinaNet)
+ extra_levels = num_outs - self.backbone_end_level + self.start_level
+ if self.add_extra_convs and extra_levels >= 1:
+ for i in range(extra_levels):
+ if i == 0 and self.add_extra_convs == 'on_input':
+ in_channels = self.in_channels[self.backbone_end_level - 1]
+ else:
+ in_channels = out_channels
+ extra_fpn_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.fpn_convs.append(extra_fpn_conv)
+
+ # default init_weights for conv(msra) and norm in ConvModule
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+
+ def forward(self, inputs):
+ assert len(inputs) == len(self.in_channels)
+
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
+ # it cannot co-exist with `size` in `F.interpolate`.
+ if 'scale_factor' in self.upsample_cfg:
+ laterals[i - 1] += F.interpolate(laterals[i],
+ **self.upsample_cfg)
+ else:
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] += F.interpolate(
+ laterals[i], size=prev_shape, **self.upsample_cfg)
+
+ # build outputs
+ # part 1: from original levels
+ outs = [
+ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
+ ]
+ # part 2: add extra levels
+ if self.num_outs > len(outs):
+ # use max pool to get more levels on top of outputs
+ # (e.g., Faster R-CNN, Mask R-CNN)
+ if not self.add_extra_convs:
+ for i in range(self.num_outs - used_backbone_levels):
+ outs.append(F.max_pool2d(outs[-1], 1, stride=2))
+ # add conv layers on top of original feature maps (RetinaNet)
+ else:
+ if self.add_extra_convs == 'on_input':
+ extra_source = inputs[self.backbone_end_level - 1]
+ elif self.add_extra_convs == 'on_lateral':
+ extra_source = laterals[-1]
+ elif self.add_extra_convs == 'on_output':
+ extra_source = outs[-1]
+ else:
+ raise NotImplementedError
+ outs.append(self.fpn_convs[used_backbone_levels](extra_source))
+ for i in range(used_backbone_levels + 1, self.num_outs):
+ if self.relu_before_extra_convs:
+ outs.append(self.fpn_convs[i](F.relu(outs[-1])))
+ else:
+ outs.append(self.fpn_convs[i](outs[-1]))
+ return tuple(outs)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/necks/multilevel_neck.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/necks/multilevel_neck.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b86c073cd1a72354d2426846125e80f7ab20dbc
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/necks/multilevel_neck.py
@@ -0,0 +1,70 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.mmpkg.mmcv.cnn import ConvModule
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class MultiLevelNeck(nn.Module):
+ """MultiLevelNeck.
+
+ A neck structure connect vit backbone and decoder_heads.
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale).
+ scales (List[int]): Scale factors for each input feature map.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer in ConvModule.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ scales=[0.5, 1, 2, 4],
+ norm_cfg=None,
+ act_cfg=None):
+ super(MultiLevelNeck, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.scales = scales
+ self.num_outs = len(scales)
+ self.lateral_convs = nn.ModuleList()
+ self.convs = nn.ModuleList()
+ for in_channel in in_channels:
+ self.lateral_convs.append(
+ ConvModule(
+ in_channel,
+ out_channels,
+ kernel_size=1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ for _ in range(self.num_outs):
+ self.convs.append(
+ ConvModule(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ def forward(self, inputs):
+ assert len(inputs) == len(self.in_channels)
+ print(inputs[0].shape)
+ inputs = [
+ lateral_conv(inputs[i])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ # for len(inputs) not equal to self.num_outs
+ if len(inputs) == 1:
+ inputs = [inputs[0] for _ in range(self.num_outs)]
+ outs = []
+ for i in range(self.num_outs):
+ x_resize = F.interpolate(
+ inputs[i], scale_factor=self.scales[i], mode='bilinear')
+ outs.append(self.convs[i](x_resize))
+ return tuple(outs)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/segmentors/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/segmentors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dca2f09405330743c476e190896bee39c45498ea
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/segmentors/__init__.py
@@ -0,0 +1,5 @@
+from .base import BaseSegmentor
+from .cascade_encoder_decoder import CascadeEncoderDecoder
+from .encoder_decoder import EncoderDecoder
+
+__all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder']
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/segmentors/base.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/segmentors/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..a12d8beb8ea40bfa234197eddb4d3ef40dbfeb6f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/segmentors/base.py
@@ -0,0 +1,273 @@
+import logging
+import warnings
+from abc import ABCMeta, abstractmethod
+from collections import OrderedDict
+
+import annotator.mmpkg.mmcv as mmcv
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from annotator.mmpkg.mmcv.runner import auto_fp16
+
+
+class BaseSegmentor(nn.Module):
+ """Base class for segmentors."""
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self):
+ super(BaseSegmentor, self).__init__()
+ self.fp16_enabled = False
+
+ @property
+ def with_neck(self):
+ """bool: whether the segmentor has neck"""
+ return hasattr(self, 'neck') and self.neck is not None
+
+ @property
+ def with_auxiliary_head(self):
+ """bool: whether the segmentor has auxiliary head"""
+ return hasattr(self,
+ 'auxiliary_head') and self.auxiliary_head is not None
+
+ @property
+ def with_decode_head(self):
+ """bool: whether the segmentor has decode head"""
+ return hasattr(self, 'decode_head') and self.decode_head is not None
+
+ @abstractmethod
+ def extract_feat(self, imgs):
+ """Placeholder for extract features from images."""
+ pass
+
+ @abstractmethod
+ def encode_decode(self, img, img_metas):
+ """Placeholder for encode images with backbone and decode into a
+ semantic segmentation map of the same size as input."""
+ pass
+
+ @abstractmethod
+ def forward_train(self, imgs, img_metas, **kwargs):
+ """Placeholder for Forward function for training."""
+ pass
+
+ @abstractmethod
+ def simple_test(self, img, img_meta, **kwargs):
+ """Placeholder for single image test."""
+ pass
+
+ @abstractmethod
+ def aug_test(self, imgs, img_metas, **kwargs):
+ """Placeholder for augmentation test."""
+ pass
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in segmentor.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if pretrained is not None:
+ logger = logging.getLogger()
+ logger.info(f'load model from: {pretrained}')
+
+ def forward_test(self, imgs, img_metas, **kwargs):
+ """
+ Args:
+ imgs (List[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (List[List[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch.
+ """
+ for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
+ if not isinstance(var, list):
+ raise TypeError(f'{name} must be a list, but got '
+ f'{type(var)}')
+
+ num_augs = len(imgs)
+ if num_augs != len(img_metas):
+ raise ValueError(f'num of augmentations ({len(imgs)}) != '
+ f'num of image meta ({len(img_metas)})')
+ # all images in the same aug batch all of the same ori_shape and pad
+ # shape
+ for img_meta in img_metas:
+ ori_shapes = [_['ori_shape'] for _ in img_meta]
+ assert all(shape == ori_shapes[0] for shape in ori_shapes)
+ img_shapes = [_['img_shape'] for _ in img_meta]
+ assert all(shape == img_shapes[0] for shape in img_shapes)
+ pad_shapes = [_['pad_shape'] for _ in img_meta]
+ assert all(shape == pad_shapes[0] for shape in pad_shapes)
+
+ if num_augs == 1:
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
+ else:
+ return self.aug_test(imgs, img_metas, **kwargs)
+
+ @auto_fp16(apply_to=('img', ))
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
+ on whether ``return_loss`` is ``True``.
+
+ Note this setting will change the expected inputs. When
+ ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
+ and List[dict]), and when ``resturn_loss=False``, img and img_meta
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
+ the outer list indicating test time augmentations.
+ """
+ if return_loss:
+ return self.forward_train(img, img_metas, **kwargs)
+ else:
+ return self.forward_test(img, img_metas, **kwargs)
+
+ def train_step(self, data_batch, optimizer, **kwargs):
+ """The iteration step during training.
+
+ This method defines an iteration step during training, except for the
+ back propagation and optimizer updating, which are done in an optimizer
+ hook. Note that in some complicated cases or models, the whole process
+ including back propagation and optimizer updating is also defined in
+ this method, such as GAN.
+
+ Args:
+ data (dict): The output of dataloader.
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
+ runner is passed to ``train_step()``. This argument is unused
+ and reserved.
+
+ Returns:
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
+ ``num_samples``.
+ ``loss`` is a tensor for back propagation, which can be a
+ weighted sum of multiple losses.
+ ``log_vars`` contains all the variables to be sent to the
+ logger.
+ ``num_samples`` indicates the batch size (when the model is
+ DDP, it means the batch size on each GPU), which is used for
+ averaging the logs.
+ """
+ losses = self(**data_batch)
+ loss, log_vars = self._parse_losses(losses)
+
+ outputs = dict(
+ loss=loss,
+ log_vars=log_vars,
+ num_samples=len(data_batch['img_metas']))
+
+ return outputs
+
+ def val_step(self, data_batch, **kwargs):
+ """The iteration step during validation.
+
+ This method shares the same signature as :func:`train_step`, but used
+ during val epochs. Note that the evaluation after training epochs is
+ not implemented with this method, but an evaluation hook.
+ """
+ output = self(**data_batch, **kwargs)
+ return output
+
+ @staticmethod
+ def _parse_losses(losses):
+ """Parse the raw outputs (losses) of the network.
+
+ Args:
+ losses (dict): Raw output of the network, which usually contain
+ losses and other necessary information.
+
+ Returns:
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
+ which may be a weighted sum of all losses, log_vars contains
+ all the variables to be sent to the logger.
+ """
+ log_vars = OrderedDict()
+ for loss_name, loss_value in losses.items():
+ if isinstance(loss_value, torch.Tensor):
+ log_vars[loss_name] = loss_value.mean()
+ elif isinstance(loss_value, list):
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+ else:
+ raise TypeError(
+ f'{loss_name} is not a tensor or list of tensors')
+
+ loss = sum(_value for _key, _value in log_vars.items()
+ if 'loss' in _key)
+
+ log_vars['loss'] = loss
+ for loss_name, loss_value in log_vars.items():
+ # reduce loss when distributed training
+ if dist.is_available() and dist.is_initialized():
+ loss_value = loss_value.data.clone()
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
+ log_vars[loss_name] = loss_value.item()
+
+ return loss, log_vars
+
+ def show_result(self,
+ img,
+ result,
+ palette=None,
+ win_name='',
+ show=False,
+ wait_time=0,
+ out_file=None,
+ opacity=0.5):
+ """Draw `result` over `img`.
+
+ Args:
+ img (str or Tensor): The image to be displayed.
+ result (Tensor): The semantic segmentation results to draw over
+ `img`.
+ palette (list[list[int]]] | np.ndarray | None): The palette of
+ segmentation map. If None is given, random palette will be
+ generated. Default: None
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ Default: 0.
+ show (bool): Whether to show the image.
+ Default: False.
+ out_file (str or None): The filename to write the image.
+ Default: None.
+ opacity(float): Opacity of painted segmentation map.
+ Default 0.5.
+ Must be in (0, 1] range.
+ Returns:
+ img (Tensor): Only if not `show` or `out_file`
+ """
+ img = mmcv.imread(img)
+ img = img.copy()
+ seg = result[0]
+ if palette is None:
+ if self.PALETTE is None:
+ palette = np.random.randint(
+ 0, 255, size=(len(self.CLASSES), 3))
+ else:
+ palette = self.PALETTE
+ palette = np.array(palette)
+ assert palette.shape[0] == len(self.CLASSES)
+ assert palette.shape[1] == 3
+ assert len(palette.shape) == 2
+ assert 0 < opacity <= 1.0
+ color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
+ for label, color in enumerate(palette):
+ color_seg[seg == label, :] = color
+ # convert to BGR
+ color_seg = color_seg[..., ::-1]
+
+ img = img * (1 - opacity) + color_seg * opacity
+ img = img.astype(np.uint8)
+ # if out_file specified, do not show image in window
+ if out_file is not None:
+ show = False
+
+ if show:
+ mmcv.imshow(img, win_name, wait_time)
+ if out_file is not None:
+ mmcv.imwrite(img, out_file)
+
+ if not (show or out_file):
+ warnings.warn('show==False and out_file is not specified, only '
+ 'result image will be returned')
+ return img
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/segmentors/cascade_encoder_decoder.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/segmentors/cascade_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..74547f0fb01da9fe32c1d142768eb788b7e8673c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/segmentors/cascade_encoder_decoder.py
@@ -0,0 +1,98 @@
+from torch import nn
+
+from annotator.mmpkg.mmseg.core import add_prefix
+from annotator.mmpkg.mmseg.ops import resize
+from .. import builder
+from ..builder import SEGMENTORS
+from .encoder_decoder import EncoderDecoder
+
+
+@SEGMENTORS.register_module()
+class CascadeEncoderDecoder(EncoderDecoder):
+ """Cascade Encoder Decoder segmentors.
+
+ CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of
+ CascadeEncoderDecoder are cascaded. The output of previous decoder_head
+ will be the input of next decoder_head.
+ """
+
+ def __init__(self,
+ num_stages,
+ backbone,
+ decode_head,
+ neck=None,
+ auxiliary_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ self.num_stages = num_stages
+ super(CascadeEncoderDecoder, self).__init__(
+ backbone=backbone,
+ decode_head=decode_head,
+ neck=neck,
+ auxiliary_head=auxiliary_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained)
+
+ def _init_decode_head(self, decode_head):
+ """Initialize ``decode_head``"""
+ assert isinstance(decode_head, list)
+ assert len(decode_head) == self.num_stages
+ self.decode_head = nn.ModuleList()
+ for i in range(self.num_stages):
+ self.decode_head.append(builder.build_head(decode_head[i]))
+ self.align_corners = self.decode_head[-1].align_corners
+ self.num_classes = self.decode_head[-1].num_classes
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone and heads.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ self.backbone.init_weights(pretrained=pretrained)
+ for i in range(self.num_stages):
+ self.decode_head[i].init_weights()
+ if self.with_auxiliary_head:
+ if isinstance(self.auxiliary_head, nn.ModuleList):
+ for aux_head in self.auxiliary_head:
+ aux_head.init_weights()
+ else:
+ self.auxiliary_head.init_weights()
+
+ def encode_decode(self, img, img_metas):
+ """Encode images with backbone and decode into a semantic segmentation
+ map of the same size as input."""
+ x = self.extract_feat(img)
+ out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg)
+ for i in range(1, self.num_stages):
+ out = self.decode_head[i].forward_test(x, out, img_metas,
+ self.test_cfg)
+ out = resize(
+ input=out,
+ size=img.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ return out
+
+ def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):
+ """Run forward function and calculate loss for decode head in
+ training."""
+ losses = dict()
+
+ loss_decode = self.decode_head[0].forward_train(
+ x, img_metas, gt_semantic_seg, self.train_cfg)
+
+ losses.update(add_prefix(loss_decode, 'decode_0'))
+
+ for i in range(1, self.num_stages):
+ # forward test again, maybe unnecessary for most methods.
+ prev_outputs = self.decode_head[i - 1].forward_test(
+ x, img_metas, self.test_cfg)
+ loss_decode = self.decode_head[i].forward_train(
+ x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg)
+ losses.update(add_prefix(loss_decode, f'decode_{i}'))
+
+ return losses
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/segmentors/encoder_decoder.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/segmentors/encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..30c25f35a15e65e45f9221a3f19ace8579f73301
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/segmentors/encoder_decoder.py
@@ -0,0 +1,298 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from annotator.mmpkg.mmseg.core import add_prefix
+from annotator.mmpkg.mmseg.ops import resize
+from .. import builder
+from ..builder import SEGMENTORS
+from .base import BaseSegmentor
+
+
+@SEGMENTORS.register_module()
+class EncoderDecoder(BaseSegmentor):
+ """Encoder Decoder segmentors.
+
+ EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
+ Note that auxiliary_head is only used for deep supervision during training,
+ which could be dumped during inference.
+ """
+
+ def __init__(self,
+ backbone,
+ decode_head,
+ neck=None,
+ auxiliary_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(EncoderDecoder, self).__init__()
+ self.backbone = builder.build_backbone(backbone)
+ if neck is not None:
+ self.neck = builder.build_neck(neck)
+ self._init_decode_head(decode_head)
+ self._init_auxiliary_head(auxiliary_head)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ self.init_weights(pretrained=pretrained)
+
+ assert self.with_decode_head
+
+ def _init_decode_head(self, decode_head):
+ """Initialize ``decode_head``"""
+ self.decode_head = builder.build_head(decode_head)
+ self.align_corners = self.decode_head.align_corners
+ self.num_classes = self.decode_head.num_classes
+
+ def _init_auxiliary_head(self, auxiliary_head):
+ """Initialize ``auxiliary_head``"""
+ if auxiliary_head is not None:
+ if isinstance(auxiliary_head, list):
+ self.auxiliary_head = nn.ModuleList()
+ for head_cfg in auxiliary_head:
+ self.auxiliary_head.append(builder.build_head(head_cfg))
+ else:
+ self.auxiliary_head = builder.build_head(auxiliary_head)
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone and heads.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+
+ super(EncoderDecoder, self).init_weights(pretrained)
+ self.backbone.init_weights(pretrained=pretrained)
+ self.decode_head.init_weights()
+ if self.with_auxiliary_head:
+ if isinstance(self.auxiliary_head, nn.ModuleList):
+ for aux_head in self.auxiliary_head:
+ aux_head.init_weights()
+ else:
+ self.auxiliary_head.init_weights()
+
+ def extract_feat(self, img):
+ """Extract features from images."""
+ x = self.backbone(img)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def encode_decode(self, img, img_metas):
+ """Encode images with backbone and decode into a semantic segmentation
+ map of the same size as input."""
+ x = self.extract_feat(img)
+ out = self._decode_head_forward_test(x, img_metas)
+ out = resize(
+ input=out,
+ size=img.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ return out
+
+ def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):
+ """Run forward function and calculate loss for decode head in
+ training."""
+ losses = dict()
+ loss_decode = self.decode_head.forward_train(x, img_metas,
+ gt_semantic_seg,
+ self.train_cfg)
+
+ losses.update(add_prefix(loss_decode, 'decode'))
+ return losses
+
+ def _decode_head_forward_test(self, x, img_metas):
+ """Run forward function and calculate loss for decode head in
+ inference."""
+ seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg)
+ return seg_logits
+
+ def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg):
+ """Run forward function and calculate loss for auxiliary head in
+ training."""
+ losses = dict()
+ if isinstance(self.auxiliary_head, nn.ModuleList):
+ for idx, aux_head in enumerate(self.auxiliary_head):
+ loss_aux = aux_head.forward_train(x, img_metas,
+ gt_semantic_seg,
+ self.train_cfg)
+ losses.update(add_prefix(loss_aux, f'aux_{idx}'))
+ else:
+ loss_aux = self.auxiliary_head.forward_train(
+ x, img_metas, gt_semantic_seg, self.train_cfg)
+ losses.update(add_prefix(loss_aux, 'aux'))
+
+ return losses
+
+ def forward_dummy(self, img):
+ """Dummy forward function."""
+ seg_logit = self.encode_decode(img, None)
+
+ return seg_logit
+
+ def forward_train(self, img, img_metas, gt_semantic_seg):
+ """Forward function for training.
+
+ Args:
+ img (Tensor): Input images.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ gt_semantic_seg (Tensor): Semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+
+ x = self.extract_feat(img)
+
+ losses = dict()
+
+ loss_decode = self._decode_head_forward_train(x, img_metas,
+ gt_semantic_seg)
+ losses.update(loss_decode)
+
+ if self.with_auxiliary_head:
+ loss_aux = self._auxiliary_head_forward_train(
+ x, img_metas, gt_semantic_seg)
+ losses.update(loss_aux)
+
+ return losses
+
+ # TODO refactor
+ def slide_inference(self, img, img_meta, rescale):
+ """Inference by sliding-window with overlap.
+
+ If h_crop > h_img or w_crop > w_img, the small patch will be used to
+ decode without padding.
+ """
+
+ h_stride, w_stride = self.test_cfg.stride
+ h_crop, w_crop = self.test_cfg.crop_size
+ batch_size, _, h_img, w_img = img.size()
+ num_classes = self.num_classes
+ h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
+ w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
+ preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
+ for h_idx in range(h_grids):
+ for w_idx in range(w_grids):
+ y1 = h_idx * h_stride
+ x1 = w_idx * w_stride
+ y2 = min(y1 + h_crop, h_img)
+ x2 = min(x1 + w_crop, w_img)
+ y1 = max(y2 - h_crop, 0)
+ x1 = max(x2 - w_crop, 0)
+ crop_img = img[:, :, y1:y2, x1:x2]
+ crop_seg_logit = self.encode_decode(crop_img, img_meta)
+ preds += F.pad(crop_seg_logit,
+ (int(x1), int(preds.shape[3] - x2), int(y1),
+ int(preds.shape[2] - y2)))
+
+ count_mat[:, :, y1:y2, x1:x2] += 1
+ assert (count_mat == 0).sum() == 0
+ if torch.onnx.is_in_onnx_export():
+ # cast count_mat to constant while exporting to ONNX
+ count_mat = torch.from_numpy(
+ count_mat.cpu().detach().numpy()).to(device=img.device)
+ preds = preds / count_mat
+ if rescale:
+ preds = resize(
+ preds,
+ size=img_meta[0]['ori_shape'][:2],
+ mode='bilinear',
+ align_corners=self.align_corners,
+ warning=False)
+ return preds
+
+ def whole_inference(self, img, img_meta, rescale):
+ """Inference with full image."""
+
+ seg_logit = self.encode_decode(img, img_meta)
+ if rescale:
+ # support dynamic shape for onnx
+ if torch.onnx.is_in_onnx_export():
+ size = img.shape[2:]
+ else:
+ size = img_meta[0]['ori_shape'][:2]
+ seg_logit = resize(
+ seg_logit,
+ size=size,
+ mode='bilinear',
+ align_corners=self.align_corners,
+ warning=False)
+
+ return seg_logit
+
+ def inference(self, img, img_meta, rescale):
+ """Inference with slide/whole style.
+
+ Args:
+ img (Tensor): The input image of shape (N, 3, H, W).
+ img_meta (dict): Image info dict where each dict has: 'img_shape',
+ 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
+ rescale (bool): Whether rescale back to original shape.
+
+ Returns:
+ Tensor: The output segmentation map.
+ """
+
+ assert self.test_cfg.mode in ['slide', 'whole']
+ ori_shape = img_meta[0]['ori_shape']
+ assert all(_['ori_shape'] == ori_shape for _ in img_meta)
+ if self.test_cfg.mode == 'slide':
+ seg_logit = self.slide_inference(img, img_meta, rescale)
+ else:
+ seg_logit = self.whole_inference(img, img_meta, rescale)
+ output = F.softmax(seg_logit, dim=1)
+ flip = img_meta[0]['flip']
+ if flip:
+ flip_direction = img_meta[0]['flip_direction']
+ assert flip_direction in ['horizontal', 'vertical']
+ if flip_direction == 'horizontal':
+ output = output.flip(dims=(3, ))
+ elif flip_direction == 'vertical':
+ output = output.flip(dims=(2, ))
+
+ return output
+
+ def simple_test(self, img, img_meta, rescale=True):
+ """Simple test with single image."""
+ seg_logit = self.inference(img, img_meta, rescale)
+ seg_pred = seg_logit.argmax(dim=1)
+ if torch.onnx.is_in_onnx_export():
+ # our inference backend only support 4D output
+ seg_pred = seg_pred.unsqueeze(0)
+ return seg_pred
+ seg_pred = seg_pred.cpu().numpy()
+ # unravel batch dim
+ seg_pred = list(seg_pred)
+ return seg_pred
+
+ def aug_test(self, imgs, img_metas, rescale=True):
+ """Test with augmentations.
+
+ Only rescale=True is supported.
+ """
+ # aug_test rescale all imgs back to ori_shape for now
+ assert rescale
+ # to save memory, we get augmented seg logit inplace
+ seg_logit = self.inference(imgs[0], img_metas[0], rescale)
+ for i in range(1, len(imgs)):
+ cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
+ seg_logit += cur_seg_logit
+ seg_logit /= len(imgs)
+ seg_pred = seg_logit.argmax(dim=1)
+ seg_pred = seg_pred.cpu().numpy()
+ # unravel batch dim
+ seg_pred = list(seg_pred)
+ return seg_pred
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d3bdd349b9f2ae499a2fcb2ac1d2e3c77befebe
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/__init__.py
@@ -0,0 +1,13 @@
+from .drop import DropPath
+from .inverted_residual import InvertedResidual, InvertedResidualV3
+from .make_divisible import make_divisible
+from .res_layer import ResLayer
+from .se_layer import SELayer
+from .self_attention_block import SelfAttentionBlock
+from .up_conv_block import UpConvBlock
+from .weight_init import trunc_normal_
+
+__all__ = [
+ 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
+ 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'trunc_normal_'
+]
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/drop.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..4520b0ff407d2a95a864086bdbca0065f222aa63
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/drop.py
@@ -0,0 +1,31 @@
+"""Modified from https://github.com/rwightman/pytorch-image-
+models/blob/master/timm/models/layers/drop.py."""
+
+import torch
+from torch import nn
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
+ residual blocks).
+
+ Args:
+ drop_prob (float): Drop rate for paths of model. Dropout rate has
+ to be between 0 and 1. Default: 0.
+ """
+
+ def __init__(self, drop_prob=0.):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.keep_prob = 1 - drop_prob
+
+ def forward(self, x):
+ if self.drop_prob == 0. or not self.training:
+ return x
+ shape = (x.shape[0], ) + (1, ) * (
+ x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = self.keep_prob + torch.rand(
+ shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(self.keep_prob) * random_tensor
+ return output
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/inverted_residual.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/inverted_residual.py
new file mode 100644
index 0000000000000000000000000000000000000000..2df5ebd7c94c0a66b0d05ef9e200ddbeabfa79f6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/inverted_residual.py
@@ -0,0 +1,208 @@
+from annotator.mmpkg.mmcv.cnn import ConvModule
+from torch import nn
+from torch.utils import checkpoint as cp
+
+from .se_layer import SELayer
+
+
+class InvertedResidual(nn.Module):
+ """InvertedResidual block for MobileNetV2.
+
+ Args:
+ in_channels (int): The input channels of the InvertedResidual block.
+ out_channels (int): The output channels of the InvertedResidual block.
+ stride (int): Stride of the middle (first) 3x3 convolution.
+ expand_ratio (int): Adjusts number of channels of the hidden layer
+ in InvertedResidual by this amount.
+ dilation (int): Dilation rate of depthwise conv. Default: 1
+ conv_cfg (dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU6').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ expand_ratio,
+ dilation=1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU6'),
+ with_cp=False):
+ super(InvertedResidual, self).__init__()
+ self.stride = stride
+ assert stride in [1, 2], f'stride must in [1, 2]. ' \
+ f'But received {stride}.'
+ self.with_cp = with_cp
+ self.use_res_connect = self.stride == 1 and in_channels == out_channels
+ hidden_dim = int(round(in_channels * expand_ratio))
+
+ layers = []
+ if expand_ratio != 1:
+ layers.append(
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=hidden_dim,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ layers.extend([
+ ConvModule(
+ in_channels=hidden_dim,
+ out_channels=hidden_dim,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ groups=hidden_dim,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg),
+ ConvModule(
+ in_channels=hidden_dim,
+ out_channels=out_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ ])
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
+
+
+class InvertedResidualV3(nn.Module):
+ """Inverted Residual Block for MobileNetV3.
+
+ Args:
+ in_channels (int): The input channels of this Module.
+ out_channels (int): The output channels of this Module.
+ mid_channels (int): The input channels of the depthwise convolution.
+ kernel_size (int): The kernel size of the depthwise convolution.
+ Default: 3.
+ stride (int): The stride of the depthwise convolution. Default: 1.
+ se_cfg (dict): Config dict for se layer. Default: None, which means no
+ se layer.
+ with_expand_conv (bool): Use expand conv or not. If set False,
+ mid_channels must be the same with in_channels. Default: True.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU').
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+
+ Returns:
+ Tensor: The output tensor.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ se_cfg=None,
+ with_expand_conv=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ with_cp=False):
+ super(InvertedResidualV3, self).__init__()
+ self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
+ assert stride in [1, 2]
+ self.with_cp = with_cp
+ self.with_se = se_cfg is not None
+ self.with_expand_conv = with_expand_conv
+
+ if self.with_se:
+ assert isinstance(se_cfg, dict)
+ if not self.with_expand_conv:
+ assert mid_channels == in_channels
+
+ if self.with_expand_conv:
+ self.expand_conv = ConvModule(
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.depthwise_conv = ConvModule(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=kernel_size // 2,
+ groups=mid_channels,
+ conv_cfg=dict(
+ type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ if self.with_se:
+ self.se = SELayer(**se_cfg)
+
+ self.linear_conv = ConvModule(
+ in_channels=mid_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ out = x
+
+ if self.with_expand_conv:
+ out = self.expand_conv(out)
+
+ out = self.depthwise_conv(out)
+
+ if self.with_se:
+ out = self.se(out)
+
+ out = self.linear_conv(out)
+
+ if self.with_res_shortcut:
+ return x + out
+ else:
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ return out
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/make_divisible.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/make_divisible.py
new file mode 100644
index 0000000000000000000000000000000000000000..75ad756052529f52fe83bb95dd1f0ecfc9a13078
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/make_divisible.py
@@ -0,0 +1,27 @@
+def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
+ """Make divisible function.
+
+ This function rounds the channel number to the nearest value that can be
+ divisible by the divisor. It is taken from the original tf repo. It ensures
+ that all layers have a channel number that is divisible by divisor. It can
+ be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py # noqa
+
+ Args:
+ value (int): The original channel number.
+ divisor (int): The divisor to fully divide the channel number.
+ min_value (int): The minimum value of the output channel.
+ Default: None, means that the minimum value equal to the divisor.
+ min_ratio (float): The minimum ratio of the rounded channel number to
+ the original channel number. Default: 0.9.
+
+ Returns:
+ int: The modified output channel number.
+ """
+
+ if min_value is None:
+ min_value = divisor
+ new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than (1-min_ratio).
+ if new_value < min_ratio * value:
+ new_value += divisor
+ return new_value
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/res_layer.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/res_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d41075a57356b4fd802bc4ff199e55e63678b589
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/res_layer.py
@@ -0,0 +1,94 @@
+from annotator.mmpkg.mmcv.cnn import build_conv_layer, build_norm_layer
+from torch import nn as nn
+
+
+class ResLayer(nn.Sequential):
+ """ResLayer to build ResNet style backbone.
+
+ Args:
+ block (nn.Module): block used to build ResLayer.
+ inplanes (int): inplanes of block.
+ planes (int): planes of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ multi_grid (int | None): Multi grid dilation rates of last
+ stage. Default: None
+ contract_dilation (bool): Whether contract first dilation of each layer
+ Default: False
+ """
+
+ def __init__(self,
+ block,
+ inplanes,
+ planes,
+ num_blocks,
+ stride=1,
+ dilation=1,
+ avg_down=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ multi_grid=None,
+ contract_dilation=False,
+ **kwargs):
+ self.block = block
+
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = []
+ conv_stride = stride
+ if avg_down:
+ conv_stride = 1
+ downsample.append(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False))
+ downsample.extend([
+ build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False),
+ build_norm_layer(norm_cfg, planes * block.expansion)[1]
+ ])
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ if multi_grid is None:
+ if dilation > 1 and contract_dilation:
+ first_dilation = dilation // 2
+ else:
+ first_dilation = dilation
+ else:
+ first_dilation = multi_grid[0]
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ dilation=first_dilation,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ inplanes = planes * block.expansion
+ for i in range(1, num_blocks):
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=1,
+ dilation=dilation if multi_grid is None else multi_grid[i],
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ super(ResLayer, self).__init__(*layers)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/se_layer.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/se_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..42ab005e1fe2211e9ecb651d31de128cf95cfec7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/se_layer.py
@@ -0,0 +1,57 @@
+import annotator.mmpkg.mmcv as mmcv
+import torch.nn as nn
+from annotator.mmpkg.mmcv.cnn import ConvModule
+
+from .make_divisible import make_divisible
+
+
+class SELayer(nn.Module):
+ """Squeeze-and-Excitation Module.
+
+ Args:
+ channels (int): The input (and output) channels of the SE layer.
+ ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
+ ``int(channels/ratio)``. Default: 16.
+ conv_cfg (None or dict): Config dict for convolution layer.
+ Default: None, which means using conv2d.
+ act_cfg (dict or Sequence[dict]): Config dict for activation layer.
+ If act_cfg is a dict, two activation layers will be configured
+ by this dict. If act_cfg is a sequence of dicts, the first
+ activation layer will be configured by the first dict and the
+ second activation layer will be configured by the second dict.
+ Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0,
+ divisor=6.0)).
+ """
+
+ def __init__(self,
+ channels,
+ ratio=16,
+ conv_cfg=None,
+ act_cfg=(dict(type='ReLU'),
+ dict(type='HSigmoid', bias=3.0, divisor=6.0))):
+ super(SELayer, self).__init__()
+ if isinstance(act_cfg, dict):
+ act_cfg = (act_cfg, act_cfg)
+ assert len(act_cfg) == 2
+ assert mmcv.is_tuple_of(act_cfg, dict)
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
+ self.conv1 = ConvModule(
+ in_channels=channels,
+ out_channels=make_divisible(channels // ratio, 8),
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ act_cfg=act_cfg[0])
+ self.conv2 = ConvModule(
+ in_channels=make_divisible(channels // ratio, 8),
+ out_channels=channels,
+ kernel_size=1,
+ stride=1,
+ conv_cfg=conv_cfg,
+ act_cfg=act_cfg[1])
+
+ def forward(self, x):
+ out = self.global_avgpool(x)
+ out = self.conv1(out)
+ out = self.conv2(out)
+ return x * out
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/self_attention_block.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/self_attention_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..a342e2b29ad53916c98d0342bde8f0f6cb10197a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/self_attention_block.py
@@ -0,0 +1,159 @@
+import torch
+from annotator.mmpkg.mmcv.cnn import ConvModule, constant_init
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+class SelfAttentionBlock(nn.Module):
+ """General self-attention block/non-local block.
+
+ Please refer to https://arxiv.org/abs/1706.03762 for details about key,
+ query and value.
+
+ Args:
+ key_in_channels (int): Input channels of key feature.
+ query_in_channels (int): Input channels of query feature.
+ channels (int): Output channels of key/query transform.
+ out_channels (int): Output channels.
+ share_key_query (bool): Whether share projection weight between key
+ and query projection.
+ query_downsample (nn.Module): Query downsample module.
+ key_downsample (nn.Module): Key downsample module.
+ key_query_num_convs (int): Number of convs for key/query projection.
+ value_num_convs (int): Number of convs for value projection.
+ matmul_norm (bool): Whether normalize attention map with sqrt of
+ channels
+ with_out (bool): Whether use out projection.
+ conv_cfg (dict|None): Config of conv layers.
+ norm_cfg (dict|None): Config of norm layers.
+ act_cfg (dict|None): Config of activation layers.
+ """
+
+ def __init__(self, key_in_channels, query_in_channels, channels,
+ out_channels, share_key_query, query_downsample,
+ key_downsample, key_query_num_convs, value_out_num_convs,
+ key_query_norm, value_out_norm, matmul_norm, with_out,
+ conv_cfg, norm_cfg, act_cfg):
+ super(SelfAttentionBlock, self).__init__()
+ if share_key_query:
+ assert key_in_channels == query_in_channels
+ self.key_in_channels = key_in_channels
+ self.query_in_channels = query_in_channels
+ self.out_channels = out_channels
+ self.channels = channels
+ self.share_key_query = share_key_query
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.key_project = self.build_project(
+ key_in_channels,
+ channels,
+ num_convs=key_query_num_convs,
+ use_conv_module=key_query_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ if share_key_query:
+ self.query_project = self.key_project
+ else:
+ self.query_project = self.build_project(
+ query_in_channels,
+ channels,
+ num_convs=key_query_num_convs,
+ use_conv_module=key_query_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.value_project = self.build_project(
+ key_in_channels,
+ channels if with_out else out_channels,
+ num_convs=value_out_num_convs,
+ use_conv_module=value_out_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ if with_out:
+ self.out_project = self.build_project(
+ channels,
+ out_channels,
+ num_convs=value_out_num_convs,
+ use_conv_module=value_out_norm,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ else:
+ self.out_project = None
+
+ self.query_downsample = query_downsample
+ self.key_downsample = key_downsample
+ self.matmul_norm = matmul_norm
+
+ self.init_weights()
+
+ def init_weights(self):
+ """Initialize weight of later layer."""
+ if self.out_project is not None:
+ if not isinstance(self.out_project, ConvModule):
+ constant_init(self.out_project, 0)
+
+ def build_project(self, in_channels, channels, num_convs, use_conv_module,
+ conv_cfg, norm_cfg, act_cfg):
+ """Build projection layer for key/query/value/out."""
+ if use_conv_module:
+ convs = [
+ ConvModule(
+ in_channels,
+ channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ ]
+ for _ in range(num_convs - 1):
+ convs.append(
+ ConvModule(
+ channels,
+ channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+ else:
+ convs = [nn.Conv2d(in_channels, channels, 1)]
+ for _ in range(num_convs - 1):
+ convs.append(nn.Conv2d(channels, channels, 1))
+ if len(convs) > 1:
+ convs = nn.Sequential(*convs)
+ else:
+ convs = convs[0]
+ return convs
+
+ def forward(self, query_feats, key_feats):
+ """Forward function."""
+ batch_size = query_feats.size(0)
+ query = self.query_project(query_feats)
+ if self.query_downsample is not None:
+ query = self.query_downsample(query)
+ query = query.reshape(*query.shape[:2], -1)
+ query = query.permute(0, 2, 1).contiguous()
+
+ key = self.key_project(key_feats)
+ value = self.value_project(key_feats)
+ if self.key_downsample is not None:
+ key = self.key_downsample(key)
+ value = self.key_downsample(value)
+ key = key.reshape(*key.shape[:2], -1)
+ value = value.reshape(*value.shape[:2], -1)
+ value = value.permute(0, 2, 1).contiguous()
+
+ sim_map = torch.matmul(query, key)
+ if self.matmul_norm:
+ sim_map = (self.channels**-.5) * sim_map
+ sim_map = F.softmax(sim_map, dim=-1)
+
+ context = torch.matmul(sim_map, value)
+ context = context.permute(0, 2, 1).contiguous()
+ context = context.reshape(batch_size, -1, *query_feats.shape[2:])
+ if self.out_project is not None:
+ context = self.out_project(context)
+ return context
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/up_conv_block.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/up_conv_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..86328011a9704d17e9f9d0d54994719ead5caa56
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/up_conv_block.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from annotator.mmpkg.mmcv.cnn import ConvModule, build_upsample_layer
+
+
+class UpConvBlock(nn.Module):
+ """Upsample convolution block in decoder for UNet.
+
+ This upsample convolution block consists of one upsample module
+ followed by one convolution block. The upsample module expands the
+ high-level low-resolution feature map and the convolution block fuses
+ the upsampled high-level low-resolution feature map and the low-level
+ high-resolution feature map from encoder.
+
+ Args:
+ conv_block (nn.Sequential): Sequential of convolutional layers.
+ in_channels (int): Number of input channels of the high-level
+ skip_channels (int): Number of input channels of the low-level
+ high-resolution feature map from encoder.
+ out_channels (int): Number of output channels.
+ num_convs (int): Number of convolutional layers in the conv_block.
+ Default: 2.
+ stride (int): Stride of convolutional layer in conv_block. Default: 1.
+ dilation (int): Dilation rate of convolutional layer in conv_block.
+ Default: 1.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed. Default: False.
+ conv_cfg (dict | None): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict | None): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict | None): Config dict for activation layer in ConvModule.
+ Default: dict(type='ReLU').
+ upsample_cfg (dict): The upsample config of the upsample module in
+ decoder. Default: dict(type='InterpConv'). If the size of
+ high-level feature map is the same as that of skip feature map
+ (low-level feature map from encoder), it does not need upsample the
+ high-level feature map and the upsample_cfg is None.
+ dcn (bool): Use deformable convolution in convolutional layer or not.
+ Default: None.
+ plugins (dict): plugins for convolutional layers. Default: None.
+ """
+
+ def __init__(self,
+ conv_block,
+ in_channels,
+ skip_channels,
+ out_channels,
+ num_convs=2,
+ stride=1,
+ dilation=1,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU'),
+ upsample_cfg=dict(type='InterpConv'),
+ dcn=None,
+ plugins=None):
+ super(UpConvBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.conv_block = conv_block(
+ in_channels=2 * skip_channels,
+ out_channels=out_channels,
+ num_convs=num_convs,
+ stride=stride,
+ dilation=dilation,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ dcn=None,
+ plugins=None)
+ if upsample_cfg is not None:
+ self.upsample = build_upsample_layer(
+ cfg=upsample_cfg,
+ in_channels=in_channels,
+ out_channels=skip_channels,
+ with_cp=with_cp,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ else:
+ self.upsample = ConvModule(
+ in_channels,
+ skip_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+
+ def forward(self, skip, x):
+ """Forward function."""
+
+ x = self.upsample(x)
+ out = torch.cat([skip, x], dim=1)
+ out = self.conv_block(out)
+
+ return out
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/weight_init.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/weight_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..38141ba3d61f64ddfc0a31574b4648cbad96d7dd
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/models/utils/weight_init.py
@@ -0,0 +1,62 @@
+"""Modified from https://github.com/rwightman/pytorch-image-
+models/blob/master/timm/models/layers/drop.py."""
+
+import math
+import warnings
+
+import torch
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ """Reference: https://people.sc.fsu.edu/~jburkardt/presentations
+ /truncated_normal.pdf"""
+
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ lower_bound = norm_cdf((a - mean) / std)
+ upper_bound = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * lower_bound - 1, 2 * upper_bound - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`
+ mean (float): the mean of the normal distribution
+ std (float): the standard deviation of the normal distribution
+ a (float): the minimum cutoff value
+ b (float): the maximum cutoff value
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/ops/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bec51c75b9363a9a19e9fb5c35f4e7dbd6f7751c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/ops/__init__.py
@@ -0,0 +1,4 @@
+from .encoding import Encoding
+from .wrappers import Upsample, resize
+
+__all__ = ['Upsample', 'resize', 'Encoding']
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/ops/encoding.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/ops/encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eb3629a6426550b8e4c537ee1ff4341893e489e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/ops/encoding.py
@@ -0,0 +1,74 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class Encoding(nn.Module):
+ """Encoding Layer: a learnable residual encoder.
+
+ Input is of shape (batch_size, channels, height, width).
+ Output is of shape (batch_size, num_codes, channels).
+
+ Args:
+ channels: dimension of the features or feature channels
+ num_codes: number of code words
+ """
+
+ def __init__(self, channels, num_codes):
+ super(Encoding, self).__init__()
+ # init codewords and smoothing factor
+ self.channels, self.num_codes = channels, num_codes
+ std = 1. / ((num_codes * channels)**0.5)
+ # [num_codes, channels]
+ self.codewords = nn.Parameter(
+ torch.empty(num_codes, channels,
+ dtype=torch.float).uniform_(-std, std),
+ requires_grad=True)
+ # [num_codes]
+ self.scale = nn.Parameter(
+ torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
+ requires_grad=True)
+
+ @staticmethod
+ def scaled_l2(x, codewords, scale):
+ num_codes, channels = codewords.size()
+ batch_size = x.size(0)
+ reshaped_scale = scale.view((1, 1, num_codes))
+ expanded_x = x.unsqueeze(2).expand(
+ (batch_size, x.size(1), num_codes, channels))
+ reshaped_codewords = codewords.view((1, 1, num_codes, channels))
+
+ scaled_l2_norm = reshaped_scale * (
+ expanded_x - reshaped_codewords).pow(2).sum(dim=3)
+ return scaled_l2_norm
+
+ @staticmethod
+ def aggregate(assignment_weights, x, codewords):
+ num_codes, channels = codewords.size()
+ reshaped_codewords = codewords.view((1, 1, num_codes, channels))
+ batch_size = x.size(0)
+
+ expanded_x = x.unsqueeze(2).expand(
+ (batch_size, x.size(1), num_codes, channels))
+ encoded_feat = (assignment_weights.unsqueeze(3) *
+ (expanded_x - reshaped_codewords)).sum(dim=1)
+ return encoded_feat
+
+ def forward(self, x):
+ assert x.dim() == 4 and x.size(1) == self.channels
+ # [batch_size, channels, height, width]
+ batch_size = x.size(0)
+ # [batch_size, height x width, channels]
+ x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
+ # assignment_weights: [batch_size, channels, num_codes]
+ assignment_weights = F.softmax(
+ self.scaled_l2(x, self.codewords, self.scale), dim=2)
+ # aggregate
+ encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
+ return encoded_feat
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \
+ f'x{self.channels})'
+ return repr_str
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/ops/wrappers.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/ops/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ed9a0cb8d7c0e0ec2748dd89c652756653cac78
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/ops/wrappers.py
@@ -0,0 +1,50 @@
+import warnings
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def resize(input,
+ size=None,
+ scale_factor=None,
+ mode='nearest',
+ align_corners=None,
+ warning=True):
+ if warning:
+ if size is not None and align_corners:
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
+ output_h, output_w = tuple(int(x) for x in size)
+ if output_h > input_h or output_w > output_h:
+ if ((output_h > 1 and output_w > 1 and input_h > 1
+ and input_w > 1) and (output_h - 1) % (input_h - 1)
+ and (output_w - 1) % (input_w - 1)):
+ warnings.warn(
+ f'When align_corners={align_corners}, '
+ 'the output would more aligned if '
+ f'input size {(input_h, input_w)} is `x+1` and '
+ f'out size {(output_h, output_w)} is `nx+1`')
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
+
+
+class Upsample(nn.Module):
+
+ def __init__(self,
+ size=None,
+ scale_factor=None,
+ mode='nearest',
+ align_corners=None):
+ super(Upsample, self).__init__()
+ self.size = size
+ if isinstance(scale_factor, tuple):
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
+ else:
+ self.scale_factor = float(scale_factor) if scale_factor else None
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ if not self.size:
+ size = [int(t * self.scale_factor) for t in x.shape[-2:]]
+ else:
+ size = self.size
+ return resize(x, size, None, self.mode, self.align_corners)
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/utils/__init__.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac489e2dbbc0e6fa87f5088b4edcc20f8cadc1a6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/utils/__init__.py
@@ -0,0 +1,4 @@
+from .collect_env import collect_env
+from .logger import get_root_logger
+
+__all__ = ['get_root_logger', 'collect_env']
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/utils/collect_env.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/utils/collect_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..015d5a6b4f3ff31859cca36584879f646b3864d4
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/utils/collect_env.py
@@ -0,0 +1,17 @@
+from annotator.mmpkg.mmcv.utils import collect_env as collect_base_env
+from annotator.mmpkg.mmcv.utils import get_git_hash
+
+import annotator.mmpkg.mmseg as mmseg
+
+
+def collect_env():
+ """Collect the information of the running environments."""
+ env_info = collect_base_env()
+ env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'
+
+ return env_info
+
+
+if __name__ == '__main__':
+ for name, val in collect_env().items():
+ print('{}: {}'.format(name, val))
diff --git a/sd-webui-controlnet/annotator/mmpkg/mmseg/utils/logger.py b/sd-webui-controlnet/annotator/mmpkg/mmseg/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c37733358e3e21479b41f54220bfe34b482009c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/mmpkg/mmseg/utils/logger.py
@@ -0,0 +1,27 @@
+import logging
+
+from annotator.mmpkg.mmcv.utils import get_logger
+
+
+def get_root_logger(log_file=None, log_level=logging.INFO):
+ """Get the root logger.
+
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added. The name of the root logger is the top-level package name,
+ e.g., "mmseg".
+
+ Args:
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+
+ Returns:
+ logging.Logger: The root logger.
+ """
+
+ logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level)
+
+ return logger
diff --git a/sd-webui-controlnet/annotator/normalbae/LICENSE b/sd-webui-controlnet/annotator/normalbae/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..16a9d56a3d4c15e4f34ac5426459c58487b01520
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Caroline Chan
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/normalbae/__init__.py b/sd-webui-controlnet/annotator/normalbae/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..74a94e1738c775b3754f7087b7ddbc6108c81a46
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/__init__.py
@@ -0,0 +1,81 @@
+import os
+import types
+import torch
+import numpy as np
+
+from einops import rearrange
+from .models.NNET import NNET
+from modules import devices
+from annotator.annotator_path import models_path
+import torchvision.transforms as transforms
+
+
+# load model
+def load_checkpoint(fpath, model):
+ ckpt = torch.load(fpath, map_location='cpu')['model']
+
+ load_dict = {}
+ for k, v in ckpt.items():
+ if k.startswith('module.'):
+ k_ = k.replace('module.', '')
+ load_dict[k_] = v
+ else:
+ load_dict[k] = v
+
+ model.load_state_dict(load_dict)
+ return model
+
+
+class NormalBaeDetector:
+ model_dir = os.path.join(models_path, "normal_bae")
+
+ def __init__(self):
+ self.model = None
+ self.device = devices.get_device_for("controlnet")
+
+ def load_model(self):
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/scannet.pt"
+ modelpath = os.path.join(self.model_dir, "scannet.pt")
+ if not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=self.model_dir)
+ args = types.SimpleNamespace()
+ args.mode = 'client'
+ args.architecture = 'BN'
+ args.pretrained = 'scannet'
+ args.sampling_ratio = 0.4
+ args.importance_ratio = 0.7
+ model = NNET(args)
+ model = load_checkpoint(modelpath, model)
+ model.eval()
+ self.model = model.to(self.device)
+ self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ def unload_model(self):
+ if self.model is not None:
+ self.model.cpu()
+
+ def __call__(self, input_image):
+ if self.model is None:
+ self.load_model()
+
+ self.model.to(self.device)
+ assert input_image.ndim == 3
+ image_normal = input_image
+ with torch.no_grad():
+ image_normal = torch.from_numpy(image_normal).float().to(self.device)
+ image_normal = image_normal / 255.0
+ image_normal = rearrange(image_normal, 'h w c -> 1 c h w')
+ image_normal = self.norm(image_normal)
+
+ normal = self.model(image_normal)
+ normal = normal[0][-1][:, :3]
+ # d = torch.sum(normal ** 2.0, dim=1, keepdim=True) ** 0.5
+ # d = torch.maximum(d, torch.ones_like(d) * 1e-5)
+ # normal /= d
+ normal = ((normal + 1) * 0.5).clip(0, 1)
+
+ normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy()
+ normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8)
+
+ return normal_image
diff --git a/sd-webui-controlnet/annotator/normalbae/models/NNET.py b/sd-webui-controlnet/annotator/normalbae/models/NNET.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ddbc50c3ac18aa4b7f16779fe3c0133981ecc7a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/NNET.py
@@ -0,0 +1,22 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .submodules.encoder import Encoder
+from .submodules.decoder import Decoder
+
+
+class NNET(nn.Module):
+ def __init__(self, args):
+ super(NNET, self).__init__()
+ self.encoder = Encoder()
+ self.decoder = Decoder(args)
+
+ def get_1x_lr_params(self): # lr/10 learning rate
+ return self.encoder.parameters()
+
+ def get_10x_lr_params(self): # lr learning rate
+ return self.decoder.parameters()
+
+ def forward(self, img, **kwargs):
+ return self.decoder(self.encoder(img), **kwargs)
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/normalbae/models/baseline.py b/sd-webui-controlnet/annotator/normalbae/models/baseline.py
new file mode 100644
index 0000000000000000000000000000000000000000..602d0fbdac1acc9ede9bc1f2e10a5df78831ce9d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/baseline.py
@@ -0,0 +1,85 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .submodules.submodules import UpSampleBN, norm_normalize
+
+
+# This is the baseline encoder-decoder we used in the ablation study
+class NNET(nn.Module):
+ def __init__(self, args=None):
+ super(NNET, self).__init__()
+ self.encoder = Encoder()
+ self.decoder = Decoder(num_classes=4)
+
+ def forward(self, x, **kwargs):
+ out = self.decoder(self.encoder(x), **kwargs)
+
+ # Bilinearly upsample the output to match the input resolution
+ up_out = F.interpolate(out, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False)
+
+ # L2-normalize the first three channels / ensure positive value for concentration parameters (kappa)
+ up_out = norm_normalize(up_out)
+ return up_out
+
+ def get_1x_lr_params(self): # lr/10 learning rate
+ return self.encoder.parameters()
+
+ def get_10x_lr_params(self): # lr learning rate
+ modules = [self.decoder]
+ for m in modules:
+ yield from m.parameters()
+
+
+# Encoder
+class Encoder(nn.Module):
+ def __init__(self):
+ super(Encoder, self).__init__()
+
+ basemodel_name = 'tf_efficientnet_b5_ap'
+ basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True)
+
+ # Remove last layer
+ basemodel.global_pool = nn.Identity()
+ basemodel.classifier = nn.Identity()
+
+ self.original_model = basemodel
+
+ def forward(self, x):
+ features = [x]
+ for k, v in self.original_model._modules.items():
+ if (k == 'blocks'):
+ for ki, vi in v._modules.items():
+ features.append(vi(features[-1]))
+ else:
+ features.append(v(features[-1]))
+ return features
+
+
+# Decoder (no pixel-wise MLP, no uncertainty-guided sampling)
+class Decoder(nn.Module):
+ def __init__(self, num_classes=4):
+ super(Decoder, self).__init__()
+ self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
+ self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
+ self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
+ self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
+ self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
+ self.conv3 = nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, features):
+ x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11]
+ x_d0 = self.conv2(x_block4)
+ x_d1 = self.up1(x_d0, x_block3)
+ x_d2 = self.up2(x_d1, x_block2)
+ x_d3 = self.up3(x_d2, x_block1)
+ x_d4 = self.up4(x_d3, x_block0)
+ out = self.conv3(x_d4)
+ return out
+
+
+if __name__ == '__main__':
+ model = Baseline()
+ x = torch.rand(2, 3, 480, 640)
+ out = model(x)
+ print(out.shape)
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/decoder.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..993203d1792311f1c492091eaea3c1ac9088187f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/decoder.py
@@ -0,0 +1,202 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .submodules import UpSampleBN, UpSampleGN, norm_normalize, sample_points
+
+
+class Decoder(nn.Module):
+ def __init__(self, args):
+ super(Decoder, self).__init__()
+
+ # hyper-parameter for sampling
+ self.sampling_ratio = args.sampling_ratio
+ self.importance_ratio = args.importance_ratio
+
+ # feature-map
+ self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
+ if args.architecture == 'BN':
+ self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
+ self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
+ self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
+ self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
+
+ elif args.architecture == 'GN':
+ self.up1 = UpSampleGN(skip_input=2048 + 176, output_features=1024)
+ self.up2 = UpSampleGN(skip_input=1024 + 64, output_features=512)
+ self.up3 = UpSampleGN(skip_input=512 + 40, output_features=256)
+ self.up4 = UpSampleGN(skip_input=256 + 24, output_features=128)
+
+ else:
+ raise Exception('invalid architecture')
+
+ # produces 1/8 res output
+ self.out_conv_res8 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
+
+ # produces 1/4 res output
+ self.out_conv_res4 = nn.Sequential(
+ nn.Conv1d(512 + 4, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 4, kernel_size=1),
+ )
+
+ # produces 1/2 res output
+ self.out_conv_res2 = nn.Sequential(
+ nn.Conv1d(256 + 4, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 4, kernel_size=1),
+ )
+
+ # produces 1/1 res output
+ self.out_conv_res1 = nn.Sequential(
+ nn.Conv1d(128 + 4, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
+ nn.Conv1d(128, 4, kernel_size=1),
+ )
+
+ def forward(self, features, gt_norm_mask=None, mode='test'):
+ x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11]
+
+ # generate feature-map
+
+ x_d0 = self.conv2(x_block4) # x_d0 : [2, 2048, 15, 20] 1/32 res
+ x_d1 = self.up1(x_d0, x_block3) # x_d1 : [2, 1024, 30, 40] 1/16 res
+ x_d2 = self.up2(x_d1, x_block2) # x_d2 : [2, 512, 60, 80] 1/8 res
+ x_d3 = self.up3(x_d2, x_block1) # x_d3: [2, 256, 120, 160] 1/4 res
+ x_d4 = self.up4(x_d3, x_block0) # x_d4: [2, 128, 240, 320] 1/2 res
+
+ # 1/8 res output
+ out_res8 = self.out_conv_res8(x_d2) # out_res8: [2, 4, 60, 80] 1/8 res output
+ out_res8 = norm_normalize(out_res8) # out_res8: [2, 4, 60, 80] 1/8 res output
+
+ ################################################################################################################
+ # out_res4
+ ################################################################################################################
+
+ if mode == 'train':
+ # upsampling ... out_res8: [2, 4, 60, 80] -> out_res8_res4: [2, 4, 120, 160]
+ out_res8_res4 = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
+ B, _, H, W = out_res8_res4.shape
+
+ # samples: [B, 1, N, 2]
+ point_coords_res4, rows_int, cols_int = sample_points(out_res8_res4.detach(), gt_norm_mask,
+ sampling_ratio=self.sampling_ratio,
+ beta=self.importance_ratio)
+
+ # output (needed for evaluation / visualization)
+ out_res4 = out_res8_res4
+
+ # grid_sample feature-map
+ feat_res4 = F.grid_sample(x_d2, point_coords_res4, mode='bilinear', align_corners=True) # (B, 512, 1, N)
+ init_pred = F.grid_sample(out_res8, point_coords_res4, mode='bilinear', align_corners=True) # (B, 4, 1, N)
+ feat_res4 = torch.cat([feat_res4, init_pred], dim=1) # (B, 512+4, 1, N)
+
+ # prediction (needed to compute loss)
+ samples_pred_res4 = self.out_conv_res4(feat_res4[:, :, 0, :]) # (B, 4, N)
+ samples_pred_res4 = norm_normalize(samples_pred_res4) # (B, 4, N) - normalized
+
+ for i in range(B):
+ out_res4[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res4[i, :, :]
+
+ else:
+ # grid_sample feature-map
+ feat_map = F.interpolate(x_d2, scale_factor=2, mode='bilinear', align_corners=True)
+ init_pred = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
+ feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
+ B, _, H, W = feat_map.shape
+
+ # try all pixels
+ out_res4 = self.out_conv_res4(feat_map.view(B, 512 + 4, -1)) # (B, 4, N)
+ out_res4 = norm_normalize(out_res4) # (B, 4, N) - normalized
+ out_res4 = out_res4.view(B, 4, H, W)
+ samples_pred_res4 = point_coords_res4 = None
+
+ ################################################################################################################
+ # out_res2
+ ################################################################################################################
+
+ if mode == 'train':
+
+ # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
+ out_res4_res2 = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
+ B, _, H, W = out_res4_res2.shape
+
+ # samples: [B, 1, N, 2]
+ point_coords_res2, rows_int, cols_int = sample_points(out_res4_res2.detach(), gt_norm_mask,
+ sampling_ratio=self.sampling_ratio,
+ beta=self.importance_ratio)
+
+ # output (needed for evaluation / visualization)
+ out_res2 = out_res4_res2
+
+ # grid_sample feature-map
+ feat_res2 = F.grid_sample(x_d3, point_coords_res2, mode='bilinear', align_corners=True) # (B, 256, 1, N)
+ init_pred = F.grid_sample(out_res4, point_coords_res2, mode='bilinear', align_corners=True) # (B, 4, 1, N)
+ feat_res2 = torch.cat([feat_res2, init_pred], dim=1) # (B, 256+4, 1, N)
+
+ # prediction (needed to compute loss)
+ samples_pred_res2 = self.out_conv_res2(feat_res2[:, :, 0, :]) # (B, 4, N)
+ samples_pred_res2 = norm_normalize(samples_pred_res2) # (B, 4, N) - normalized
+
+ for i in range(B):
+ out_res2[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res2[i, :, :]
+
+ else:
+ # grid_sample feature-map
+ feat_map = F.interpolate(x_d3, scale_factor=2, mode='bilinear', align_corners=True)
+ init_pred = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
+ feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
+ B, _, H, W = feat_map.shape
+
+ out_res2 = self.out_conv_res2(feat_map.view(B, 256 + 4, -1)) # (B, 4, N)
+ out_res2 = norm_normalize(out_res2) # (B, 4, N) - normalized
+ out_res2 = out_res2.view(B, 4, H, W)
+ samples_pred_res2 = point_coords_res2 = None
+
+ ################################################################################################################
+ # out_res1
+ ################################################################################################################
+
+ if mode == 'train':
+ # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
+ out_res2_res1 = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
+ B, _, H, W = out_res2_res1.shape
+
+ # samples: [B, 1, N, 2]
+ point_coords_res1, rows_int, cols_int = sample_points(out_res2_res1.detach(), gt_norm_mask,
+ sampling_ratio=self.sampling_ratio,
+ beta=self.importance_ratio)
+
+ # output (needed for evaluation / visualization)
+ out_res1 = out_res2_res1
+
+ # grid_sample feature-map
+ feat_res1 = F.grid_sample(x_d4, point_coords_res1, mode='bilinear', align_corners=True) # (B, 128, 1, N)
+ init_pred = F.grid_sample(out_res2, point_coords_res1, mode='bilinear', align_corners=True) # (B, 4, 1, N)
+ feat_res1 = torch.cat([feat_res1, init_pred], dim=1) # (B, 128+4, 1, N)
+
+ # prediction (needed to compute loss)
+ samples_pred_res1 = self.out_conv_res1(feat_res1[:, :, 0, :]) # (B, 4, N)
+ samples_pred_res1 = norm_normalize(samples_pred_res1) # (B, 4, N) - normalized
+
+ for i in range(B):
+ out_res1[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res1[i, :, :]
+
+ else:
+ # grid_sample feature-map
+ feat_map = F.interpolate(x_d4, scale_factor=2, mode='bilinear', align_corners=True)
+ init_pred = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
+ feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
+ B, _, H, W = feat_map.shape
+
+ out_res1 = self.out_conv_res1(feat_map.view(B, 128 + 4, -1)) # (B, 4, N)
+ out_res1 = norm_normalize(out_res1) # (B, 4, N) - normalized
+ out_res1 = out_res1.view(B, 4, H, W)
+ samples_pred_res1 = point_coords_res1 = None
+
+ return [out_res8, out_res4, out_res2, out_res1], \
+ [out_res8, samples_pred_res4, samples_pred_res2, samples_pred_res1], \
+ [None, point_coords_res4, point_coords_res2, point_coords_res1]
+
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/.gitignore b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..f04e5fff91094d9b9c662bba977d762bf71516ac
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/.gitignore
@@ -0,0 +1,109 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# pytorch stuff
+*.pth
+*.onnx
+*.pb
+
+trained_models/
+.fuse_hidden*
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/BENCHMARK.md b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/BENCHMARK.md
new file mode 100644
index 0000000000000000000000000000000000000000..6ead7171ce5a5bbd2702f6b5c825dc9808ba5658
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/BENCHMARK.md
@@ -0,0 +1,555 @@
+# Model Performance Benchmarks
+
+All benchmarks run as per:
+
+```
+python onnx_export.py --model mobilenetv3_100 ./mobilenetv3_100.onnx
+python onnx_optimize.py ./mobilenetv3_100.onnx --output mobilenetv3_100-opt.onnx
+python onnx_to_caffe.py ./mobilenetv3_100.onnx --c2-prefix mobilenetv3
+python onnx_to_caffe.py ./mobilenetv3_100-opt.onnx --c2-prefix mobilenetv3-opt
+python caffe2_benchmark.py --c2-init ./mobilenetv3.init.pb --c2-predict ./mobilenetv3.predict.pb
+python caffe2_benchmark.py --c2-init ./mobilenetv3-opt.init.pb --c2-predict ./mobilenetv3-opt.predict.pb
+```
+
+## EfficientNet-B0
+
+### Unoptimized
+```
+Main run finished. Milliseconds per iter: 49.2862. Iters per second: 20.2897
+Time per operator type:
+ 29.7378 ms. 60.5145%. Conv
+ 12.1785 ms. 24.7824%. Sigmoid
+ 3.62811 ms. 7.38297%. SpatialBN
+ 2.98444 ms. 6.07314%. Mul
+ 0.326902 ms. 0.665225%. AveragePool
+ 0.197317 ms. 0.401528%. FC
+ 0.0852877 ms. 0.173555%. Add
+ 0.0032607 ms. 0.00663532%. Squeeze
+ 49.1416 ms in Total
+FLOP per operator type:
+ 0.76907 GFLOP. 95.2696%. Conv
+ 0.0269508 GFLOP. 3.33857%. SpatialBN
+ 0.00846444 GFLOP. 1.04855%. Mul
+ 0.002561 GFLOP. 0.317248%. FC
+ 0.000210112 GFLOP. 0.0260279%. Add
+ 0.807256 GFLOP in Total
+Feature Memory Read per operator type:
+ 58.5253 MB. 43.0891%. Mul
+ 43.2015 MB. 31.807%. Conv
+ 27.2869 MB. 20.0899%. SpatialBN
+ 5.12912 MB. 3.77631%. FC
+ 1.6809 MB. 1.23756%. Add
+ 135.824 MB in Total
+Feature Memory Written per operator type:
+ 33.8578 MB. 38.1965%. Mul
+ 26.9881 MB. 30.4465%. Conv
+ 26.9508 MB. 30.4044%. SpatialBN
+ 0.840448 MB. 0.948147%. Add
+ 0.004 MB. 0.00451258%. FC
+ 88.6412 MB in Total
+Parameter Memory per operator type:
+ 15.8248 MB. 74.9391%. Conv
+ 5.124 MB. 24.265%. FC
+ 0.168064 MB. 0.795877%. SpatialBN
+ 0 MB. 0%. Add
+ 0 MB. 0%. Mul
+ 21.1168 MB in Total
+```
+### Optimized
+```
+Main run finished. Milliseconds per iter: 46.0838. Iters per second: 21.6996
+Time per operator type:
+ 29.776 ms. 65.002%. Conv
+ 12.2803 ms. 26.8084%. Sigmoid
+ 3.15073 ms. 6.87815%. Mul
+ 0.328651 ms. 0.717456%. AveragePool
+ 0.186237 ms. 0.406563%. FC
+ 0.0832429 ms. 0.181722%. Add
+ 0.0026184 ms. 0.00571606%. Squeeze
+ 45.8078 ms in Total
+FLOP per operator type:
+ 0.76907 GFLOP. 98.5601%. Conv
+ 0.00846444 GFLOP. 1.08476%. Mul
+ 0.002561 GFLOP. 0.328205%. FC
+ 0.000210112 GFLOP. 0.0269269%. Add
+ 0.780305 GFLOP in Total
+Feature Memory Read per operator type:
+ 58.5253 MB. 53.8803%. Mul
+ 43.2855 MB. 39.8501%. Conv
+ 5.12912 MB. 4.72204%. FC
+ 1.6809 MB. 1.54749%. Add
+ 108.621 MB in Total
+Feature Memory Written per operator type:
+ 33.8578 MB. 54.8834%. Mul
+ 26.9881 MB. 43.7477%. Conv
+ 0.840448 MB. 1.36237%. Add
+ 0.004 MB. 0.00648399%. FC
+ 61.6904 MB in Total
+Parameter Memory per operator type:
+ 15.8248 MB. 75.5403%. Conv
+ 5.124 MB. 24.4597%. FC
+ 0 MB. 0%. Add
+ 0 MB. 0%. Mul
+ 20.9488 MB in Total
+```
+
+## EfficientNet-B1
+### Optimized
+```
+Main run finished. Milliseconds per iter: 71.8102. Iters per second: 13.9256
+Time per operator type:
+ 45.7915 ms. 66.3206%. Conv
+ 17.8718 ms. 25.8841%. Sigmoid
+ 4.44132 ms. 6.43244%. Mul
+ 0.51001 ms. 0.738658%. AveragePool
+ 0.233283 ms. 0.337868%. Add
+ 0.194986 ms. 0.282402%. FC
+ 0.00268255 ms. 0.00388519%. Squeeze
+ 69.0456 ms in Total
+FLOP per operator type:
+ 1.37105 GFLOP. 98.7673%. Conv
+ 0.0138759 GFLOP. 0.99959%. Mul
+ 0.002561 GFLOP. 0.184489%. FC
+ 0.000674432 GFLOP. 0.0485847%. Add
+ 1.38816 GFLOP in Total
+Feature Memory Read per operator type:
+ 94.624 MB. 54.0789%. Mul
+ 69.8255 MB. 39.9062%. Conv
+ 5.39546 MB. 3.08357%. Add
+ 5.12912 MB. 2.93136%. FC
+ 174.974 MB in Total
+Feature Memory Written per operator type:
+ 55.5035 MB. 54.555%. Mul
+ 43.5333 MB. 42.7894%. Conv
+ 2.69773 MB. 2.65163%. Add
+ 0.004 MB. 0.00393165%. FC
+ 101.739 MB in Total
+Parameter Memory per operator type:
+ 25.7479 MB. 83.4024%. Conv
+ 5.124 MB. 16.5976%. FC
+ 0 MB. 0%. Add
+ 0 MB. 0%. Mul
+ 30.8719 MB in Total
+```
+
+## EfficientNet-B2
+### Optimized
+```
+Main run finished. Milliseconds per iter: 92.28. Iters per second: 10.8366
+Time per operator type:
+ 61.4627 ms. 67.5845%. Conv
+ 22.7458 ms. 25.0113%. Sigmoid
+ 5.59931 ms. 6.15701%. Mul
+ 0.642567 ms. 0.706568%. AveragePool
+ 0.272795 ms. 0.299965%. Add
+ 0.216178 ms. 0.237709%. FC
+ 0.00268895 ms. 0.00295677%. Squeeze
+ 90.942 ms in Total
+FLOP per operator type:
+ 1.98431 GFLOP. 98.9343%. Conv
+ 0.0177039 GFLOP. 0.882686%. Mul
+ 0.002817 GFLOP. 0.140451%. FC
+ 0.000853984 GFLOP. 0.0425782%. Add
+ 2.00568 GFLOP in Total
+Feature Memory Read per operator type:
+ 120.609 MB. 54.9637%. Mul
+ 86.3512 MB. 39.3519%. Conv
+ 6.83187 MB. 3.11341%. Add
+ 5.64163 MB. 2.571%. FC
+ 219.433 MB in Total
+Feature Memory Written per operator type:
+ 70.8155 MB. 54.6573%. Mul
+ 55.3273 MB. 42.7031%. Conv
+ 3.41594 MB. 2.63651%. Add
+ 0.004 MB. 0.00308731%. FC
+ 129.563 MB in Total
+Parameter Memory per operator type:
+ 30.4721 MB. 84.3913%. Conv
+ 5.636 MB. 15.6087%. FC
+ 0 MB. 0%. Add
+ 0 MB. 0%. Mul
+ 36.1081 MB in Total
+```
+
+## MixNet-M
+### Optimized
+```
+Main run finished. Milliseconds per iter: 63.1122. Iters per second: 15.8448
+Time per operator type:
+ 48.1139 ms. 75.2052%. Conv
+ 7.1341 ms. 11.1511%. Sigmoid
+ 2.63706 ms. 4.12189%. SpatialBN
+ 1.73186 ms. 2.70701%. Mul
+ 1.38707 ms. 2.16809%. Split
+ 1.29322 ms. 2.02139%. Concat
+ 1.00093 ms. 1.56452%. Relu
+ 0.235309 ms. 0.367803%. Add
+ 0.221579 ms. 0.346343%. FC
+ 0.219315 ms. 0.342803%. AveragePool
+ 0.00250145 ms. 0.00390993%. Squeeze
+ 63.9768 ms in Total
+FLOP per operator type:
+ 0.675273 GFLOP. 95.5827%. Conv
+ 0.0221072 GFLOP. 3.12921%. SpatialBN
+ 0.00538445 GFLOP. 0.762152%. Mul
+ 0.003073 GFLOP. 0.434973%. FC
+ 0.000642488 GFLOP. 0.0909421%. Add
+ 0 GFLOP. 0%. Concat
+ 0 GFLOP. 0%. Relu
+ 0.70648 GFLOP in Total
+Feature Memory Read per operator type:
+ 46.8424 MB. 30.502%. Conv
+ 36.8626 MB. 24.0036%. Mul
+ 22.3152 MB. 14.5309%. SpatialBN
+ 22.1074 MB. 14.3955%. Concat
+ 14.1496 MB. 9.21372%. Relu
+ 6.15414 MB. 4.00735%. FC
+ 5.1399 MB. 3.34692%. Add
+ 153.571 MB in Total
+Feature Memory Written per operator type:
+ 32.7672 MB. 28.4331%. Conv
+ 22.1072 MB. 19.1831%. Concat
+ 22.1072 MB. 19.1831%. SpatialBN
+ 21.5378 MB. 18.689%. Mul
+ 14.1496 MB. 12.2781%. Relu
+ 2.56995 MB. 2.23003%. Add
+ 0.004 MB. 0.00347092%. FC
+ 115.243 MB in Total
+Parameter Memory per operator type:
+ 13.7059 MB. 68.674%. Conv
+ 6.148 MB. 30.8049%. FC
+ 0.104 MB. 0.521097%. SpatialBN
+ 0 MB. 0%. Add
+ 0 MB. 0%. Concat
+ 0 MB. 0%. Mul
+ 0 MB. 0%. Relu
+ 19.9579 MB in Total
+```
+
+## TF MobileNet-V3 Large 1.0
+
+### Optimized
+```
+Main run finished. Milliseconds per iter: 22.0495. Iters per second: 45.3525
+Time per operator type:
+ 17.437 ms. 80.0087%. Conv
+ 1.27662 ms. 5.8577%. Add
+ 1.12759 ms. 5.17387%. Div
+ 0.701155 ms. 3.21721%. Mul
+ 0.562654 ms. 2.58171%. Relu
+ 0.431144 ms. 1.97828%. Clip
+ 0.156902 ms. 0.719936%. FC
+ 0.0996858 ms. 0.457402%. AveragePool
+ 0.00112455 ms. 0.00515993%. Flatten
+ 21.7939 ms in Total
+FLOP per operator type:
+ 0.43062 GFLOP. 98.1484%. Conv
+ 0.002561 GFLOP. 0.583713%. FC
+ 0.00210867 GFLOP. 0.480616%. Mul
+ 0.00193868 GFLOP. 0.441871%. Add
+ 0.00151532 GFLOP. 0.345377%. Div
+ 0 GFLOP. 0%. Relu
+ 0.438743 GFLOP in Total
+Feature Memory Read per operator type:
+ 34.7967 MB. 43.9391%. Conv
+ 14.496 MB. 18.3046%. Mul
+ 9.44828 MB. 11.9307%. Add
+ 9.26157 MB. 11.6949%. Relu
+ 6.0614 MB. 7.65395%. Div
+ 5.12912 MB. 6.47673%. FC
+ 79.193 MB in Total
+Feature Memory Written per operator type:
+ 17.6247 MB. 35.8656%. Conv
+ 9.26157 MB. 18.847%. Relu
+ 8.43469 MB. 17.1643%. Mul
+ 7.75472 MB. 15.7806%. Add
+ 6.06128 MB. 12.3345%. Div
+ 0.004 MB. 0.00813985%. FC
+ 49.1409 MB in Total
+Parameter Memory per operator type:
+ 16.6851 MB. 76.5052%. Conv
+ 5.124 MB. 23.4948%. FC
+ 0 MB. 0%. Add
+ 0 MB. 0%. Div
+ 0 MB. 0%. Mul
+ 0 MB. 0%. Relu
+ 21.8091 MB in Total
+```
+
+## MobileNet-V3 (RW)
+
+### Unoptimized
+```
+Main run finished. Milliseconds per iter: 24.8316. Iters per second: 40.2712
+Time per operator type:
+ 15.9266 ms. 69.2624%. Conv
+ 2.36551 ms. 10.2873%. SpatialBN
+ 1.39102 ms. 6.04936%. Add
+ 1.30327 ms. 5.66773%. Div
+ 0.737014 ms. 3.20517%. Mul
+ 0.639697 ms. 2.78195%. Relu
+ 0.375681 ms. 1.63378%. Clip
+ 0.153126 ms. 0.665921%. FC
+ 0.0993787 ms. 0.432184%. AveragePool
+ 0.0032632 ms. 0.0141912%. Squeeze
+ 22.9946 ms in Total
+FLOP per operator type:
+ 0.430616 GFLOP. 94.4041%. Conv
+ 0.0175992 GFLOP. 3.85829%. SpatialBN
+ 0.002561 GFLOP. 0.561449%. FC
+ 0.00210961 GFLOP. 0.46249%. Mul
+ 0.00173891 GFLOP. 0.381223%. Add
+ 0.00151626 GFLOP. 0.33241%. Div
+ 0 GFLOP. 0%. Relu
+ 0.456141 GFLOP in Total
+Feature Memory Read per operator type:
+ 34.7354 MB. 36.4363%. Conv
+ 17.7944 MB. 18.6658%. SpatialBN
+ 14.5035 MB. 15.2137%. Mul
+ 9.25778 MB. 9.71113%. Relu
+ 7.84641 MB. 8.23064%. Add
+ 6.06516 MB. 6.36216%. Div
+ 5.12912 MB. 5.38029%. FC
+ 95.3317 MB in Total
+Feature Memory Written per operator type:
+ 17.6246 MB. 26.7264%. Conv
+ 17.5992 MB. 26.6878%. SpatialBN
+ 9.25778 MB. 14.0387%. Relu
+ 8.43843 MB. 12.7962%. Mul
+ 6.95565 MB. 10.5477%. Add
+ 6.06502 MB. 9.19713%. Div
+ 0.004 MB. 0.00606568%. FC
+ 65.9447 MB in Total
+Parameter Memory per operator type:
+ 16.6778 MB. 76.1564%. Conv
+ 5.124 MB. 23.3979%. FC
+ 0.0976 MB. 0.445674%. SpatialBN
+ 0 MB. 0%. Add
+ 0 MB. 0%. Div
+ 0 MB. 0%. Mul
+ 0 MB. 0%. Relu
+ 21.8994 MB in Total
+
+```
+### Optimized
+
+```
+Main run finished. Milliseconds per iter: 22.0981. Iters per second: 45.2527
+Time per operator type:
+ 17.146 ms. 78.8965%. Conv
+ 1.38453 ms. 6.37084%. Add
+ 1.30991 ms. 6.02749%. Div
+ 0.685417 ms. 3.15391%. Mul
+ 0.532589 ms. 2.45068%. Relu
+ 0.418263 ms. 1.92461%. Clip
+ 0.15128 ms. 0.696106%. FC
+ 0.102065 ms. 0.469648%. AveragePool
+ 0.0022143 ms. 0.010189%. Squeeze
+ 21.7323 ms in Total
+FLOP per operator type:
+ 0.430616 GFLOP. 98.1927%. Conv
+ 0.002561 GFLOP. 0.583981%. FC
+ 0.00210961 GFLOP. 0.481051%. Mul
+ 0.00173891 GFLOP. 0.396522%. Add
+ 0.00151626 GFLOP. 0.34575%. Div
+ 0 GFLOP. 0%. Relu
+ 0.438542 GFLOP in Total
+Feature Memory Read per operator type:
+ 34.7842 MB. 44.833%. Conv
+ 14.5035 MB. 18.6934%. Mul
+ 9.25778 MB. 11.9323%. Relu
+ 7.84641 MB. 10.1132%. Add
+ 6.06516 MB. 7.81733%. Div
+ 5.12912 MB. 6.61087%. FC
+ 77.5861 MB in Total
+Feature Memory Written per operator type:
+ 17.6246 MB. 36.4556%. Conv
+ 9.25778 MB. 19.1492%. Relu
+ 8.43843 MB. 17.4544%. Mul
+ 6.95565 MB. 14.3874%. Add
+ 6.06502 MB. 12.5452%. Div
+ 0.004 MB. 0.00827378%. FC
+ 48.3455 MB in Total
+Parameter Memory per operator type:
+ 16.6778 MB. 76.4973%. Conv
+ 5.124 MB. 23.5027%. FC
+ 0 MB. 0%. Add
+ 0 MB. 0%. Div
+ 0 MB. 0%. Mul
+ 0 MB. 0%. Relu
+ 21.8018 MB in Total
+
+```
+
+## MnasNet-A1
+
+### Unoptimized
+```
+Main run finished. Milliseconds per iter: 30.0892. Iters per second: 33.2345
+Time per operator type:
+ 24.4656 ms. 79.0905%. Conv
+ 4.14958 ms. 13.4144%. SpatialBN
+ 1.60598 ms. 5.19169%. Relu
+ 0.295219 ms. 0.95436%. Mul
+ 0.187609 ms. 0.606486%. FC
+ 0.120556 ms. 0.389724%. AveragePool
+ 0.09036 ms. 0.292109%. Add
+ 0.015727 ms. 0.050841%. Sigmoid
+ 0.00306205 ms. 0.00989875%. Squeeze
+ 30.9337 ms in Total
+FLOP per operator type:
+ 0.620598 GFLOP. 95.6434%. Conv
+ 0.0248873 GFLOP. 3.8355%. SpatialBN
+ 0.002561 GFLOP. 0.394688%. FC
+ 0.000597408 GFLOP. 0.0920695%. Mul
+ 0.000222656 GFLOP. 0.0343146%. Add
+ 0 GFLOP. 0%. Relu
+ 0.648867 GFLOP in Total
+Feature Memory Read per operator type:
+ 35.5457 MB. 38.4109%. Conv
+ 25.1552 MB. 27.1829%. SpatialBN
+ 22.5235 MB. 24.339%. Relu
+ 5.12912 MB. 5.54256%. FC
+ 2.40586 MB. 2.59978%. Mul
+ 1.78125 MB. 1.92483%. Add
+ 92.5406 MB in Total
+Feature Memory Written per operator type:
+ 24.9042 MB. 32.9424%. Conv
+ 24.8873 MB. 32.92%. SpatialBN
+ 22.5235 MB. 29.7932%. Relu
+ 2.38963 MB. 3.16092%. Mul
+ 0.890624 MB. 1.17809%. Add
+ 0.004 MB. 0.00529106%. FC
+ 75.5993 MB in Total
+Parameter Memory per operator type:
+ 10.2732 MB. 66.1459%. Conv
+ 5.124 MB. 32.9917%. FC
+ 0.133952 MB. 0.86247%. SpatialBN
+ 0 MB. 0%. Add
+ 0 MB. 0%. Mul
+ 0 MB. 0%. Relu
+ 15.5312 MB in Total
+```
+
+### Optimized
+```
+Main run finished. Milliseconds per iter: 24.2367. Iters per second: 41.2597
+Time per operator type:
+ 22.0547 ms. 91.1375%. Conv
+ 1.49096 ms. 6.16116%. Relu
+ 0.253417 ms. 1.0472%. Mul
+ 0.18506 ms. 0.76473%. FC
+ 0.112942 ms. 0.466717%. AveragePool
+ 0.086769 ms. 0.358559%. Add
+ 0.0127889 ms. 0.0528479%. Sigmoid
+ 0.0027346 ms. 0.0113003%. Squeeze
+ 24.1994 ms in Total
+FLOP per operator type:
+ 0.620598 GFLOP. 99.4581%. Conv
+ 0.002561 GFLOP. 0.41043%. FC
+ 0.000597408 GFLOP. 0.0957417%. Mul
+ 0.000222656 GFLOP. 0.0356832%. Add
+ 0 GFLOP. 0%. Relu
+ 0.623979 GFLOP in Total
+Feature Memory Read per operator type:
+ 35.6127 MB. 52.7968%. Conv
+ 22.5235 MB. 33.3917%. Relu
+ 5.12912 MB. 7.60406%. FC
+ 2.40586 MB. 3.56675%. Mul
+ 1.78125 MB. 2.64075%. Add
+ 67.4524 MB in Total
+Feature Memory Written per operator type:
+ 24.9042 MB. 49.1092%. Conv
+ 22.5235 MB. 44.4145%. Relu
+ 2.38963 MB. 4.71216%. Mul
+ 0.890624 MB. 1.75624%. Add
+ 0.004 MB. 0.00788768%. FC
+ 50.712 MB in Total
+Parameter Memory per operator type:
+ 10.2732 MB. 66.7213%. Conv
+ 5.124 MB. 33.2787%. FC
+ 0 MB. 0%. Add
+ 0 MB. 0%. Mul
+ 0 MB. 0%. Relu
+ 15.3972 MB in Total
+```
+## MnasNet-B1
+
+### Unoptimized
+```
+Main run finished. Milliseconds per iter: 28.3109. Iters per second: 35.322
+Time per operator type:
+ 29.1121 ms. 83.3081%. Conv
+ 4.14959 ms. 11.8746%. SpatialBN
+ 1.35823 ms. 3.88675%. Relu
+ 0.186188 ms. 0.532802%. FC
+ 0.116244 ms. 0.332647%. Add
+ 0.018641 ms. 0.0533437%. AveragePool
+ 0.0040904 ms. 0.0117052%. Squeeze
+ 34.9451 ms in Total
+FLOP per operator type:
+ 0.626272 GFLOP. 96.2088%. Conv
+ 0.0218266 GFLOP. 3.35303%. SpatialBN
+ 0.002561 GFLOP. 0.393424%. FC
+ 0.000291648 GFLOP. 0.0448034%. Add
+ 0 GFLOP. 0%. Relu
+ 0.650951 GFLOP in Total
+Feature Memory Read per operator type:
+ 34.4354 MB. 41.3788%. Conv
+ 22.1299 MB. 26.5921%. SpatialBN
+ 19.1923 MB. 23.0622%. Relu
+ 5.12912 MB. 6.16333%. FC
+ 2.33318 MB. 2.80364%. Add
+ 83.2199 MB in Total
+Feature Memory Written per operator type:
+ 21.8266 MB. 34.0955%. Conv
+ 21.8266 MB. 34.0955%. SpatialBN
+ 19.1923 MB. 29.9805%. Relu
+ 1.16659 MB. 1.82234%. Add
+ 0.004 MB. 0.00624844%. FC
+ 64.016 MB in Total
+Parameter Memory per operator type:
+ 12.2576 MB. 69.9104%. Conv
+ 5.124 MB. 29.2245%. FC
+ 0.15168 MB. 0.865099%. SpatialBN
+ 0 MB. 0%. Add
+ 0 MB. 0%. Relu
+ 17.5332 MB in Total
+```
+
+### Optimized
+```
+Main run finished. Milliseconds per iter: 26.6364. Iters per second: 37.5426
+Time per operator type:
+ 24.9888 ms. 94.0962%. Conv
+ 1.26147 ms. 4.75011%. Relu
+ 0.176234 ms. 0.663619%. FC
+ 0.113309 ms. 0.426672%. Add
+ 0.0138708 ms. 0.0522311%. AveragePool
+ 0.00295685 ms. 0.0111341%. Squeeze
+ 26.5566 ms in Total
+FLOP per operator type:
+ 0.626272 GFLOP. 99.5466%. Conv
+ 0.002561 GFLOP. 0.407074%. FC
+ 0.000291648 GFLOP. 0.0463578%. Add
+ 0 GFLOP. 0%. Relu
+ 0.629124 GFLOP in Total
+Feature Memory Read per operator type:
+ 34.5112 MB. 56.4224%. Conv
+ 19.1923 MB. 31.3775%. Relu
+ 5.12912 MB. 8.3856%. FC
+ 2.33318 MB. 3.81452%. Add
+ 61.1658 MB in Total
+Feature Memory Written per operator type:
+ 21.8266 MB. 51.7346%. Conv
+ 19.1923 MB. 45.4908%. Relu
+ 1.16659 MB. 2.76513%. Add
+ 0.004 MB. 0.00948104%. FC
+ 42.1895 MB in Total
+Parameter Memory per operator type:
+ 12.2576 MB. 70.5205%. Conv
+ 5.124 MB. 29.4795%. FC
+ 0 MB. 0%. Add
+ 0 MB. 0%. Relu
+ 17.3816 MB in Total
+```
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/LICENSE b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..80e7d15508202f3262a50db27f5198460d7f509f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2020 Ross Wightman
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/README.md b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..463368280d6a5015060eb73d20fe6512f8e04c50
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/README.md
@@ -0,0 +1,323 @@
+# (Generic) EfficientNets for PyTorch
+
+A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter efficient architectures derived from the MobileNet V1/V2 block sequence, including those found via automated neural architecture search.
+
+All models are implemented by GenEfficientNet or MobileNetV3 classes, with string based architecture definitions to configure the block layouts (idea from [here](https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py))
+
+## What's New
+
+### Aug 19, 2020
+* Add updated PyTorch trained EfficientNet-B3 weights trained by myself with `timm` (82.1 top-1)
+* Add PyTorch trained EfficientNet-Lite0 contributed by [@hal-314](https://github.com/hal-314) (75.5 top-1)
+* Update ONNX and Caffe2 export / utility scripts to work with latest PyTorch / ONNX
+* ONNX runtime based validation script added
+* activations (mostly) brought in sync with `timm` equivalents
+
+
+### April 5, 2020
+* Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite
+ * 3.5M param MobileNet-V2 100 @ 73%
+ * 4.5M param MobileNet-V2 110d @ 75%
+ * 6.1M param MobileNet-V2 140 @ 76.5%
+ * 5.8M param MobileNet-V2 120d @ 77.3%
+
+### March 23, 2020
+ * Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
+ * Add PyTorch trained MobileNet-V3 Large weights with 75.77% top-1
+ * IMPORTANT CHANGE (if training from scratch) - weight init changed to better match Tensorflow impl, set `fix_group_fanout=False` in `initialize_weight_goog` for old behavior
+
+### Feb 12, 2020
+ * Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet)
+ * Port new EfficientNet-B8 (RandAugment) weights from TF TPU, these are different than the B8 AdvProp, different input normalization.
+ * Add RandAugment PyTorch trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin)
+
+### Jan 22, 2020
+ * Update weights for EfficientNet B0, B2, B3 and MixNet-XL with latest RandAugment trained weights. Trained with (https://github.com/rwightman/pytorch-image-models)
+ * Fix torchscript compatibility for PyTorch 1.4, add torchscript support for MixedConv2d using ModuleDict
+ * Test models, torchscript, onnx export with PyTorch 1.4 -- no issues
+
+### Nov 22, 2019
+ * New top-1 high! Ported official TF EfficientNet AdvProp (https://arxiv.org/abs/1911.09665) weights and B8 model spec. Created a new set of `ap` models since they use a different
+ preprocessing (Inception mean/std) from the original EfficientNet base/AA/RA weights.
+
+### Nov 15, 2019
+ * Ported official TF MobileNet-V3 float32 large/small/minimalistic weights
+ * Modifications to MobileNet-V3 model and components to support some additional config needed for differences between TF MobileNet-V3 and mine
+
+### Oct 30, 2019
+ * Many of the models will now work with torch.jit.script, MixNet being the biggest exception
+ * Improved interface for enabling torchscript or ONNX export compatible modes (via config)
+ * Add JIT optimized mem-efficient Swish/Mish autograd.fn in addition to memory-efficient autgrad.fn
+ * Activation factory to select best version of activation by name or override one globally
+ * Add pretrained checkpoint load helper that handles input conv and classifier changes
+
+### Oct 27, 2019
+ * Add CondConv EfficientNet variants ported from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
+ * Add RandAug weights for TF EfficientNet B5 and B7 from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
+ * Bring over MixNet-XL model and depth scaling algo from my pytorch-image-models code base
+ * Switch activations and global pooling to modules
+ * Add memory-efficient Swish/Mish impl
+ * Add as_sequential() method to all models and allow as an argument in entrypoint fns
+ * Move MobileNetV3 into own file since it has a different head
+ * Remove ChamNet, MobileNet V2/V1 since they will likely never be used here
+
+## Models
+
+Implemented models include:
+ * EfficientNet NoisyStudent (B0-B7, L2) (https://arxiv.org/abs/1911.04252)
+ * EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665)
+ * EfficientNet (B0-B8) (https://arxiv.org/abs/1905.11946)
+ * EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html)
+ * EfficientNet-CondConv (https://arxiv.org/abs/1904.04971)
+ * EfficientNet-Lite (https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
+ * MixNet (https://arxiv.org/abs/1907.09595)
+ * MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
+ * MobileNet-V3 (https://arxiv.org/abs/1905.02244)
+ * FBNet-C (https://arxiv.org/abs/1812.03443)
+ * Single-Path NAS (https://arxiv.org/abs/1904.02877)
+
+I originally implemented and trained some these models with code [here](https://github.com/rwightman/pytorch-image-models), this repository contains just the GenEfficientNet models, validation, and associated ONNX/Caffe2 export code.
+
+## Pretrained
+
+I've managed to train several of the models to accuracies close to or above the originating papers and official impl. My training code is here: https://github.com/rwightman/pytorch-image-models
+
+
+|Model | Prec@1 (Err) | Prec@5 (Err) | Param#(M) | MAdds(M) | Image Scaling | Resolution | Crop |
+|---|---|---|---|---|---|---|---|
+| efficientnet_b3 | 82.240 (17.760) | 96.116 (3.884) | 12.23 | TBD | bicubic | 320 | 1.0 |
+| efficientnet_b3 | 82.076 (17.924) | 96.020 (3.980) | 12.23 | TBD | bicubic | 300 | 0.904 |
+| mixnet_xl | 81.074 (18.926) | 95.282 (4.718) | 11.90 | TBD | bicubic | 256 | 1.0 |
+| efficientnet_b2 | 80.612 (19.388) | 95.318 (4.682) | 9.1 | TBD | bicubic | 288 | 1.0 |
+| mixnet_xl | 80.476 (19.524) | 94.936 (5.064) | 11.90 | TBD | bicubic | 224 | 0.875 |
+| efficientnet_b2 | 80.288 (19.712) | 95.166 (4.834) | 9.1 | 1003 | bicubic | 260 | 0.890 |
+| mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33 | TBD | bicubic | 224 | 0.875 |
+| efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.8 | 694 | bicubic | 240 | 0.882 |
+| efficientnet_es | 78.066 (21.934) | 93.926 (6.074) | 5.44 | TBD | bicubic | 224 | 0.875 |
+| efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.3 | 390 | bicubic | 224 | 0.875 |
+| mobilenetv2_120d | 77.294 (22.706 | 93.502 (6.498) | 5.8 | TBD | bicubic | 224 | 0.875 |
+| mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01 | 353 | bicubic | 224 | 0.875 |
+| mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1 | TBD | bicubic | 224 | 0.875 |
+| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13 | TBD | bicubic | 224 | 0.875 |
+| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5 | TBD | bicubic | 224 | 0.875 |
+| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5 | 219 | bicubic | 224 | 0.875 |
+| efficientnet_lite0 | 75.472 (24.528) | 92.520 (7.480) | 4.65 | TBD | bicubic | 224 | 0.875 |
+| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.9 | 312 | bicubic | 224 | 0.875 |
+| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6 | 385 | bilinear | 224 | 0.875 |
+| mobilenetv2_110d | 75.052 (24.948) | 92.180 (7.820) | 4.5 | TBD | bicubic | 224 | 0.875 |
+| mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.4 | 315 | bicubic | 224 | 0.875 |
+| spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.4 | TBD | bilinear | 224 | 0.875 |
+| mobilenetv2_100 | 72.978 (27.022) | 91.016 (8.984) | 3.5 | TBD | bicubic | 224 | 0.875 |
+
+
+More pretrained models to come...
+
+
+## Ported Weights
+
+The weights ported from Tensorflow checkpoints for the EfficientNet models do pretty much match accuracy in Tensorflow once a SAME convolution padding equivalent is added, and the same crop factors, image scaling, etc (see table) are used via cmd line args.
+
+**IMPORTANT:**
+* Tensorflow ported weights for EfficientNet AdvProp (AP), EfficientNet EdgeTPU, EfficientNet-CondConv, EfficientNet-Lite, and MobileNet-V3 models use Inception style (0.5, 0.5, 0.5) for mean and std.
+* Enabling the Tensorflow preprocessing pipeline with `--tf-preprocessing` at validation time will improve scores by 0.1-0.5%, very close to original TF impl.
+
+To run validation for tf_efficientnet_b5:
+`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --crop-pct 0.934 --interpolation bicubic`
+
+To run validation w/ TF preprocessing for tf_efficientnet_b5:
+`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --tf-preprocessing`
+
+To run validation for a model with Inception preprocessing, ie EfficientNet-B8 AdvProp:
+`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b8_ap -b 48 --num-gpu 2 --img-size 672 --crop-pct 0.954 --mean 0.5 --std 0.5`
+
+|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size | Crop |
+|---|---|---|---|---|---|---|
+| tf_efficientnet_l2_ns *tfp | 88.352 (11.648) | 98.652 (1.348) | 480 | bicubic | 800 | N/A |
+| tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 800 | 0.961 |
+| tf_efficientnet_l2_ns_475 | 88.234 (11.766) | 98.546 (1.454) | 480 | bicubic | 475 | 0.936 |
+| tf_efficientnet_l2_ns_475 *tfp | 88.172 (11.828) | 98.566 (1.434) | 480 | bicubic | 475 | N/A |
+| tf_efficientnet_b7_ns *tfp | 86.844 (13.156) | 98.084 (1.916) | 66.35 | bicubic | 600 | N/A |
+| tf_efficientnet_b7_ns | 86.840 (13.160) | 98.094 (1.906) | 66.35 | bicubic | 600 | N/A |
+| tf_efficientnet_b6_ns | 86.452 (13.548) | 97.882 (2.118) | 43.04 | bicubic | 528 | N/A |
+| tf_efficientnet_b6_ns *tfp | 86.444 (13.556) | 97.880 (2.120) | 43.04 | bicubic | 528 | N/A |
+| tf_efficientnet_b5_ns *tfp | 86.064 (13.936) | 97.746 (2.254) | 30.39 | bicubic | 456 | N/A |
+| tf_efficientnet_b5_ns | 86.088 (13.912) | 97.752 (2.248) | 30.39 | bicubic | 456 | N/A |
+| tf_efficientnet_b8_ap *tfp | 85.436 (14.564) | 97.272 (2.728) | 87.4 | bicubic | 672 | N/A |
+| tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 | N/A |
+| tf_efficientnet_b8 | 85.370 (14.630) | 97.390 (2.610) | 87.4 | bicubic | 672 | 0.954 |
+| tf_efficientnet_b8_ap | 85.368 (14.632) | 97.294 (2.706) | 87.4 | bicubic | 672 | 0.954 |
+| tf_efficientnet_b4_ns *tfp | 85.298 (14.702) | 97.504 (2.496) | 19.34 | bicubic | 380 | N/A |
+| tf_efficientnet_b4_ns | 85.162 (14.838) | 97.470 (2.530) | 19.34 | bicubic | 380 | 0.922 |
+| tf_efficientnet_b7_ap *tfp | 85.154 (14.846) | 97.244 (2.756) | 66.35 | bicubic | 600 | N/A |
+| tf_efficientnet_b7_ap | 85.118 (14.882) | 97.252 (2.748) | 66.35 | bicubic | 600 | 0.949 |
+| tf_efficientnet_b7 *tfp | 84.940 (15.060) | 97.214 (2.786) | 66.35 | bicubic | 600 | N/A |
+| tf_efficientnet_b7 | 84.932 (15.068) | 97.208 (2.792) | 66.35 | bicubic | 600 | 0.949 |
+| tf_efficientnet_b6_ap | 84.786 (15.214) | 97.138 (2.862) | 43.04 | bicubic | 528 | 0.942 |
+| tf_efficientnet_b6_ap *tfp | 84.760 (15.240) | 97.124 (2.876) | 43.04 | bicubic | 528 | N/A |
+| tf_efficientnet_b5_ap *tfp | 84.276 (15.724) | 96.932 (3.068) | 30.39 | bicubic | 456 | N/A |
+| tf_efficientnet_b5_ap | 84.254 (15.746) | 96.976 (3.024) | 30.39 | bicubic | 456 | 0.934 |
+| tf_efficientnet_b6 *tfp | 84.140 (15.860) | 96.852 (3.148) | 43.04 | bicubic | 528 | N/A |
+| tf_efficientnet_b6 | 84.110 (15.890) | 96.886 (3.114) | 43.04 | bicubic | 528 | 0.942 |
+| tf_efficientnet_b3_ns *tfp | 84.054 (15.946) | 96.918 (3.082) | 12.23 | bicubic | 300 | N/A |
+| tf_efficientnet_b3_ns | 84.048 (15.952) | 96.910 (3.090) | 12.23 | bicubic | 300 | .904 |
+| tf_efficientnet_b5 *tfp | 83.822 (16.178) | 96.756 (3.244) | 30.39 | bicubic | 456 | N/A |
+| tf_efficientnet_b5 | 83.812 (16.188) | 96.748 (3.252) | 30.39 | bicubic | 456 | 0.934 |
+| tf_efficientnet_b4_ap *tfp | 83.278 (16.722) | 96.376 (3.624) | 19.34 | bicubic | 380 | N/A |
+| tf_efficientnet_b4_ap | 83.248 (16.752) | 96.388 (3.612) | 19.34 | bicubic | 380 | 0.922 |
+| tf_efficientnet_b4 | 83.022 (16.978) | 96.300 (3.700) | 19.34 | bicubic | 380 | 0.922 |
+| tf_efficientnet_b4 *tfp | 82.948 (17.052) | 96.308 (3.692) | 19.34 | bicubic | 380 | N/A |
+| tf_efficientnet_b2_ns *tfp | 82.436 (17.564) | 96.268 (3.732) | 9.11 | bicubic | 260 | N/A |
+| tf_efficientnet_b2_ns | 82.380 (17.620) | 96.248 (3.752) | 9.11 | bicubic | 260 | 0.89 |
+| tf_efficientnet_b3_ap *tfp | 81.882 (18.118) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A |
+| tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 | 0.904 |
+| tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 | 0.904 |
+| tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A |
+| tf_efficientnet_lite4 | 81.528 (18.472) | 95.668 (4.332) | 13.00 | bilinear | 380 | 0.92 |
+| tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 | N/A |
+| tf_efficientnet_lite4 *tfp | 81.502 (18.498) | 95.676 (4.324) | 13.00 | bilinear | 380 | N/A |
+| tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 | 0.88 |
+| tf_efficientnet_el | 80.534 (19.466) | 95.190 (4.810) | 10.59 | bicubic | 300 | 0.904 |
+| tf_efficientnet_el *tfp | 80.476 (19.524) | 95.200 (4.800) | 10.59 | bicubic | 300 | N/A |
+| tf_efficientnet_b2_ap *tfp | 80.420 (19.580) | 95.040 (4.960) | 9.11 | bicubic | 260 | N/A |
+| tf_efficientnet_b2_ap | 80.306 (19.694) | 95.028 (4.972) | 9.11 | bicubic | 260 | 0.890 |
+| tf_efficientnet_b2 *tfp | 80.188 (19.812) | 94.974 (5.026) | 9.11 | bicubic | 260 | N/A |
+| tf_efficientnet_b2 | 80.086 (19.914) | 94.908 (5.092) | 9.11 | bicubic | 260 | 0.890 |
+| tf_efficientnet_lite3 | 79.812 (20.188) | 94.914 (5.086) | 8.20 | bilinear | 300 | 0.904 |
+| tf_efficientnet_lite3 *tfp | 79.734 (20.266) | 94.838 (5.162) | 8.20 | bilinear | 300 | N/A |
+| tf_efficientnet_b1_ap *tfp | 79.532 (20.468) | 94.378 (5.622) | 7.79 | bicubic | 240 | N/A |
+| tf_efficientnet_cc_b1_8e *tfp | 79.464 (20.536)| 94.492 (5.508) | 39.7 | bicubic | 240 | 0.88 |
+| tf_efficientnet_cc_b1_8e | 79.298 (20.702) | 94.364 (5.636) | 39.7 | bicubic | 240 | 0.88 |
+| tf_efficientnet_b1_ap | 79.278 (20.722) | 94.308 (5.692) | 7.79 | bicubic | 240 | 0.88 |
+| tf_efficientnet_b1 *tfp | 79.172 (20.828) | 94.450 (5.550) | 7.79 | bicubic | 240 | N/A |
+| tf_efficientnet_em *tfp | 78.958 (21.042) | 94.458 (5.542) | 6.90 | bicubic | 240 | N/A |
+| tf_efficientnet_b0_ns *tfp | 78.806 (21.194) | 94.496 (5.504) | 5.29 | bicubic | 224 | N/A |
+| tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | 224 | N/A |
+| tf_efficientnet_b1 | 78.826 (21.174) | 94.198 (5.802) | 7.79 | bicubic | 240 | 0.88 |
+| tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | 224 | 0.875 |
+| tf_efficientnet_em | 78.742 (21.258) | 94.332 (5.668) | 6.90 | bicubic | 240 | 0.875 |
+| tf_efficientnet_b0_ns | 78.658 (21.342) | 94.376 (5.624) | 5.29 | bicubic | 224 | 0.875 |
+| tf_efficientnet_cc_b0_8e *tfp | 78.314 (21.686) | 93.790 (6.210) | 24.0 | bicubic | 224 | 0.875 |
+| tf_efficientnet_cc_b0_8e | 77.908 (22.092) | 93.656 (6.344) | 24.0 | bicubic | 224 | 0.875 |
+| tf_efficientnet_cc_b0_4e *tfp | 77.746 (22.254) | 93.552 (6.448) | 13.3 | bicubic | 224 | 0.875 |
+| tf_efficientnet_cc_b0_4e | 77.304 (22.696) | 93.332 (6.668) | 13.3 | bicubic | 224 | 0.875 |
+| tf_efficientnet_es *tfp | 77.616 (22.384) | 93.750 (6.250) | 5.44 | bicubic | 224 | N/A |
+| tf_efficientnet_lite2 *tfp | 77.544 (22.456) | 93.800 (6.200) | 6.09 | bilinear | 260 | N/A |
+| tf_efficientnet_lite2 | 77.460 (22.540) | 93.746 (6.254) | 6.09 | bicubic | 260 | 0.89 |
+| tf_efficientnet_b0_ap *tfp | 77.514 (22.486) | 93.576 (6.424) | 5.29 | bicubic | 224 | N/A |
+| tf_efficientnet_es | 77.264 (22.736) | 93.600 (6.400) | 5.44 | bicubic | 224 | N/A |
+| tf_efficientnet_b0 *tfp | 77.258 (22.742) | 93.478 (6.522) | 5.29 | bicubic | 224 | N/A |
+| tf_efficientnet_b0_ap | 77.084 (22.916) | 93.254 (6.746) | 5.29 | bicubic | 224 | 0.875 |
+| tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | 224 | N/A |
+| tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | 224 | 0.875 |
+| tf_efficientnet_b0 | 76.848 (23.152) | 93.228 (6.772) | 5.29 | bicubic | 224 | 0.875 |
+| tf_efficientnet_lite1 *tfp | 76.764 (23.236) | 93.326 (6.674) | 5.42 | bilinear | 240 | N/A |
+| tf_efficientnet_lite1 | 76.638 (23.362) | 93.232 (6.768) | 5.42 | bicubic | 240 | 0.882 |
+| tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | 224 | N/A |
+| tf_mobilenetv3_large_100 *tfp | 75.768 (24.232) | 92.710 (7.290) | 5.48 | bilinear | 224 | N/A |
+| tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | 224 | 0.875 |
+| tf_mobilenetv3_large_100 | 75.516 (24.484) | 92.600 (7.400) | 5.48 | bilinear | 224 | 0.875 |
+| tf_efficientnet_lite0 *tfp | 75.074 (24.926) | 92.314 (7.686) | 4.65 | bilinear | 224 | N/A |
+| tf_efficientnet_lite0 | 74.842 (25.158) | 92.170 (7.830) | 4.65 | bicubic | 224 | 0.875 |
+| tf_mobilenetv3_large_075 *tfp | 73.730 (26.270) | 91.616 (8.384) | 3.99 | bilinear | 224 |N/A |
+| tf_mobilenetv3_large_075 | 73.442 (26.558) | 91.352 (8.648) | 3.99 | bilinear | 224 | 0.875 |
+| tf_mobilenetv3_large_minimal_100 *tfp | 72.678 (27.322) | 90.860 (9.140) | 3.92 | bilinear | 224 | N/A |
+| tf_mobilenetv3_large_minimal_100 | 72.244 (27.756) | 90.636 (9.364) | 3.92 | bilinear | 224 | 0.875 |
+| tf_mobilenetv3_small_100 *tfp | 67.918 (32.082) | 87.958 (12.042 | 2.54 | bilinear | 224 | N/A |
+| tf_mobilenetv3_small_100 | 67.918 (32.082) | 87.662 (12.338) | 2.54 | bilinear | 224 | 0.875 |
+| tf_mobilenetv3_small_075 *tfp | 66.142 (33.858) | 86.498 (13.502) | 2.04 | bilinear | 224 | N/A |
+| tf_mobilenetv3_small_075 | 65.718 (34.282) | 86.136 (13.864) | 2.04 | bilinear | 224 | 0.875 |
+| tf_mobilenetv3_small_minimal_100 *tfp | 63.378 (36.622) | 84.802 (15.198) | 2.04 | bilinear | 224 | N/A |
+| tf_mobilenetv3_small_minimal_100 | 62.898 (37.102) | 84.230 (15.770) | 2.04 | bilinear | 224 | 0.875 |
+
+
+*tfp models validated with `tf-preprocessing` pipeline
+
+Google tf and tflite weights ported from official Tensorflow repositories
+* https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
+* https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
+* https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet
+
+## Usage
+
+### Environment
+
+All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x, 3.8.x.
+
+Users have reported that a Python 3 Anaconda install in Windows works. I have not verified this myself.
+
+PyTorch versions 1.4, 1.5, 1.6 have been tested with this code.
+
+I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda:
+```
+conda create -n torch-env
+conda activate torch-env
+conda install -c pytorch pytorch torchvision cudatoolkit=10.2
+```
+
+### PyTorch Hub
+
+Models can be accessed via the PyTorch Hub API
+
+```
+>>> torch.hub.list('rwightman/gen-efficientnet-pytorch')
+['efficientnet_b0', ...]
+>>> model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', pretrained=True)
+>>> model.eval()
+>>> output = model(torch.randn(1,3,224,224))
+```
+
+### Pip
+This package can be installed via pip.
+
+Install (after conda env/install):
+```
+pip install geffnet
+```
+
+Eval use:
+```
+>>> import geffnet
+>>> m = geffnet.create_model('mobilenetv3_large_100', pretrained=True)
+>>> m.eval()
+```
+
+Train use:
+```
+>>> import geffnet
+>>> # models can also be created by using the entrypoint directly
+>>> m = geffnet.efficientnet_b2(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2)
+>>> m.train()
+```
+
+Create in a nn.Sequential container, for fast.ai, etc:
+```
+>>> import geffnet
+>>> m = geffnet.mixnet_l(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2, as_sequential=True)
+```
+
+### Exporting
+
+Scripts are included to
+* export models to ONNX (`onnx_export.py`)
+* optimized ONNX graph (`onnx_optimize.py` or `onnx_validate.py` w/ `--onnx-output-opt` arg)
+* validate with ONNX runtime (`onnx_validate.py`)
+* convert ONNX model to Caffe2 (`onnx_to_caffe.py`)
+* validate in Caffe2 (`caffe2_validate.py`)
+* benchmark in Caffe2 w/ FLOPs, parameters output (`caffe2_benchmark.py`)
+
+As an example, to export the MobileNet-V3 pretrained model and then run an Imagenet validation:
+```
+python onnx_export.py --model mobilenetv3_large_100 ./mobilenetv3_100.onnx
+python onnx_validate.py /imagenet/validation/ --onnx-input ./mobilenetv3_100.onnx
+```
+
+These scripts were tested to be working as of PyTorch 1.6 and ONNX 1.7 w/ ONNX runtime 1.4. Caffe2 compatible
+export now requires additional args mentioned in the export script (not needed in earlier versions).
+
+#### Export Notes
+1. The TF ported weights with the 'SAME' conv padding activated cannot be exported to ONNX unless `_EXPORTABLE` flag in `config.py` is set to True. Use `config.set_exportable(True)` as in the `onnx_export.py` script.
+2. TF ported models with 'SAME' padding will have the padding fixed at export time to the resolution used for export. Even though dynamic padding is supported in opset >= 11, I can't get it working.
+3. ONNX optimize facility doesn't work reliably in PyTorch 1.6 / ONNX 1.7. Fortunately, the onnxruntime based inference is working very well now and includes on the fly optimization.
+3. ONNX / Caffe2 export/import frequently breaks with different PyTorch and ONNX version releases. Please check their respective issue trackers before filing issues here.
+
+
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/caffe2_benchmark.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/caffe2_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..93f28a1e63d9f7287ca02997c7991fe66dd0aeb9
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/caffe2_benchmark.py
@@ -0,0 +1,65 @@
+""" Caffe2 validation script
+
+This script runs Caffe2 benchmark on exported ONNX model.
+It is a useful tool for reporting model FLOPS.
+
+Copyright 2020 Ross Wightman
+"""
+import argparse
+from caffe2.python import core, workspace, model_helper
+from caffe2.proto import caffe2_pb2
+
+
+parser = argparse.ArgumentParser(description='Caffe2 Model Benchmark')
+parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME',
+ help='caffe2 model pb name prefix')
+parser.add_argument('--c2-init', default='', type=str, metavar='PATH',
+ help='caffe2 model init .pb')
+parser.add_argument('--c2-predict', default='', type=str, metavar='PATH',
+ help='caffe2 model predict .pb')
+parser.add_argument('-b', '--batch-size', default=1, type=int,
+ metavar='N', help='mini-batch size (default: 1)')
+parser.add_argument('--img-size', default=224, type=int,
+ metavar='N', help='Input image dimension, uses model default if empty')
+
+
+def main():
+ args = parser.parse_args()
+ args.gpu_id = 0
+ if args.c2_prefix:
+ args.c2_init = args.c2_prefix + '.init.pb'
+ args.c2_predict = args.c2_prefix + '.predict.pb'
+
+ model = model_helper.ModelHelper(name="le_net", init_params=False)
+
+ # Bring in the init net from init_net.pb
+ init_net_proto = caffe2_pb2.NetDef()
+ with open(args.c2_init, "rb") as f:
+ init_net_proto.ParseFromString(f.read())
+ model.param_init_net = core.Net(init_net_proto)
+
+ # bring in the predict net from predict_net.pb
+ predict_net_proto = caffe2_pb2.NetDef()
+ with open(args.c2_predict, "rb") as f:
+ predict_net_proto.ParseFromString(f.read())
+ model.net = core.Net(predict_net_proto)
+
+ # CUDA performance not impressive
+ #device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id)
+ #model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
+ #model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
+
+ input_blob = model.net.external_inputs[0]
+ model.param_init_net.GaussianFill(
+ [],
+ input_blob.GetUnscopedName(),
+ shape=(args.batch_size, 3, args.img_size, args.img_size),
+ mean=0.0,
+ std=1.0)
+ workspace.RunNetOnce(model.param_init_net)
+ workspace.CreateNet(model.net, overwrite=True)
+ workspace.BenchmarkNet(model.net.Proto().name, 5, 20, True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/caffe2_validate.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/caffe2_validate.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cfaab38c095663fe32e4addbdf06b57bcb53614
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/caffe2_validate.py
@@ -0,0 +1,138 @@
+""" Caffe2 validation script
+
+This script is created to verify exported ONNX models running in Caffe2
+It utilizes the same PyTorch dataloader/processing pipeline for a
+fair comparison against the originals.
+
+Copyright 2020 Ross Wightman
+"""
+import argparse
+import numpy as np
+from caffe2.python import core, workspace, model_helper
+from caffe2.proto import caffe2_pb2
+from data import create_loader, resolve_data_config, Dataset
+from utils import AverageMeter
+import time
+
+parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation')
+parser.add_argument('data', metavar='DIR',
+ help='path to dataset')
+parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME',
+ help='caffe2 model pb name prefix')
+parser.add_argument('--c2-init', default='', type=str, metavar='PATH',
+ help='caffe2 model init .pb')
+parser.add_argument('--c2-predict', default='', type=str, metavar='PATH',
+ help='caffe2 model predict .pb')
+parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
+ help='number of data loading workers (default: 2)')
+parser.add_argument('-b', '--batch-size', default=256, type=int,
+ metavar='N', help='mini-batch size (default: 256)')
+parser.add_argument('--img-size', default=None, type=int,
+ metavar='N', help='Input image dimension, uses model default if empty')
+parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
+ help='Override mean pixel value of dataset')
+parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
+ help='Override std deviation of of dataset')
+parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
+ help='Override default crop pct of 0.875')
+parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
+ help='Image resize interpolation type (overrides model)')
+parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
+ help='use tensorflow mnasnet preporcessing')
+parser.add_argument('--print-freq', '-p', default=10, type=int,
+ metavar='N', help='print frequency (default: 10)')
+
+
+def main():
+ args = parser.parse_args()
+ args.gpu_id = 0
+ if args.c2_prefix:
+ args.c2_init = args.c2_prefix + '.init.pb'
+ args.c2_predict = args.c2_prefix + '.predict.pb'
+
+ model = model_helper.ModelHelper(name="validation_net", init_params=False)
+
+ # Bring in the init net from init_net.pb
+ init_net_proto = caffe2_pb2.NetDef()
+ with open(args.c2_init, "rb") as f:
+ init_net_proto.ParseFromString(f.read())
+ model.param_init_net = core.Net(init_net_proto)
+
+ # bring in the predict net from predict_net.pb
+ predict_net_proto = caffe2_pb2.NetDef()
+ with open(args.c2_predict, "rb") as f:
+ predict_net_proto.ParseFromString(f.read())
+ model.net = core.Net(predict_net_proto)
+
+ data_config = resolve_data_config(None, args)
+ loader = create_loader(
+ Dataset(args.data, load_bytes=args.tf_preprocessing),
+ input_size=data_config['input_size'],
+ batch_size=args.batch_size,
+ use_prefetcher=False,
+ interpolation=data_config['interpolation'],
+ mean=data_config['mean'],
+ std=data_config['std'],
+ num_workers=args.workers,
+ crop_pct=data_config['crop_pct'],
+ tensorflow_preprocessing=args.tf_preprocessing)
+
+ # this is so obvious, wonderful interface
+ input_blob = model.net.external_inputs[0]
+ output_blob = model.net.external_outputs[0]
+
+ if True:
+ device_opts = None
+ else:
+ # CUDA is crashing, no idea why, awesome error message, give it a try for kicks
+ device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id)
+ model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
+ model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
+
+ model.param_init_net.GaussianFill(
+ [], input_blob.GetUnscopedName(),
+ shape=(1,) + data_config['input_size'], mean=0.0, std=1.0)
+ workspace.RunNetOnce(model.param_init_net)
+ workspace.CreateNet(model.net, overwrite=True)
+
+ batch_time = AverageMeter()
+ top1 = AverageMeter()
+ top5 = AverageMeter()
+ end = time.time()
+ for i, (input, target) in enumerate(loader):
+ # run the net and return prediction
+ caffe2_in = input.data.numpy()
+ workspace.FeedBlob(input_blob, caffe2_in, device_opts)
+ workspace.RunNet(model.net, num_iter=1)
+ output = workspace.FetchBlob(output_blob)
+
+ # measure accuracy and record loss
+ prec1, prec5 = accuracy_np(output.data, target.numpy())
+ top1.update(prec1.item(), input.size(0))
+ top5.update(prec5.item(), input.size(0))
+
+ # measure elapsed time
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ if i % args.print_freq == 0:
+ print('Test: [{0}/{1}]\t'
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t'
+ 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
+ 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
+ i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg,
+ ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5))
+
+ print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
+ top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
+
+
+def accuracy_np(output, target):
+ max_indices = np.argsort(output, axis=1)[:, ::-1]
+ top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean()
+ top1 = 100 * np.equal(max_indices[:, 0], target).mean()
+ return top1, top5
+
+
+if __name__ == '__main__':
+ main()
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/__init__.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e441a5838d1e972823b9668ac8d459445f6f6ce
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/__init__.py
@@ -0,0 +1,5 @@
+from .gen_efficientnet import *
+from .mobilenetv3 import *
+from .model_factory import create_model
+from .config import is_exportable, is_scriptable, set_exportable, set_scriptable
+from .activations import *
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/__init__.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..813421a743ffc33b8eb53ebf62dd4a03d831b654
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/__init__.py
@@ -0,0 +1,137 @@
+from geffnet import config
+from geffnet.activations.activations_me import *
+from geffnet.activations.activations_jit import *
+from geffnet.activations.activations import *
+import torch
+
+_has_silu = 'silu' in dir(torch.nn.functional)
+
+_ACT_FN_DEFAULT = dict(
+ silu=F.silu if _has_silu else swish,
+ swish=F.silu if _has_silu else swish,
+ mish=mish,
+ relu=F.relu,
+ relu6=F.relu6,
+ sigmoid=sigmoid,
+ tanh=tanh,
+ hard_sigmoid=hard_sigmoid,
+ hard_swish=hard_swish,
+)
+
+_ACT_FN_JIT = dict(
+ silu=F.silu if _has_silu else swish_jit,
+ swish=F.silu if _has_silu else swish_jit,
+ mish=mish_jit,
+)
+
+_ACT_FN_ME = dict(
+ silu=F.silu if _has_silu else swish_me,
+ swish=F.silu if _has_silu else swish_me,
+ mish=mish_me,
+ hard_swish=hard_swish_me,
+ hard_sigmoid_jit=hard_sigmoid_me,
+)
+
+_ACT_LAYER_DEFAULT = dict(
+ silu=nn.SiLU if _has_silu else Swish,
+ swish=nn.SiLU if _has_silu else Swish,
+ mish=Mish,
+ relu=nn.ReLU,
+ relu6=nn.ReLU6,
+ sigmoid=Sigmoid,
+ tanh=Tanh,
+ hard_sigmoid=HardSigmoid,
+ hard_swish=HardSwish,
+)
+
+_ACT_LAYER_JIT = dict(
+ silu=nn.SiLU if _has_silu else SwishJit,
+ swish=nn.SiLU if _has_silu else SwishJit,
+ mish=MishJit,
+)
+
+_ACT_LAYER_ME = dict(
+ silu=nn.SiLU if _has_silu else SwishMe,
+ swish=nn.SiLU if _has_silu else SwishMe,
+ mish=MishMe,
+ hard_swish=HardSwishMe,
+ hard_sigmoid=HardSigmoidMe
+)
+
+_OVERRIDE_FN = dict()
+_OVERRIDE_LAYER = dict()
+
+
+def add_override_act_fn(name, fn):
+ global _OVERRIDE_FN
+ _OVERRIDE_FN[name] = fn
+
+
+def update_override_act_fn(overrides):
+ assert isinstance(overrides, dict)
+ global _OVERRIDE_FN
+ _OVERRIDE_FN.update(overrides)
+
+
+def clear_override_act_fn():
+ global _OVERRIDE_FN
+ _OVERRIDE_FN = dict()
+
+
+def add_override_act_layer(name, fn):
+ _OVERRIDE_LAYER[name] = fn
+
+
+def update_override_act_layer(overrides):
+ assert isinstance(overrides, dict)
+ global _OVERRIDE_LAYER
+ _OVERRIDE_LAYER.update(overrides)
+
+
+def clear_override_act_layer():
+ global _OVERRIDE_LAYER
+ _OVERRIDE_LAYER = dict()
+
+
+def get_act_fn(name='relu'):
+ """ Activation Function Factory
+ Fetching activation fns by name with this function allows export or torch script friendly
+ functions to be returned dynamically based on current config.
+ """
+ if name in _OVERRIDE_FN:
+ return _OVERRIDE_FN[name]
+ use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
+ if use_me and name in _ACT_FN_ME:
+ # If not exporting or scripting the model, first look for a memory optimized version
+ # activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin
+ return _ACT_FN_ME[name]
+ if config.is_exportable() and name in ('silu', 'swish'):
+ # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
+ return swish
+ use_jit = not (config.is_exportable() or config.is_no_jit())
+ # NOTE: export tracing should work with jit scripted components, but I keep running into issues
+ if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
+ return _ACT_FN_JIT[name]
+ return _ACT_FN_DEFAULT[name]
+
+
+def get_act_layer(name='relu'):
+ """ Activation Layer Factory
+ Fetching activation layers by name with this function allows export or torch script friendly
+ functions to be returned dynamically based on current config.
+ """
+ if name in _OVERRIDE_LAYER:
+ return _OVERRIDE_LAYER[name]
+ use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
+ if use_me and name in _ACT_LAYER_ME:
+ return _ACT_LAYER_ME[name]
+ if config.is_exportable() and name in ('silu', 'swish'):
+ # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
+ return Swish
+ use_jit = not (config.is_exportable() or config.is_no_jit())
+ # NOTE: export tracing should work with jit scripted components, but I keep running into issues
+ if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
+ return _ACT_LAYER_JIT[name]
+ return _ACT_LAYER_DEFAULT[name]
+
+
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/activations.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdea692d1397673b2513d898c33edbcb37d94240
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/activations.py
@@ -0,0 +1,102 @@
+""" Activations
+
+A collection of activations fn and modules with a common interface so that they can
+easily be swapped. All have an `inplace` arg even if not used.
+
+Copyright 2020 Ross Wightman
+"""
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+def swish(x, inplace: bool = False):
+ """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
+ and also as Swish (https://arxiv.org/abs/1710.05941).
+
+ TODO Rename to SiLU with addition to PyTorch
+ """
+ return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
+
+
+class Swish(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(Swish, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return swish(x, self.inplace)
+
+
+def mish(x, inplace: bool = False):
+ """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
+ """
+ return x.mul(F.softplus(x).tanh())
+
+
+class Mish(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(Mish, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return mish(x, self.inplace)
+
+
+def sigmoid(x, inplace: bool = False):
+ return x.sigmoid_() if inplace else x.sigmoid()
+
+
+# PyTorch has this, but not with a consistent inplace argmument interface
+class Sigmoid(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(Sigmoid, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return x.sigmoid_() if self.inplace else x.sigmoid()
+
+
+def tanh(x, inplace: bool = False):
+ return x.tanh_() if inplace else x.tanh()
+
+
+# PyTorch has this, but not with a consistent inplace argmument interface
+class Tanh(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(Tanh, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return x.tanh_() if self.inplace else x.tanh()
+
+
+def hard_swish(x, inplace: bool = False):
+ inner = F.relu6(x + 3.).div_(6.)
+ return x.mul_(inner) if inplace else x.mul(inner)
+
+
+class HardSwish(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSwish, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return hard_swish(x, self.inplace)
+
+
+def hard_sigmoid(x, inplace: bool = False):
+ if inplace:
+ return x.add_(3.).clamp_(0., 6.).div_(6.)
+ else:
+ return F.relu6(x + 3.) / 6.
+
+
+class HardSigmoid(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSigmoid, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return hard_sigmoid(x, self.inplace)
+
+
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/activations_jit.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/activations_jit.py
new file mode 100644
index 0000000000000000000000000000000000000000..7176b05e779787528a47f20d55d64d4a0f219360
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/activations_jit.py
@@ -0,0 +1,79 @@
+""" Activations (jit)
+
+A collection of jit-scripted activations fn and modules with a common interface so that they can
+easily be swapped. All have an `inplace` arg even if not used.
+
+All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
+currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
+versions if they contain in-place ops.
+
+Copyright 2020 Ross Wightman
+"""
+
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit',
+ 'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit']
+
+
+@torch.jit.script
+def swish_jit(x, inplace: bool = False):
+ """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
+ and also as Swish (https://arxiv.org/abs/1710.05941).
+
+ TODO Rename to SiLU with addition to PyTorch
+ """
+ return x.mul(x.sigmoid())
+
+
+@torch.jit.script
+def mish_jit(x, _inplace: bool = False):
+ """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
+ """
+ return x.mul(F.softplus(x).tanh())
+
+
+class SwishJit(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(SwishJit, self).__init__()
+
+ def forward(self, x):
+ return swish_jit(x)
+
+
+class MishJit(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(MishJit, self).__init__()
+
+ def forward(self, x):
+ return mish_jit(x)
+
+
+@torch.jit.script
+def hard_sigmoid_jit(x, inplace: bool = False):
+ # return F.relu6(x + 3.) / 6.
+ return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
+
+
+class HardSigmoidJit(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSigmoidJit, self).__init__()
+
+ def forward(self, x):
+ return hard_sigmoid_jit(x)
+
+
+@torch.jit.script
+def hard_swish_jit(x, inplace: bool = False):
+ # return x * (F.relu6(x + 3.) / 6)
+ return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
+
+
+class HardSwishJit(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSwishJit, self).__init__()
+
+ def forward(self, x):
+ return hard_swish_jit(x)
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/activations_me.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/activations_me.py
new file mode 100644
index 0000000000000000000000000000000000000000..e91df5a50fdbe40bc386e2541a4fda743ad95e9a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/activations_me.py
@@ -0,0 +1,174 @@
+""" Activations (memory-efficient w/ custom autograd)
+
+A collection of activations fn and modules with a common interface so that they can
+easily be swapped. All have an `inplace` arg even if not used.
+
+These activations are not compatible with jit scripting or ONNX export of the model, please use either
+the JIT or basic versions of the activations.
+
+Copyright 2020 Ross Wightman
+"""
+
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+__all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe',
+ 'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe']
+
+
+@torch.jit.script
+def swish_jit_fwd(x):
+ return x.mul(torch.sigmoid(x))
+
+
+@torch.jit.script
+def swish_jit_bwd(x, grad_output):
+ x_sigmoid = torch.sigmoid(x)
+ return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
+
+
+class SwishJitAutoFn(torch.autograd.Function):
+ """ torch.jit.script optimised Swish w/ memory-efficient checkpoint
+ Inspired by conversation btw Jeremy Howard & Adam Pazske
+ https://twitter.com/jeremyphoward/status/1188251041835315200
+
+ Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
+ and also as Swish (https://arxiv.org/abs/1710.05941).
+
+ TODO Rename to SiLU with addition to PyTorch
+ """
+
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return swish_jit_fwd(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x = ctx.saved_tensors[0]
+ return swish_jit_bwd(x, grad_output)
+
+
+def swish_me(x, inplace=False):
+ return SwishJitAutoFn.apply(x)
+
+
+class SwishMe(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(SwishMe, self).__init__()
+
+ def forward(self, x):
+ return SwishJitAutoFn.apply(x)
+
+
+@torch.jit.script
+def mish_jit_fwd(x):
+ return x.mul(torch.tanh(F.softplus(x)))
+
+
+@torch.jit.script
+def mish_jit_bwd(x, grad_output):
+ x_sigmoid = torch.sigmoid(x)
+ x_tanh_sp = F.softplus(x).tanh()
+ return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
+
+
+class MishJitAutoFn(torch.autograd.Function):
+ """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
+ A memory efficient, jit scripted variant of Mish
+ """
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return mish_jit_fwd(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x = ctx.saved_tensors[0]
+ return mish_jit_bwd(x, grad_output)
+
+
+def mish_me(x, inplace=False):
+ return MishJitAutoFn.apply(x)
+
+
+class MishMe(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(MishMe, self).__init__()
+
+ def forward(self, x):
+ return MishJitAutoFn.apply(x)
+
+
+@torch.jit.script
+def hard_sigmoid_jit_fwd(x, inplace: bool = False):
+ return (x + 3).clamp(min=0, max=6).div(6.)
+
+
+@torch.jit.script
+def hard_sigmoid_jit_bwd(x, grad_output):
+ m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
+ return grad_output * m
+
+
+class HardSigmoidJitAutoFn(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return hard_sigmoid_jit_fwd(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x = ctx.saved_tensors[0]
+ return hard_sigmoid_jit_bwd(x, grad_output)
+
+
+def hard_sigmoid_me(x, inplace: bool = False):
+ return HardSigmoidJitAutoFn.apply(x)
+
+
+class HardSigmoidMe(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSigmoidMe, self).__init__()
+
+ def forward(self, x):
+ return HardSigmoidJitAutoFn.apply(x)
+
+
+@torch.jit.script
+def hard_swish_jit_fwd(x):
+ return x * (x + 3).clamp(min=0, max=6).div(6.)
+
+
+@torch.jit.script
+def hard_swish_jit_bwd(x, grad_output):
+ m = torch.ones_like(x) * (x >= 3.)
+ m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
+ return grad_output * m
+
+
+class HardSwishJitAutoFn(torch.autograd.Function):
+ """A memory efficient, jit-scripted HardSwish activation"""
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return hard_swish_jit_fwd(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x = ctx.saved_tensors[0]
+ return hard_swish_jit_bwd(x, grad_output)
+
+
+def hard_swish_me(x, inplace=False):
+ return HardSwishJitAutoFn.apply(x)
+
+
+class HardSwishMe(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSwishMe, self).__init__()
+
+ def forward(self, x):
+ return HardSwishJitAutoFn.apply(x)
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/config.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..27d5307fd9ee0246f1e35f41520f17385d23f1dd
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/config.py
@@ -0,0 +1,123 @@
+""" Global layer config state
+"""
+from typing import Any, Optional
+
+__all__ = [
+ 'is_exportable', 'is_scriptable', 'is_no_jit', 'layer_config_kwargs',
+ 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
+]
+
+# Set to True if prefer to have layers with no jit optimization (includes activations)
+_NO_JIT = False
+
+# Set to True if prefer to have activation layers with no jit optimization
+# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
+# the jit flags so far are activations. This will change as more layers are updated and/or added.
+_NO_ACTIVATION_JIT = False
+
+# Set to True if exporting a model with Same padding via ONNX
+_EXPORTABLE = False
+
+# Set to True if wanting to use torch.jit.script on a model
+_SCRIPTABLE = False
+
+
+def is_no_jit():
+ return _NO_JIT
+
+
+class set_no_jit:
+ def __init__(self, mode: bool) -> None:
+ global _NO_JIT
+ self.prev = _NO_JIT
+ _NO_JIT = mode
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(self, *args: Any) -> bool:
+ global _NO_JIT
+ _NO_JIT = self.prev
+ return False
+
+
+def is_exportable():
+ return _EXPORTABLE
+
+
+class set_exportable:
+ def __init__(self, mode: bool) -> None:
+ global _EXPORTABLE
+ self.prev = _EXPORTABLE
+ _EXPORTABLE = mode
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(self, *args: Any) -> bool:
+ global _EXPORTABLE
+ _EXPORTABLE = self.prev
+ return False
+
+
+def is_scriptable():
+ return _SCRIPTABLE
+
+
+class set_scriptable:
+ def __init__(self, mode: bool) -> None:
+ global _SCRIPTABLE
+ self.prev = _SCRIPTABLE
+ _SCRIPTABLE = mode
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(self, *args: Any) -> bool:
+ global _SCRIPTABLE
+ _SCRIPTABLE = self.prev
+ return False
+
+
+class set_layer_config:
+ """ Layer config context manager that allows setting all layer config flags at once.
+ If a flag arg is None, it will not change the current value.
+ """
+ def __init__(
+ self,
+ scriptable: Optional[bool] = None,
+ exportable: Optional[bool] = None,
+ no_jit: Optional[bool] = None,
+ no_activation_jit: Optional[bool] = None):
+ global _SCRIPTABLE
+ global _EXPORTABLE
+ global _NO_JIT
+ global _NO_ACTIVATION_JIT
+ self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
+ if scriptable is not None:
+ _SCRIPTABLE = scriptable
+ if exportable is not None:
+ _EXPORTABLE = exportable
+ if no_jit is not None:
+ _NO_JIT = no_jit
+ if no_activation_jit is not None:
+ _NO_ACTIVATION_JIT = no_activation_jit
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(self, *args: Any) -> bool:
+ global _SCRIPTABLE
+ global _EXPORTABLE
+ global _NO_JIT
+ global _NO_ACTIVATION_JIT
+ _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
+ return False
+
+
+def layer_config_kwargs(kwargs):
+ """ Consume config kwargs and return contextmgr obj """
+ return set_layer_config(
+ scriptable=kwargs.pop('scriptable', None),
+ exportable=kwargs.pop('exportable', None),
+ no_jit=kwargs.pop('no_jit', None))
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/conv2d_layers.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/conv2d_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8467460c4b36e54c83ce2dcd3ebe91d3432cad2
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/conv2d_layers.py
@@ -0,0 +1,304 @@
+""" Conv2D w/ SAME padding, CondConv, MixedConv
+
+A collection of conv layers and padding helpers needed by EfficientNet, MixNet, and
+MobileNetV3 models that maintain weight compatibility with original Tensorflow models.
+
+Copyright 2020 Ross Wightman
+"""
+import collections.abc
+import math
+from functools import partial
+from itertools import repeat
+from typing import Tuple, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .config import *
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+ return parse
+
+
+_single = _ntuple(1)
+_pair = _ntuple(2)
+_triple = _ntuple(3)
+_quadruple = _ntuple(4)
+
+
+def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
+ return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
+
+
+def _get_padding(kernel_size, stride=1, dilation=1, **_):
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
+ return padding
+
+
+def _calc_same_pad(i: int, k: int, s: int, d: int):
+ return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0)
+
+
+def _same_pad_arg(input_size, kernel_size, stride, dilation):
+ ih, iw = input_size
+ kh, kw = kernel_size
+ pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
+ pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
+ return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
+
+
+def _split_channels(num_chan, num_groups):
+ split = [num_chan // num_groups for _ in range(num_groups)]
+ split[0] += num_chan - sum(split)
+ return split
+
+
+def conv2d_same(
+ x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
+ padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
+ ih, iw = x.size()[-2:]
+ kh, kw = weight.size()[-2:]
+ pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
+ pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
+ return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
+
+
+class Conv2dSame(nn.Conv2d):
+ """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
+ """
+
+ # pylint: disable=unused-argument
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True):
+ super(Conv2dSame, self).__init__(
+ in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
+
+ def forward(self, x):
+ return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+
+class Conv2dSameExport(nn.Conv2d):
+ """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
+
+ NOTE: This does not currently work with torch.jit.script
+ """
+
+ # pylint: disable=unused-argument
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True):
+ super(Conv2dSameExport, self).__init__(
+ in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
+ self.pad = None
+ self.pad_input_size = (0, 0)
+
+ def forward(self, x):
+ input_size = x.size()[-2:]
+ if self.pad is None:
+ pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
+ self.pad = nn.ZeroPad2d(pad_arg)
+ self.pad_input_size = input_size
+
+ if self.pad is not None:
+ x = self.pad(x)
+ return F.conv2d(
+ x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+
+def get_padding_value(padding, kernel_size, **kwargs):
+ dynamic = False
+ if isinstance(padding, str):
+ # for any string padding, the padding will be calculated for you, one of three ways
+ padding = padding.lower()
+ if padding == 'same':
+ # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
+ if _is_static_pad(kernel_size, **kwargs):
+ # static case, no extra overhead
+ padding = _get_padding(kernel_size, **kwargs)
+ else:
+ # dynamic padding
+ padding = 0
+ dynamic = True
+ elif padding == 'valid':
+ # 'VALID' padding, same as padding=0
+ padding = 0
+ else:
+ # Default to PyTorch style 'same'-ish symmetric padding
+ padding = _get_padding(kernel_size, **kwargs)
+ return padding, dynamic
+
+
+def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
+ padding = kwargs.pop('padding', '')
+ kwargs.setdefault('bias', False)
+ padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
+ if is_dynamic:
+ if is_exportable():
+ assert not is_scriptable()
+ return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
+ else:
+ return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
+ else:
+ return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
+
+
+class MixedConv2d(nn.ModuleDict):
+ """ Mixed Grouped Convolution
+ Based on MDConv and GroupedConv in MixNet impl:
+ https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding='', dilation=1, depthwise=False, **kwargs):
+ super(MixedConv2d, self).__init__()
+
+ kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
+ num_groups = len(kernel_size)
+ in_splits = _split_channels(in_channels, num_groups)
+ out_splits = _split_channels(out_channels, num_groups)
+ self.in_channels = sum(in_splits)
+ self.out_channels = sum(out_splits)
+ for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
+ conv_groups = out_ch if depthwise else 1
+ self.add_module(
+ str(idx),
+ create_conv2d_pad(
+ in_ch, out_ch, k, stride=stride,
+ padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
+ )
+ self.splits = in_splits
+
+ def forward(self, x):
+ x_split = torch.split(x, self.splits, 1)
+ x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())]
+ x = torch.cat(x_out, 1)
+ return x
+
+
+def get_condconv_initializer(initializer, num_experts, expert_shape):
+ def condconv_initializer(weight):
+ """CondConv initializer function."""
+ num_params = np.prod(expert_shape)
+ if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
+ weight.shape[1] != num_params):
+ raise (ValueError(
+ 'CondConv variables must have shape [num_experts, num_params]'))
+ for i in range(num_experts):
+ initializer(weight[i].view(expert_shape))
+ return condconv_initializer
+
+
+class CondConv2d(nn.Module):
+ """ Conditional Convolution
+ Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
+
+ Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
+ https://github.com/pytorch/pytorch/issues/17983
+ """
+ __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
+
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
+ super(CondConv2d, self).__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ padding_val, is_padding_dynamic = get_padding_value(
+ padding, kernel_size, stride=stride, dilation=dilation)
+ self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
+ self.padding = _pair(padding_val)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.num_experts = num_experts
+
+ self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
+ weight_num_param = 1
+ for wd in self.weight_shape:
+ weight_num_param *= wd
+ self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
+
+ if bias:
+ self.bias_shape = (self.out_channels,)
+ self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
+ else:
+ self.register_parameter('bias', None)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ init_weight = get_condconv_initializer(
+ partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
+ init_weight(self.weight)
+ if self.bias is not None:
+ fan_in = np.prod(self.weight_shape[1:])
+ bound = 1 / math.sqrt(fan_in)
+ init_bias = get_condconv_initializer(
+ partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
+ init_bias(self.bias)
+
+ def forward(self, x, routing_weights):
+ B, C, H, W = x.shape
+ weight = torch.matmul(routing_weights, self.weight)
+ new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
+ weight = weight.view(new_weight_shape)
+ bias = None
+ if self.bias is not None:
+ bias = torch.matmul(routing_weights, self.bias)
+ bias = bias.view(B * self.out_channels)
+ # move batch elements with channels so each batch element can be efficiently convolved with separate kernel
+ x = x.view(1, B * C, H, W)
+ if self.dynamic_padding:
+ out = conv2d_same(
+ x, weight, bias, stride=self.stride, padding=self.padding,
+ dilation=self.dilation, groups=self.groups * B)
+ else:
+ out = F.conv2d(
+ x, weight, bias, stride=self.stride, padding=self.padding,
+ dilation=self.dilation, groups=self.groups * B)
+ out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
+
+ # Literal port (from TF definition)
+ # x = torch.split(x, 1, 0)
+ # weight = torch.split(weight, 1, 0)
+ # if self.bias is not None:
+ # bias = torch.matmul(routing_weights, self.bias)
+ # bias = torch.split(bias, 1, 0)
+ # else:
+ # bias = [None] * B
+ # out = []
+ # for xi, wi, bi in zip(x, weight, bias):
+ # wi = wi.view(*self.weight_shape)
+ # if bi is not None:
+ # bi = bi.view(*self.bias_shape)
+ # out.append(self.conv_fn(
+ # xi, wi, bi, stride=self.stride, padding=self.padding,
+ # dilation=self.dilation, groups=self.groups))
+ # out = torch.cat(out, 0)
+ return out
+
+
+def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
+ assert 'groups' not in kwargs # only use 'depthwise' bool arg
+ if isinstance(kernel_size, list):
+ assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
+ # We're going to use only lists for defining the MixedConv2d kernel groups,
+ # ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
+ m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
+ else:
+ depthwise = kwargs.pop('depthwise', False)
+ groups = out_chs if depthwise else 1
+ if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
+ m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
+ else:
+ m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
+ return m
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/efficientnet_builder.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/efficientnet_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..95dd63d400e70d70664c5a433a2772363f865e61
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/efficientnet_builder.py
@@ -0,0 +1,683 @@
+""" EfficientNet / MobileNetV3 Blocks and Builder
+
+Copyright 2020 Ross Wightman
+"""
+import re
+from copy import deepcopy
+
+from .conv2d_layers import *
+from geffnet.activations import *
+
+__all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible',
+ 'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv',
+ 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'EfficientNetBuilder', 'decode_arch_def',
+ 'initialize_weight_default', 'initialize_weight_goog', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'
+]
+
+# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
+# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
+# NOTE: momentum varies btw .99 and .9997 depending on source
+# .99 in official TF TPU impl
+# .9997 (/w .999 in search space) for paper
+#
+# PyTorch defaults are momentum = .1, eps = 1e-5
+#
+BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
+BN_EPS_TF_DEFAULT = 1e-3
+_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
+
+
+def get_bn_args_tf():
+ return _BN_ARGS_TF.copy()
+
+
+def resolve_bn_args(kwargs):
+ bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
+ bn_momentum = kwargs.pop('bn_momentum', None)
+ if bn_momentum is not None:
+ bn_args['momentum'] = bn_momentum
+ bn_eps = kwargs.pop('bn_eps', None)
+ if bn_eps is not None:
+ bn_args['eps'] = bn_eps
+ return bn_args
+
+
+_SE_ARGS_DEFAULT = dict(
+ gate_fn=sigmoid,
+ act_layer=None, # None == use containing block's activation layer
+ reduce_mid=False,
+ divisor=1)
+
+
+def resolve_se_args(kwargs, in_chs, act_layer=None):
+ se_kwargs = kwargs.copy() if kwargs is not None else {}
+ # fill in args that aren't specified with the defaults
+ for k, v in _SE_ARGS_DEFAULT.items():
+ se_kwargs.setdefault(k, v)
+ # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
+ if not se_kwargs.pop('reduce_mid'):
+ se_kwargs['reduced_base_chs'] = in_chs
+ # act_layer override, if it remains None, the containing block's act_layer will be used
+ if se_kwargs['act_layer'] is None:
+ assert act_layer is not None
+ se_kwargs['act_layer'] = act_layer
+ return se_kwargs
+
+
+def resolve_act_layer(kwargs, default='relu'):
+ act_layer = kwargs.pop('act_layer', default)
+ if isinstance(act_layer, str):
+ act_layer = get_act_layer(act_layer)
+ return act_layer
+
+
+def make_divisible(v: int, divisor: int = 8, min_value: int = None):
+ min_value = min_value or divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ if new_v < 0.9 * v: # ensure round down does not go down by more than 10%.
+ new_v += divisor
+ return new_v
+
+
+def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
+ """Round number of filters based on depth multiplier."""
+ if not multiplier:
+ return channels
+ channels *= multiplier
+ return make_divisible(channels, divisor, channel_min)
+
+
+def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.):
+ """Apply drop connect."""
+ if not training:
+ return inputs
+
+ keep_prob = 1 - drop_connect_rate
+ random_tensor = keep_prob + torch.rand(
+ (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
+ random_tensor.floor_() # binarize
+ output = inputs.div(keep_prob) * random_tensor
+ return output
+
+
+class SqueezeExcite(nn.Module):
+
+ def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1):
+ super(SqueezeExcite, self).__init__()
+ reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
+ self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
+ self.act1 = act_layer(inplace=True)
+ self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
+ self.gate_fn = gate_fn
+
+ def forward(self, x):
+ x_se = x.mean((2, 3), keepdim=True)
+ x_se = self.conv_reduce(x_se)
+ x_se = self.act1(x_se)
+ x_se = self.conv_expand(x_se)
+ x = x * self.gate_fn(x_se)
+ return x
+
+
+class ConvBnAct(nn.Module):
+ def __init__(self, in_chs, out_chs, kernel_size,
+ stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
+ super(ConvBnAct, self).__init__()
+ assert stride in [1, 2]
+ norm_kwargs = norm_kwargs or {}
+ self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type)
+ self.bn1 = norm_layer(out_chs, **norm_kwargs)
+ self.act1 = act_layer(inplace=True)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+ return x
+
+
+class DepthwiseSeparableConv(nn.Module):
+ """ DepthwiseSeparable block
+ Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
+ factor of 1.0. This is an alternative to having a IR with optional first pw conv.
+ """
+ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
+ stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
+ pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
+ norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
+ super(DepthwiseSeparableConv, self).__init__()
+ assert stride in [1, 2]
+ norm_kwargs = norm_kwargs or {}
+ self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
+ self.drop_connect_rate = drop_connect_rate
+
+ self.conv_dw = select_conv2d(
+ in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True)
+ self.bn1 = norm_layer(in_chs, **norm_kwargs)
+ self.act1 = act_layer(inplace=True)
+
+ # Squeeze-and-excitation
+ if se_ratio is not None and se_ratio > 0.:
+ se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
+ self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
+ else:
+ self.se = nn.Identity()
+
+ self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
+ self.bn2 = norm_layer(out_chs, **norm_kwargs)
+ self.act2 = act_layer(inplace=True) if pw_act else nn.Identity()
+
+ def forward(self, x):
+ residual = x
+
+ x = self.conv_dw(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+
+ x = self.se(x)
+
+ x = self.conv_pw(x)
+ x = self.bn2(x)
+ x = self.act2(x)
+
+ if self.has_residual:
+ if self.drop_connect_rate > 0.:
+ x = drop_connect(x, self.training, self.drop_connect_rate)
+ x += residual
+ return x
+
+
+class InvertedResidual(nn.Module):
+ """ Inverted residual block w/ optional SE"""
+
+ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
+ stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
+ exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
+ se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
+ conv_kwargs=None, drop_connect_rate=0.):
+ super(InvertedResidual, self).__init__()
+ norm_kwargs = norm_kwargs or {}
+ conv_kwargs = conv_kwargs or {}
+ mid_chs: int = make_divisible(in_chs * exp_ratio)
+ self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
+ self.drop_connect_rate = drop_connect_rate
+
+ # Point-wise expansion
+ self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
+ self.bn1 = norm_layer(mid_chs, **norm_kwargs)
+ self.act1 = act_layer(inplace=True)
+
+ # Depth-wise convolution
+ self.conv_dw = select_conv2d(
+ mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs)
+ self.bn2 = norm_layer(mid_chs, **norm_kwargs)
+ self.act2 = act_layer(inplace=True)
+
+ # Squeeze-and-excitation
+ if se_ratio is not None and se_ratio > 0.:
+ se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
+ self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
+ else:
+ self.se = nn.Identity() # for jit.script compat
+
+ # Point-wise linear projection
+ self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
+ self.bn3 = norm_layer(out_chs, **norm_kwargs)
+
+ def forward(self, x):
+ residual = x
+
+ # Point-wise expansion
+ x = self.conv_pw(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+
+ # Depth-wise convolution
+ x = self.conv_dw(x)
+ x = self.bn2(x)
+ x = self.act2(x)
+
+ # Squeeze-and-excitation
+ x = self.se(x)
+
+ # Point-wise linear projection
+ x = self.conv_pwl(x)
+ x = self.bn3(x)
+
+ if self.has_residual:
+ if self.drop_connect_rate > 0.:
+ x = drop_connect(x, self.training, self.drop_connect_rate)
+ x += residual
+ return x
+
+
+class CondConvResidual(InvertedResidual):
+ """ Inverted residual block w/ CondConv routing"""
+
+ def __init__(self, in_chs, out_chs, dw_kernel_size=3,
+ stride=1, pad_type='', act_layer=nn.ReLU, noskip=False,
+ exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
+ se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
+ num_experts=0, drop_connect_rate=0.):
+
+ self.num_experts = num_experts
+ conv_kwargs = dict(num_experts=self.num_experts)
+
+ super(CondConvResidual, self).__init__(
+ in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type,
+ act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
+ pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
+ norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
+ drop_connect_rate=drop_connect_rate)
+
+ self.routing_fn = nn.Linear(in_chs, self.num_experts)
+
+ def forward(self, x):
+ residual = x
+
+ # CondConv routing
+ pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
+ routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
+
+ # Point-wise expansion
+ x = self.conv_pw(x, routing_weights)
+ x = self.bn1(x)
+ x = self.act1(x)
+
+ # Depth-wise convolution
+ x = self.conv_dw(x, routing_weights)
+ x = self.bn2(x)
+ x = self.act2(x)
+
+ # Squeeze-and-excitation
+ x = self.se(x)
+
+ # Point-wise linear projection
+ x = self.conv_pwl(x, routing_weights)
+ x = self.bn3(x)
+
+ if self.has_residual:
+ if self.drop_connect_rate > 0.:
+ x = drop_connect(x, self.training, self.drop_connect_rate)
+ x += residual
+ return x
+
+
+class EdgeResidual(nn.Module):
+ """ EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride"""
+
+ def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
+ stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
+ se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
+ super(EdgeResidual, self).__init__()
+ norm_kwargs = norm_kwargs or {}
+ mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio)
+ self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
+ self.drop_connect_rate = drop_connect_rate
+
+ # Expansion convolution
+ self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
+ self.bn1 = norm_layer(mid_chs, **norm_kwargs)
+ self.act1 = act_layer(inplace=True)
+
+ # Squeeze-and-excitation
+ if se_ratio is not None and se_ratio > 0.:
+ se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
+ self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
+ else:
+ self.se = nn.Identity()
+
+ # Point-wise linear projection
+ self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type)
+ self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs)
+
+ def forward(self, x):
+ residual = x
+
+ # Expansion convolution
+ x = self.conv_exp(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+
+ # Squeeze-and-excitation
+ x = self.se(x)
+
+ # Point-wise linear projection
+ x = self.conv_pwl(x)
+ x = self.bn2(x)
+
+ if self.has_residual:
+ if self.drop_connect_rate > 0.:
+ x = drop_connect(x, self.training, self.drop_connect_rate)
+ x += residual
+
+ return x
+
+
+class EfficientNetBuilder:
+ """ Build Trunk Blocks for Efficient/Mobile Networks
+
+ This ended up being somewhat of a cross between
+ https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
+ and
+ https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
+
+ """
+
+ def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
+ pad_type='', act_layer=None, se_kwargs=None,
+ norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.):
+ self.channel_multiplier = channel_multiplier
+ self.channel_divisor = channel_divisor
+ self.channel_min = channel_min
+ self.pad_type = pad_type
+ self.act_layer = act_layer
+ self.se_kwargs = se_kwargs
+ self.norm_layer = norm_layer
+ self.norm_kwargs = norm_kwargs
+ self.drop_connect_rate = drop_connect_rate
+
+ # updated during build
+ self.in_chs = None
+ self.block_idx = 0
+ self.block_count = 0
+
+ def _round_channels(self, chs):
+ return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
+
+ def _make_block(self, ba):
+ bt = ba.pop('block_type')
+ ba['in_chs'] = self.in_chs
+ ba['out_chs'] = self._round_channels(ba['out_chs'])
+ if 'fake_in_chs' in ba and ba['fake_in_chs']:
+ # FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU
+ ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
+ ba['norm_layer'] = self.norm_layer
+ ba['norm_kwargs'] = self.norm_kwargs
+ ba['pad_type'] = self.pad_type
+ # block act fn overrides the model default
+ ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
+ assert ba['act_layer'] is not None
+ if bt == 'ir':
+ ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
+ ba['se_kwargs'] = self.se_kwargs
+ if ba.get('num_experts', 0) > 0:
+ block = CondConvResidual(**ba)
+ else:
+ block = InvertedResidual(**ba)
+ elif bt == 'ds' or bt == 'dsa':
+ ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
+ ba['se_kwargs'] = self.se_kwargs
+ block = DepthwiseSeparableConv(**ba)
+ elif bt == 'er':
+ ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
+ ba['se_kwargs'] = self.se_kwargs
+ block = EdgeResidual(**ba)
+ elif bt == 'cn':
+ block = ConvBnAct(**ba)
+ else:
+ assert False, 'Uknkown block type (%s) while building model.' % bt
+ self.in_chs = ba['out_chs'] # update in_chs for arg of next block
+ return block
+
+ def _make_stack(self, stack_args):
+ blocks = []
+ # each stack (stage) contains a list of block arguments
+ for i, ba in enumerate(stack_args):
+ if i >= 1:
+ # only the first block in any stack can have a stride > 1
+ ba['stride'] = 1
+ block = self._make_block(ba)
+ blocks.append(block)
+ self.block_idx += 1 # incr global idx (across all stacks)
+ return nn.Sequential(*blocks)
+
+ def __call__(self, in_chs, block_args):
+ """ Build the blocks
+ Args:
+ in_chs: Number of input-channels passed to first block
+ block_args: A list of lists, outer list defines stages, inner
+ list contains strings defining block configuration(s)
+ Return:
+ List of block stacks (each stack wrapped in nn.Sequential)
+ """
+ self.in_chs = in_chs
+ self.block_count = sum([len(x) for x in block_args])
+ self.block_idx = 0
+ blocks = []
+ # outer list of block_args defines the stacks ('stages' by some conventions)
+ for stack_idx, stack in enumerate(block_args):
+ assert isinstance(stack, list)
+ stack = self._make_stack(stack)
+ blocks.append(stack)
+ return blocks
+
+
+def _parse_ksize(ss):
+ if ss.isdigit():
+ return int(ss)
+ else:
+ return [int(k) for k in ss.split('.')]
+
+
+def _decode_block_str(block_str):
+ """ Decode block definition string
+
+ Gets a list of block arg (dicts) through a string notation of arguments.
+ E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
+
+ All args can exist in any order with the exception of the leading string which
+ is assumed to indicate the block type.
+
+ leading string - block type (
+ ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
+ r - number of repeat blocks,
+ k - kernel size,
+ s - strides (1-9),
+ e - expansion ratio,
+ c - output channels,
+ se - squeeze/excitation ratio
+ n - activation fn ('re', 'r6', 'hs', or 'sw')
+ Args:
+ block_str: a string representation of block arguments.
+ Returns:
+ A list of block args (dicts)
+ Raises:
+ ValueError: if the string def not properly specified (TODO)
+ """
+ assert isinstance(block_str, str)
+ ops = block_str.split('_')
+ block_type = ops[0] # take the block type off the front
+ ops = ops[1:]
+ options = {}
+ noskip = False
+ for op in ops:
+ # string options being checked on individual basis, combine if they grow
+ if op == 'noskip':
+ noskip = True
+ elif op.startswith('n'):
+ # activation fn
+ key = op[0]
+ v = op[1:]
+ if v == 're':
+ value = get_act_layer('relu')
+ elif v == 'r6':
+ value = get_act_layer('relu6')
+ elif v == 'hs':
+ value = get_act_layer('hard_swish')
+ elif v == 'sw':
+ value = get_act_layer('swish')
+ else:
+ continue
+ options[key] = value
+ else:
+ # all numeric options
+ splits = re.split(r'(\d.*)', op)
+ if len(splits) >= 2:
+ key, value = splits[:2]
+ options[key] = value
+
+ # if act_layer is None, the model default (passed to model init) will be used
+ act_layer = options['n'] if 'n' in options else None
+ exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
+ pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
+ fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
+
+ num_repeat = int(options['r'])
+ # each type of block has different valid arguments, fill accordingly
+ if block_type == 'ir':
+ block_args = dict(
+ block_type=block_type,
+ dw_kernel_size=_parse_ksize(options['k']),
+ exp_kernel_size=exp_kernel_size,
+ pw_kernel_size=pw_kernel_size,
+ out_chs=int(options['c']),
+ exp_ratio=float(options['e']),
+ se_ratio=float(options['se']) if 'se' in options else None,
+ stride=int(options['s']),
+ act_layer=act_layer,
+ noskip=noskip,
+ )
+ if 'cc' in options:
+ block_args['num_experts'] = int(options['cc'])
+ elif block_type == 'ds' or block_type == 'dsa':
+ block_args = dict(
+ block_type=block_type,
+ dw_kernel_size=_parse_ksize(options['k']),
+ pw_kernel_size=pw_kernel_size,
+ out_chs=int(options['c']),
+ se_ratio=float(options['se']) if 'se' in options else None,
+ stride=int(options['s']),
+ act_layer=act_layer,
+ pw_act=block_type == 'dsa',
+ noskip=block_type == 'dsa' or noskip,
+ )
+ elif block_type == 'er':
+ block_args = dict(
+ block_type=block_type,
+ exp_kernel_size=_parse_ksize(options['k']),
+ pw_kernel_size=pw_kernel_size,
+ out_chs=int(options['c']),
+ exp_ratio=float(options['e']),
+ fake_in_chs=fake_in_chs,
+ se_ratio=float(options['se']) if 'se' in options else None,
+ stride=int(options['s']),
+ act_layer=act_layer,
+ noskip=noskip,
+ )
+ elif block_type == 'cn':
+ block_args = dict(
+ block_type=block_type,
+ kernel_size=int(options['k']),
+ out_chs=int(options['c']),
+ stride=int(options['s']),
+ act_layer=act_layer,
+ )
+ else:
+ assert False, 'Unknown block type (%s)' % block_type
+
+ return block_args, num_repeat
+
+
+def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
+ """ Per-stage depth scaling
+ Scales the block repeats in each stage. This depth scaling impl maintains
+ compatibility with the EfficientNet scaling method, while allowing sensible
+ scaling for other models that may have multiple block arg definitions in each stage.
+ """
+
+ # We scale the total repeat count for each stage, there may be multiple
+ # block arg defs per stage so we need to sum.
+ num_repeat = sum(repeats)
+ if depth_trunc == 'round':
+ # Truncating to int by rounding allows stages with few repeats to remain
+ # proportionally smaller for longer. This is a good choice when stage definitions
+ # include single repeat stages that we'd prefer to keep that way as long as possible
+ num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
+ else:
+ # The default for EfficientNet truncates repeats to int via 'ceil'.
+ # Any multiplier > 1.0 will result in an increased depth for every stage.
+ num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
+
+ # Proportionally distribute repeat count scaling to each block definition in the stage.
+ # Allocation is done in reverse as it results in the first block being less likely to be scaled.
+ # The first block makes less sense to repeat in most of the arch definitions.
+ repeats_scaled = []
+ for r in repeats[::-1]:
+ rs = max(1, round((r / num_repeat * num_repeat_scaled)))
+ repeats_scaled.append(rs)
+ num_repeat -= r
+ num_repeat_scaled -= rs
+ repeats_scaled = repeats_scaled[::-1]
+
+ # Apply the calculated scaling to each block arg in the stage
+ sa_scaled = []
+ for ba, rep in zip(stack_args, repeats_scaled):
+ sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
+ return sa_scaled
+
+
+def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
+ arch_args = []
+ for stack_idx, block_strings in enumerate(arch_def):
+ assert isinstance(block_strings, list)
+ stack_args = []
+ repeats = []
+ for block_str in block_strings:
+ assert isinstance(block_str, str)
+ ba, rep = _decode_block_str(block_str)
+ if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
+ ba['num_experts'] *= experts_multiplier
+ stack_args.append(ba)
+ repeats.append(rep)
+ if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
+ arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
+ else:
+ arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
+ return arch_args
+
+
+def initialize_weight_goog(m, n='', fix_group_fanout=True):
+ # weight init as per Tensorflow Official impl
+ # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
+ if isinstance(m, CondConv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ if fix_group_fanout:
+ fan_out //= m.groups
+ init_weight_fn = get_condconv_initializer(
+ lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
+ init_weight_fn(m.weight)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ if fix_group_fanout:
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1.0)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ fan_out = m.weight.size(0) # fan-out
+ fan_in = 0
+ if 'routing_fn' in n:
+ fan_in = m.weight.size(1)
+ init_range = 1.0 / math.sqrt(fan_in + fan_out)
+ m.weight.data.uniform_(-init_range, init_range)
+ m.bias.data.zero_()
+
+
+def initialize_weight_default(m, n=''):
+ if isinstance(m, CondConv2d):
+ init_fn = get_condconv_initializer(partial(
+ nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape)
+ init_fn(m.weight)
+ elif isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1.0)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/gen_efficientnet.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/gen_efficientnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd170d4cc5bed6ca82b61539902b470d3320c691
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/gen_efficientnet.py
@@ -0,0 +1,1450 @@
+""" Generic Efficient Networks
+
+A generic MobileNet class with building blocks to support a variety of models:
+
+* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent ports)
+ - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946
+ - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971
+ - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665
+ - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252
+
+* EfficientNet-Lite
+
+* MixNet (Small, Medium, and Large)
+ - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595
+
+* MNasNet B1, A1 (SE), Small
+ - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626
+
+* FBNet-C
+ - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443
+
+* Single-Path NAS Pixel1
+ - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877
+
+* And likely more...
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .config import layer_config_kwargs, is_scriptable
+from .conv2d_layers import select_conv2d
+from .helpers import load_pretrained
+from .efficientnet_builder import *
+
+__all__ = ['GenEfficientNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140',
+ 'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small',
+ 'mobilenetv2_100', 'mobilenetv2_140', 'mobilenetv2_110d', 'mobilenetv2_120d',
+ 'fbnetc_100', 'spnasnet_100', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3',
+ 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8',
+ 'efficientnet_l2', 'efficientnet_es', 'efficientnet_em', 'efficientnet_el',
+ 'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e',
+ 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4',
+ 'tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3',
+ 'tf_efficientnet_b4', 'tf_efficientnet_b5', 'tf_efficientnet_b6', 'tf_efficientnet_b7', 'tf_efficientnet_b8',
+ 'tf_efficientnet_b0_ap', 'tf_efficientnet_b1_ap', 'tf_efficientnet_b2_ap', 'tf_efficientnet_b3_ap',
+ 'tf_efficientnet_b4_ap', 'tf_efficientnet_b5_ap', 'tf_efficientnet_b6_ap', 'tf_efficientnet_b7_ap',
+ 'tf_efficientnet_b8_ap', 'tf_efficientnet_b0_ns', 'tf_efficientnet_b1_ns', 'tf_efficientnet_b2_ns',
+ 'tf_efficientnet_b3_ns', 'tf_efficientnet_b4_ns', 'tf_efficientnet_b5_ns', 'tf_efficientnet_b6_ns',
+ 'tf_efficientnet_b7_ns', 'tf_efficientnet_l2_ns', 'tf_efficientnet_l2_ns_475',
+ 'tf_efficientnet_es', 'tf_efficientnet_em', 'tf_efficientnet_el',
+ 'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e',
+ 'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3',
+ 'tf_efficientnet_lite4',
+ 'mixnet_s', 'mixnet_m', 'mixnet_l', 'mixnet_xl', 'tf_mixnet_s', 'tf_mixnet_m', 'tf_mixnet_l']
+
+
+model_urls = {
+ 'mnasnet_050': None,
+ 'mnasnet_075': None,
+ 'mnasnet_100':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth',
+ 'mnasnet_140': None,
+ 'mnasnet_small': None,
+
+ 'semnasnet_050': None,
+ 'semnasnet_075': None,
+ 'semnasnet_100':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth',
+ 'semnasnet_140': None,
+
+ 'mobilenetv2_100':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth',
+ 'mobilenetv2_110d':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth',
+ 'mobilenetv2_120d':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth',
+ 'mobilenetv2_140':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth',
+
+ 'fbnetc_100':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth',
+ 'spnasnet_100':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth',
+
+ 'efficientnet_b0':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth',
+ 'efficientnet_b1':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
+ 'efficientnet_b2':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth',
+ 'efficientnet_b3':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth',
+ 'efficientnet_b4': None,
+ 'efficientnet_b5': None,
+ 'efficientnet_b6': None,
+ 'efficientnet_b7': None,
+ 'efficientnet_b8': None,
+ 'efficientnet_l2': None,
+
+ 'efficientnet_es':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth',
+ 'efficientnet_em': None,
+ 'efficientnet_el': None,
+
+ 'efficientnet_cc_b0_4e': None,
+ 'efficientnet_cc_b0_8e': None,
+ 'efficientnet_cc_b1_8e': None,
+
+ 'efficientnet_lite0': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth',
+ 'efficientnet_lite1': None,
+ 'efficientnet_lite2': None,
+ 'efficientnet_lite3': None,
+ 'efficientnet_lite4': None,
+
+ 'tf_efficientnet_b0':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
+ 'tf_efficientnet_b1':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
+ 'tf_efficientnet_b2':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
+ 'tf_efficientnet_b3':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
+ 'tf_efficientnet_b4':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
+ 'tf_efficientnet_b5':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
+ 'tf_efficientnet_b6':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
+ 'tf_efficientnet_b7':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
+ 'tf_efficientnet_b8':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
+
+ 'tf_efficientnet_b0_ap':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth',
+ 'tf_efficientnet_b1_ap':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth',
+ 'tf_efficientnet_b2_ap':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth',
+ 'tf_efficientnet_b3_ap':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth',
+ 'tf_efficientnet_b4_ap':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth',
+ 'tf_efficientnet_b5_ap':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth',
+ 'tf_efficientnet_b6_ap':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth',
+ 'tf_efficientnet_b7_ap':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth',
+ 'tf_efficientnet_b8_ap':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth',
+
+ 'tf_efficientnet_b0_ns':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
+ 'tf_efficientnet_b1_ns':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
+ 'tf_efficientnet_b2_ns':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
+ 'tf_efficientnet_b3_ns':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
+ 'tf_efficientnet_b4_ns':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
+ 'tf_efficientnet_b5_ns':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
+ 'tf_efficientnet_b6_ns':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
+ 'tf_efficientnet_b7_ns':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
+ 'tf_efficientnet_l2_ns_475':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
+ 'tf_efficientnet_l2_ns':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
+
+ 'tf_efficientnet_es':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
+ 'tf_efficientnet_em':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth',
+ 'tf_efficientnet_el':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth',
+
+ 'tf_efficientnet_cc_b0_4e':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth',
+ 'tf_efficientnet_cc_b0_8e':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth',
+ 'tf_efficientnet_cc_b1_8e':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth',
+
+ 'tf_efficientnet_lite0':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth',
+ 'tf_efficientnet_lite1':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth',
+ 'tf_efficientnet_lite2':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth',
+ 'tf_efficientnet_lite3':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth',
+ 'tf_efficientnet_lite4':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth',
+
+ 'mixnet_s': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth',
+ 'mixnet_m': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth',
+ 'mixnet_l': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth',
+ 'mixnet_xl': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth',
+
+ 'tf_mixnet_s':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth',
+ 'tf_mixnet_m':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth',
+ 'tf_mixnet_l':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth',
+}
+
+
+class GenEfficientNet(nn.Module):
+ """ Generic EfficientNets
+
+ An implementation of mobile optimized networks that covers:
+ * EfficientNet (B0-B8, L2, CondConv, EdgeTPU)
+ * MixNet (Small, Medium, and Large, XL)
+ * MNASNet A1, B1, and small
+ * FBNet C
+ * Single-Path NAS Pixel1
+ """
+
+ def __init__(self, block_args, num_classes=1000, in_chans=3, num_features=1280, stem_size=32, fix_stem=False,
+ channel_multiplier=1.0, channel_divisor=8, channel_min=None,
+ pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
+ se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
+ weight_init='goog'):
+ super(GenEfficientNet, self).__init__()
+ self.drop_rate = drop_rate
+
+ if not fix_stem:
+ stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
+ self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
+ self.bn1 = norm_layer(stem_size, **norm_kwargs)
+ self.act1 = act_layer(inplace=True)
+ in_chs = stem_size
+
+ builder = EfficientNetBuilder(
+ channel_multiplier, channel_divisor, channel_min,
+ pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_connect_rate)
+ self.blocks = nn.Sequential(*builder(in_chs, block_args))
+ in_chs = builder.in_chs
+
+ self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type)
+ self.bn2 = norm_layer(num_features, **norm_kwargs)
+ self.act2 = act_layer(inplace=True)
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
+ self.classifier = nn.Linear(num_features, num_classes)
+
+ for n, m in self.named_modules():
+ if weight_init == 'goog':
+ initialize_weight_goog(m, n)
+ else:
+ initialize_weight_default(m, n)
+
+ def features(self, x):
+ x = self.conv_stem(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+ x = self.blocks(x)
+ x = self.conv_head(x)
+ x = self.bn2(x)
+ x = self.act2(x)
+ return x
+
+ def as_sequential(self):
+ layers = [self.conv_stem, self.bn1, self.act1]
+ layers.extend(self.blocks)
+ layers.extend([
+ self.conv_head, self.bn2, self.act2,
+ self.global_pool, nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.features(x)
+ x = self.global_pool(x)
+ x = x.flatten(1)
+ if self.drop_rate > 0.:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ return self.classifier(x)
+
+
+def _create_model(model_kwargs, variant, pretrained=False):
+ as_sequential = model_kwargs.pop('as_sequential', False)
+ model = GenEfficientNet(**model_kwargs)
+ if pretrained:
+ load_pretrained(model, model_urls[variant])
+ if as_sequential:
+ model = model.as_sequential()
+ return model
+
+
+def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates a mnasnet-a1 model.
+
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
+ Paper: https://arxiv.org/pdf/1807.11626.pdf.
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer.
+ """
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_e1_c16_noskip'],
+ # stage 1, 112x112 in
+ ['ir_r2_k3_s2_e6_c24'],
+ # stage 2, 56x56 in
+ ['ir_r3_k5_s2_e3_c40_se0.25'],
+ # stage 3, 28x28 in
+ ['ir_r4_k3_s2_e6_c80'],
+ # stage 4, 14x14in
+ ['ir_r2_k3_s1_e6_c112_se0.25'],
+ # stage 5, 14x14in
+ ['ir_r3_k5_s2_e6_c160_se0.25'],
+ # stage 6, 7x7 in
+ ['ir_r1_k3_s1_e6_c320'],
+ ]
+ with layer_config_kwargs(kwargs):
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ stem_size=32,
+ channel_multiplier=channel_multiplier,
+ act_layer=resolve_act_layer(kwargs, 'relu'),
+ norm_kwargs=resolve_bn_args(kwargs),
+ **kwargs
+ )
+ model = _create_model(model_kwargs, variant, pretrained)
+ return model
+
+
+def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates a mnasnet-b1 model.
+
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
+ Paper: https://arxiv.org/pdf/1807.11626.pdf.
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer.
+ """
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_c16_noskip'],
+ # stage 1, 112x112 in
+ ['ir_r3_k3_s2_e3_c24'],
+ # stage 2, 56x56 in
+ ['ir_r3_k5_s2_e3_c40'],
+ # stage 3, 28x28 in
+ ['ir_r3_k5_s2_e6_c80'],
+ # stage 4, 14x14in
+ ['ir_r2_k3_s1_e6_c96'],
+ # stage 5, 14x14in
+ ['ir_r4_k5_s2_e6_c192'],
+ # stage 6, 7x7 in
+ ['ir_r1_k3_s1_e6_c320_noskip']
+ ]
+ with layer_config_kwargs(kwargs):
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ stem_size=32,
+ channel_multiplier=channel_multiplier,
+ act_layer=resolve_act_layer(kwargs, 'relu'),
+ norm_kwargs=resolve_bn_args(kwargs),
+ **kwargs
+ )
+ model = _create_model(model_kwargs, variant, pretrained)
+ return model
+
+
+def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates a mnasnet-b1 model.
+
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
+ Paper: https://arxiv.org/pdf/1807.11626.pdf.
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer.
+ """
+ arch_def = [
+ ['ds_r1_k3_s1_c8'],
+ ['ir_r1_k3_s2_e3_c16'],
+ ['ir_r2_k3_s2_e6_c16'],
+ ['ir_r4_k5_s2_e6_c32_se0.25'],
+ ['ir_r3_k3_s1_e6_c32_se0.25'],
+ ['ir_r3_k5_s2_e6_c88_se0.25'],
+ ['ir_r1_k3_s1_e6_c144']
+ ]
+ with layer_config_kwargs(kwargs):
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ stem_size=8,
+ channel_multiplier=channel_multiplier,
+ act_layer=resolve_act_layer(kwargs, 'relu'),
+ norm_kwargs=resolve_bn_args(kwargs),
+ **kwargs
+ )
+ model = _create_model(model_kwargs, variant, pretrained)
+ return model
+
+
+def _gen_mobilenet_v2(
+ variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs):
+ """ Generate MobileNet-V2 network
+ Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
+ Paper: https://arxiv.org/abs/1801.04381
+ """
+ arch_def = [
+ ['ds_r1_k3_s1_c16'],
+ ['ir_r2_k3_s2_e6_c24'],
+ ['ir_r3_k3_s2_e6_c32'],
+ ['ir_r4_k3_s2_e6_c64'],
+ ['ir_r3_k3_s1_e6_c96'],
+ ['ir_r3_k3_s2_e6_c160'],
+ ['ir_r1_k3_s1_e6_c320'],
+ ]
+ with layer_config_kwargs(kwargs):
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
+ num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None),
+ stem_size=32,
+ fix_stem=fix_stem_head,
+ channel_multiplier=channel_multiplier,
+ norm_kwargs=resolve_bn_args(kwargs),
+ act_layer=nn.ReLU6,
+ **kwargs
+ )
+ model = _create_model(model_kwargs, variant, pretrained)
+ return model
+
+
+def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """ FBNet-C
+
+ Paper: https://arxiv.org/abs/1812.03443
+ Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py
+
+ NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper,
+ it was used to confirm some building block details
+ """
+ arch_def = [
+ ['ir_r1_k3_s1_e1_c16'],
+ ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'],
+ ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'],
+ ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'],
+ ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'],
+ ['ir_r4_k5_s2_e6_c184'],
+ ['ir_r1_k3_s1_e6_c352'],
+ ]
+ with layer_config_kwargs(kwargs):
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ stem_size=16,
+ num_features=1984, # paper suggests this, but is not 100% clear
+ channel_multiplier=channel_multiplier,
+ act_layer=resolve_act_layer(kwargs, 'relu'),
+ norm_kwargs=resolve_bn_args(kwargs),
+ **kwargs
+ )
+ model = _create_model(model_kwargs, variant, pretrained)
+ return model
+
+
+def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates the Single-Path NAS model from search targeted for Pixel1 phone.
+
+ Paper: https://arxiv.org/abs/1904.02877
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer.
+ """
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_c16_noskip'],
+ # stage 1, 112x112 in
+ ['ir_r3_k3_s2_e3_c24'],
+ # stage 2, 56x56 in
+ ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'],
+ # stage 3, 28x28 in
+ ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'],
+ # stage 4, 14x14in
+ ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'],
+ # stage 5, 14x14in
+ ['ir_r4_k5_s2_e6_c192'],
+ # stage 6, 7x7 in
+ ['ir_r1_k3_s1_e6_c320_noskip']
+ ]
+ with layer_config_kwargs(kwargs):
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ stem_size=32,
+ channel_multiplier=channel_multiplier,
+ act_layer=resolve_act_layer(kwargs, 'relu'),
+ norm_kwargs=resolve_bn_args(kwargs),
+ **kwargs
+ )
+ model = _create_model(model_kwargs, variant, pretrained)
+ return model
+
+
+def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates an EfficientNet model.
+
+ Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
+ Paper: https://arxiv.org/abs/1905.11946
+
+ EfficientNet params
+ name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
+ 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
+ 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
+ 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
+ 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
+ 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
+ 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
+ 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
+ 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
+ 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer
+ depth_multiplier: multiplier to number of repeats per stage
+
+ """
+ arch_def = [
+ ['ds_r1_k3_s1_e1_c16_se0.25'],
+ ['ir_r2_k3_s2_e6_c24_se0.25'],
+ ['ir_r2_k5_s2_e6_c40_se0.25'],
+ ['ir_r3_k3_s2_e6_c80_se0.25'],
+ ['ir_r3_k5_s1_e6_c112_se0.25'],
+ ['ir_r4_k5_s2_e6_c192_se0.25'],
+ ['ir_r1_k3_s1_e6_c320_se0.25'],
+ ]
+ with layer_config_kwargs(kwargs):
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier),
+ num_features=round_channels(1280, channel_multiplier, 8, None),
+ stem_size=32,
+ channel_multiplier=channel_multiplier,
+ act_layer=resolve_act_layer(kwargs, 'swish'),
+ norm_kwargs=resolve_bn_args(kwargs),
+ **kwargs,
+ )
+ model = _create_model(model_kwargs, variant, pretrained)
+ return model
+
+
+def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
+ arch_def = [
+ # NOTE `fc` is present to override a mismatch between stem channels and in chs not
+ # present in other models
+ ['er_r1_k3_s1_e4_c24_fc24_noskip'],
+ ['er_r2_k3_s2_e8_c32'],
+ ['er_r4_k3_s2_e8_c48'],
+ ['ir_r5_k5_s2_e8_c96'],
+ ['ir_r4_k5_s1_e8_c144'],
+ ['ir_r2_k5_s2_e8_c192'],
+ ]
+ with layer_config_kwargs(kwargs):
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier),
+ num_features=round_channels(1280, channel_multiplier, 8, None),
+ stem_size=32,
+ channel_multiplier=channel_multiplier,
+ act_layer=resolve_act_layer(kwargs, 'relu'),
+ norm_kwargs=resolve_bn_args(kwargs),
+ **kwargs,
+ )
+ model = _create_model(model_kwargs, variant, pretrained)
+ return model
+
+
+def _gen_efficientnet_condconv(
+ variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs):
+ """Creates an efficientnet-condconv model."""
+ arch_def = [
+ ['ds_r1_k3_s1_e1_c16_se0.25'],
+ ['ir_r2_k3_s2_e6_c24_se0.25'],
+ ['ir_r2_k5_s2_e6_c40_se0.25'],
+ ['ir_r3_k3_s2_e6_c80_se0.25'],
+ ['ir_r3_k5_s1_e6_c112_se0.25_cc4'],
+ ['ir_r4_k5_s2_e6_c192_se0.25_cc4'],
+ ['ir_r1_k3_s1_e6_c320_se0.25_cc4'],
+ ]
+ with layer_config_kwargs(kwargs):
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier),
+ num_features=round_channels(1280, channel_multiplier, 8, None),
+ stem_size=32,
+ channel_multiplier=channel_multiplier,
+ act_layer=resolve_act_layer(kwargs, 'swish'),
+ norm_kwargs=resolve_bn_args(kwargs),
+ **kwargs,
+ )
+ model = _create_model(model_kwargs, variant, pretrained)
+ return model
+
+
+def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates an EfficientNet-Lite model.
+
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite
+ Paper: https://arxiv.org/abs/1905.11946
+
+ EfficientNet params
+ name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
+ 'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
+ 'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
+ 'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
+ 'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
+ 'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer
+ depth_multiplier: multiplier to number of repeats per stage
+ """
+ arch_def = [
+ ['ds_r1_k3_s1_e1_c16'],
+ ['ir_r2_k3_s2_e6_c24'],
+ ['ir_r2_k5_s2_e6_c40'],
+ ['ir_r3_k3_s2_e6_c80'],
+ ['ir_r3_k5_s1_e6_c112'],
+ ['ir_r4_k5_s2_e6_c192'],
+ ['ir_r1_k3_s1_e6_c320'],
+ ]
+ with layer_config_kwargs(kwargs):
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True),
+ num_features=1280,
+ stem_size=32,
+ fix_stem=True,
+ channel_multiplier=channel_multiplier,
+ act_layer=nn.ReLU6,
+ norm_kwargs=resolve_bn_args(kwargs),
+ **kwargs,
+ )
+ model = _create_model(model_kwargs, variant, pretrained)
+ return model
+
+
+def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates a MixNet Small model.
+
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
+ Paper: https://arxiv.org/abs/1907.09595
+ """
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_e1_c16'], # relu
+ # stage 1, 112x112 in
+ ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu
+ # stage 2, 56x56 in
+ ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
+ # stage 3, 28x28 in
+ ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish
+ # stage 4, 14x14in
+ ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
+ # stage 5, 14x14in
+ ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
+ # 7x7
+ ]
+ with layer_config_kwargs(kwargs):
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ num_features=1536,
+ stem_size=16,
+ channel_multiplier=channel_multiplier,
+ act_layer=resolve_act_layer(kwargs, 'relu'),
+ norm_kwargs=resolve_bn_args(kwargs),
+ **kwargs
+ )
+ model = _create_model(model_kwargs, variant, pretrained)
+ return model
+
+
+def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates a MixNet Medium-Large model.
+
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
+ Paper: https://arxiv.org/abs/1907.09595
+ """
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_e1_c24'], # relu
+ # stage 1, 112x112 in
+ ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu
+ # stage 2, 56x56 in
+ ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
+ # stage 3, 28x28 in
+ ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish
+ # stage 4, 14x14in
+ ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
+ # stage 5, 14x14in
+ ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
+ # 7x7
+ ]
+ with layer_config_kwargs(kwargs):
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
+ num_features=1536,
+ stem_size=24,
+ channel_multiplier=channel_multiplier,
+ act_layer=resolve_act_layer(kwargs, 'relu'),
+ norm_kwargs=resolve_bn_args(kwargs),
+ **kwargs
+ )
+ model = _create_model(model_kwargs, variant, pretrained)
+ return model
+
+
+def mnasnet_050(pretrained=False, **kwargs):
+ """ MNASNet B1, depth multiplier of 0.5. """
+ model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mnasnet_075(pretrained=False, **kwargs):
+ """ MNASNet B1, depth multiplier of 0.75. """
+ model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mnasnet_100(pretrained=False, **kwargs):
+ """ MNASNet B1, depth multiplier of 1.0. """
+ model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mnasnet_b1(pretrained=False, **kwargs):
+ """ MNASNet B1, depth multiplier of 1.0. """
+ return mnasnet_100(pretrained, **kwargs)
+
+
+def mnasnet_140(pretrained=False, **kwargs):
+ """ MNASNet B1, depth multiplier of 1.4 """
+ model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+def semnasnet_050(pretrained=False, **kwargs):
+ """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """
+ model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs)
+ return model
+
+
+def semnasnet_075(pretrained=False, **kwargs):
+ """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """
+ model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs)
+ return model
+
+
+def semnasnet_100(pretrained=False, **kwargs):
+ """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """
+ model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mnasnet_a1(pretrained=False, **kwargs):
+ """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """
+ return semnasnet_100(pretrained, **kwargs)
+
+
+def semnasnet_140(pretrained=False, **kwargs):
+ """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """
+ model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mnasnet_small(pretrained=False, **kwargs):
+ """ MNASNet Small, depth multiplier of 1.0. """
+ model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mobilenetv2_100(pretrained=False, **kwargs):
+ """ MobileNet V2 w/ 1.0 channel multiplier """
+ model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mobilenetv2_140(pretrained=False, **kwargs):
+ """ MobileNet V2 w/ 1.4 channel multiplier """
+ model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mobilenetv2_110d(pretrained=False, **kwargs):
+ """ MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers"""
+ model = _gen_mobilenet_v2(
+ 'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mobilenetv2_120d(pretrained=False, **kwargs):
+ """ MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """
+ model = _gen_mobilenet_v2(
+ 'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs)
+ return model
+
+
+def fbnetc_100(pretrained=False, **kwargs):
+ """ FBNet-C """
+ if pretrained:
+ # pretrained model trained with non-default BN epsilon
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def spnasnet_100(pretrained=False, **kwargs):
+ """ Single-Path NAS Pixel1"""
+ model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_b0(pretrained=False, **kwargs):
+ """ EfficientNet-B0 """
+ # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_b1(pretrained=False, **kwargs):
+ """ EfficientNet-B1 """
+ # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_b2(pretrained=False, **kwargs):
+ """ EfficientNet-B2 """
+ # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_b3(pretrained=False, **kwargs):
+ """ EfficientNet-B3 """
+ # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_b4(pretrained=False, **kwargs):
+ """ EfficientNet-B4 """
+ # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_b5(pretrained=False, **kwargs):
+ """ EfficientNet-B5 """
+ # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_b6(pretrained=False, **kwargs):
+ """ EfficientNet-B6 """
+ # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_b7(pretrained=False, **kwargs):
+ """ EfficientNet-B7 """
+ # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_b8(pretrained=False, **kwargs):
+ """ EfficientNet-B8 """
+ # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_l2(pretrained=False, **kwargs):
+ """ EfficientNet-L2. """
+ # NOTE for train, drop_rate should be 0.5
+ model = _gen_efficientnet(
+ 'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_es(pretrained=False, **kwargs):
+ """ EfficientNet-Edge Small. """
+ model = _gen_efficientnet_edge(
+ 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_em(pretrained=False, **kwargs):
+ """ EfficientNet-Edge-Medium. """
+ model = _gen_efficientnet_edge(
+ 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_el(pretrained=False, **kwargs):
+ """ EfficientNet-Edge-Large. """
+ model = _gen_efficientnet_edge(
+ 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_cc_b0_4e(pretrained=False, **kwargs):
+ """ EfficientNet-CondConv-B0 w/ 8 Experts """
+ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
+ model = _gen_efficientnet_condconv(
+ 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
+ """ EfficientNet-CondConv-B0 w/ 8 Experts """
+ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
+ model = _gen_efficientnet_condconv(
+ 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
+ pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
+ """ EfficientNet-CondConv-B1 w/ 8 Experts """
+ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
+ model = _gen_efficientnet_condconv(
+ 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
+ pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_lite0(pretrained=False, **kwargs):
+ """ EfficientNet-Lite0 """
+ model = _gen_efficientnet_lite(
+ 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_lite1(pretrained=False, **kwargs):
+ """ EfficientNet-Lite1 """
+ model = _gen_efficientnet_lite(
+ 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_lite2(pretrained=False, **kwargs):
+ """ EfficientNet-Lite2 """
+ model = _gen_efficientnet_lite(
+ 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_lite3(pretrained=False, **kwargs):
+ """ EfficientNet-Lite3 """
+ model = _gen_efficientnet_lite(
+ 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+def efficientnet_lite4(pretrained=False, **kwargs):
+ """ EfficientNet-Lite4 """
+ model = _gen_efficientnet_lite(
+ 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b0(pretrained=False, **kwargs):
+ """ EfficientNet-B0 AutoAug. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b1(pretrained=False, **kwargs):
+ """ EfficientNet-B1 AutoAug. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b2(pretrained=False, **kwargs):
+ """ EfficientNet-B2 AutoAug. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b3(pretrained=False, **kwargs):
+ """ EfficientNet-B3 AutoAug. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b4(pretrained=False, **kwargs):
+ """ EfficientNet-B4 AutoAug. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b5(pretrained=False, **kwargs):
+ """ EfficientNet-B5 RandAug. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b6(pretrained=False, **kwargs):
+ """ EfficientNet-B6 AutoAug. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b7(pretrained=False, **kwargs):
+ """ EfficientNet-B7 RandAug. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b8(pretrained=False, **kwargs):
+ """ EfficientNet-B8 RandAug. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b0_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B0 AdvProp. Tensorflow compatible variant
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b1_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B1 AdvProp. Tensorflow compatible variant
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b2_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B2 AdvProp. Tensorflow compatible variant
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b3_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B3 AdvProp. Tensorflow compatible variant
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b4_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B4 AdvProp. Tensorflow compatible variant
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b5_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B5 AdvProp. Tensorflow compatible variant
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b6_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B6 AdvProp. Tensorflow compatible variant
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
+ """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b7_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B7 AdvProp. Tensorflow compatible variant
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
+ """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b8_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B8 AdvProp. Tensorflow compatible variant
+ Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665)
+ """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b0_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B0 NoisyStudent. Tensorflow compatible variant
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b1_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B1 NoisyStudent. Tensorflow compatible variant
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b2_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B2 NoisyStudent. Tensorflow compatible variant
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b3_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B3 NoisyStudent. Tensorflow compatible variant
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b4_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B4 NoisyStudent. Tensorflow compatible variant
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b5_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B5 NoisyStudent. Tensorflow compatible variant
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b6_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B6 NoisyStudent. Tensorflow compatible variant
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
+ """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_b7_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B7 NoisyStudent. Tensorflow compatible variant
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
+ """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs):
+ """ EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
+ """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_l2_ns(pretrained=False, **kwargs):
+ """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant
+ Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252)
+ """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_es(pretrained=False, **kwargs):
+ """ EfficientNet-Edge Small. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_edge(
+ 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_em(pretrained=False, **kwargs):
+ """ EfficientNet-Edge-Medium. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_edge(
+ 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_el(pretrained=False, **kwargs):
+ """ EfficientNet-Edge-Large. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_edge(
+ 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
+ """ EfficientNet-CondConv-B0 w/ 4 Experts """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_condconv(
+ 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
+ """ EfficientNet-CondConv-B0 w/ 8 Experts """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_condconv(
+ 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
+ pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
+ """ EfficientNet-CondConv-B1 w/ 8 Experts """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_condconv(
+ 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
+ pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_lite0(pretrained=False, **kwargs):
+ """ EfficientNet-Lite0. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_lite(
+ 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_lite1(pretrained=False, **kwargs):
+ """ EfficientNet-Lite1. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_lite(
+ 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_lite2(pretrained=False, **kwargs):
+ """ EfficientNet-Lite2. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_lite(
+ 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_lite3(pretrained=False, **kwargs):
+ """ EfficientNet-Lite3. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_lite(
+ 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_efficientnet_lite4(pretrained=False, **kwargs):
+ """ EfficientNet-Lite4. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_lite(
+ 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mixnet_s(pretrained=False, **kwargs):
+ """Creates a MixNet Small model.
+ """
+ # NOTE for train set drop_rate=0.2
+ model = _gen_mixnet_s(
+ 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mixnet_m(pretrained=False, **kwargs):
+ """Creates a MixNet Medium model.
+ """
+ # NOTE for train set drop_rate=0.25
+ model = _gen_mixnet_m(
+ 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mixnet_l(pretrained=False, **kwargs):
+ """Creates a MixNet Large model.
+ """
+ # NOTE for train set drop_rate=0.25
+ model = _gen_mixnet_m(
+ 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mixnet_xl(pretrained=False, **kwargs):
+ """Creates a MixNet Extra-Large model.
+ Not a paper spec, experimental def by RW w/ depth scaling.
+ """
+ # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2
+ model = _gen_mixnet_m(
+ 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mixnet_xxl(pretrained=False, **kwargs):
+ """Creates a MixNet Double Extra Large model.
+ Not a paper spec, experimental def by RW w/ depth scaling.
+ """
+ # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2
+ model = _gen_mixnet_m(
+ 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_mixnet_s(pretrained=False, **kwargs):
+ """Creates a MixNet Small model. Tensorflow compatible variant
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mixnet_s(
+ 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_mixnet_m(pretrained=False, **kwargs):
+ """Creates a MixNet Medium model. Tensorflow compatible variant
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mixnet_m(
+ 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_mixnet_l(pretrained=False, **kwargs):
+ """Creates a MixNet Large model. Tensorflow compatible variant
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mixnet_m(
+ 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
+ return model
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/helpers.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f83a07d690c7ad681c777c19b1e7a5bb95da007
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/helpers.py
@@ -0,0 +1,71 @@
+""" Checkpoint loading / state_dict helpers
+Copyright 2020 Ross Wightman
+"""
+import torch
+import os
+from collections import OrderedDict
+try:
+ from torch.hub import load_state_dict_from_url
+except ImportError:
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+
+def load_checkpoint(model, checkpoint_path):
+ if checkpoint_path and os.path.isfile(checkpoint_path):
+ print("=> Loading checkpoint '{}'".format(checkpoint_path))
+ checkpoint = torch.load(checkpoint_path)
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
+ new_state_dict = OrderedDict()
+ for k, v in checkpoint['state_dict'].items():
+ if k.startswith('module'):
+ name = k[7:] # remove `module.`
+ else:
+ name = k
+ new_state_dict[name] = v
+ model.load_state_dict(new_state_dict)
+ else:
+ model.load_state_dict(checkpoint)
+ print("=> Loaded checkpoint '{}'".format(checkpoint_path))
+ else:
+ print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
+ raise FileNotFoundError()
+
+
+def load_pretrained(model, url, filter_fn=None, strict=True):
+ if not url:
+ print("=> Warning: Pretrained model URL is empty, using random initialization.")
+ return
+
+ state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu')
+
+ input_conv = 'conv_stem'
+ classifier = 'classifier'
+ in_chans = getattr(model, input_conv).weight.shape[1]
+ num_classes = getattr(model, classifier).weight.shape[0]
+
+ input_conv_weight = input_conv + '.weight'
+ pretrained_in_chans = state_dict[input_conv_weight].shape[1]
+ if in_chans != pretrained_in_chans:
+ if in_chans == 1:
+ print('=> Converting pretrained input conv {} from {} to 1 channel'.format(
+ input_conv_weight, pretrained_in_chans))
+ conv1_weight = state_dict[input_conv_weight]
+ state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True)
+ else:
+ print('=> Discarding pretrained input conv {} since input channel count != {}'.format(
+ input_conv_weight, pretrained_in_chans))
+ del state_dict[input_conv_weight]
+ strict = False
+
+ classifier_weight = classifier + '.weight'
+ pretrained_num_classes = state_dict[classifier_weight].shape[0]
+ if num_classes != pretrained_num_classes:
+ print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes))
+ del state_dict[classifier_weight]
+ del state_dict[classifier + '.bias']
+ strict = False
+
+ if filter_fn is not None:
+ state_dict = filter_fn(state_dict)
+
+ model.load_state_dict(state_dict, strict=strict)
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/mobilenetv3.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/mobilenetv3.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5966c28f7207e98ee50745b1bc8f3663c650f9d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/mobilenetv3.py
@@ -0,0 +1,364 @@
+""" MobileNet-V3
+
+A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
+
+Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .activations import get_act_fn, get_act_layer, HardSwish
+from .config import layer_config_kwargs
+from .conv2d_layers import select_conv2d
+from .helpers import load_pretrained
+from .efficientnet_builder import *
+
+__all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100',
+ 'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100',
+ 'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100',
+ 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100']
+
+model_urls = {
+ 'mobilenetv3_rw':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
+ 'mobilenetv3_large_075': None,
+ 'mobilenetv3_large_100':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
+ 'mobilenetv3_large_minimal_100': None,
+ 'mobilenetv3_small_075': None,
+ 'mobilenetv3_small_100': None,
+ 'mobilenetv3_small_minimal_100': None,
+ 'tf_mobilenetv3_large_075':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
+ 'tf_mobilenetv3_large_100':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
+ 'tf_mobilenetv3_large_minimal_100':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
+ 'tf_mobilenetv3_small_075':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
+ 'tf_mobilenetv3_small_100':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
+ 'tf_mobilenetv3_small_minimal_100':
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
+}
+
+
+class MobileNetV3(nn.Module):
+ """ MobileNet-V3
+
+ A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the
+ head convolution without a final batch-norm layer before the classifier.
+
+ Paper: https://arxiv.org/abs/1905.02244
+ """
+
+ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
+ channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0.,
+ se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'):
+ super(MobileNetV3, self).__init__()
+ self.drop_rate = drop_rate
+
+ stem_size = round_channels(stem_size, channel_multiplier)
+ self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
+ self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs)
+ self.act1 = act_layer(inplace=True)
+ in_chs = stem_size
+
+ builder = EfficientNetBuilder(
+ channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs,
+ norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate)
+ self.blocks = nn.Sequential(*builder(in_chs, block_args))
+ in_chs = builder.in_chs
+
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
+ self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias)
+ self.act2 = act_layer(inplace=True)
+ self.classifier = nn.Linear(num_features, num_classes)
+
+ for m in self.modules():
+ if weight_init == 'goog':
+ initialize_weight_goog(m)
+ else:
+ initialize_weight_default(m)
+
+ def as_sequential(self):
+ layers = [self.conv_stem, self.bn1, self.act1]
+ layers.extend(self.blocks)
+ layers.extend([
+ self.global_pool, self.conv_head, self.act2,
+ nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
+ return nn.Sequential(*layers)
+
+ def features(self, x):
+ x = self.conv_stem(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+ x = self.blocks(x)
+ x = self.global_pool(x)
+ x = self.conv_head(x)
+ x = self.act2(x)
+ return x
+
+ def forward(self, x):
+ x = self.features(x)
+ x = x.flatten(1)
+ if self.drop_rate > 0.:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ return self.classifier(x)
+
+
+def _create_model(model_kwargs, variant, pretrained=False):
+ as_sequential = model_kwargs.pop('as_sequential', False)
+ model = MobileNetV3(**model_kwargs)
+ if pretrained and model_urls[variant]:
+ load_pretrained(model, model_urls[variant])
+ if as_sequential:
+ model = model.as_sequential()
+ return model
+
+
+def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates a MobileNet-V3 model (RW variant).
+
+ Paper: https://arxiv.org/abs/1905.02244
+
+ This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the
+ eventual Tensorflow reference impl but has a few differences:
+ 1. This model has no bias on the head convolution
+ 2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet
+ 3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer
+ from their parent block
+ 4. This model does not enforce divisible by 8 limitation on the SE reduction channel count
+
+ Overall the changes are fairly minor and result in a very small parameter count difference and no
+ top-1/5
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer.
+ """
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu
+ # stage 1, 112x112 in
+ ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
+ # stage 2, 56x56 in
+ ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
+ # stage 3, 28x28 in
+ ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
+ # stage 4, 14x14in
+ ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
+ # stage 5, 14x14in
+ ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
+ # stage 6, 7x7 in
+ ['cn_r1_k1_s1_c960'], # hard-swish
+ ]
+ with layer_config_kwargs(kwargs):
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ head_bias=False, # one of my mistakes
+ channel_multiplier=channel_multiplier,
+ act_layer=resolve_act_layer(kwargs, 'hard_swish'),
+ se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True),
+ norm_kwargs=resolve_bn_args(kwargs),
+ **kwargs,
+ )
+ model = _create_model(model_kwargs, variant, pretrained)
+ return model
+
+
+def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates a MobileNet-V3 large/small/minimal models.
+
+ Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py
+ Paper: https://arxiv.org/abs/1905.02244
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer.
+ """
+ if 'small' in variant:
+ num_features = 1024
+ if 'minimal' in variant:
+ act_layer = 'relu'
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s2_e1_c16'],
+ # stage 1, 56x56 in
+ ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],
+ # stage 2, 28x28 in
+ ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],
+ # stage 3, 14x14 in
+ ['ir_r2_k3_s1_e3_c48'],
+ # stage 4, 14x14in
+ ['ir_r3_k3_s2_e6_c96'],
+ # stage 6, 7x7 in
+ ['cn_r1_k1_s1_c576'],
+ ]
+ else:
+ act_layer = 'hard_swish'
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu
+ # stage 1, 56x56 in
+ ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
+ # stage 2, 28x28 in
+ ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish
+ # stage 3, 14x14 in
+ ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish
+ # stage 4, 14x14in
+ ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish
+ # stage 6, 7x7 in
+ ['cn_r1_k1_s1_c576'], # hard-swish
+ ]
+ else:
+ num_features = 1280
+ if 'minimal' in variant:
+ act_layer = 'relu'
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_e1_c16'],
+ # stage 1, 112x112 in
+ ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],
+ # stage 2, 56x56 in
+ ['ir_r3_k3_s2_e3_c40'],
+ # stage 3, 28x28 in
+ ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],
+ # stage 4, 14x14in
+ ['ir_r2_k3_s1_e6_c112'],
+ # stage 5, 14x14in
+ ['ir_r3_k3_s2_e6_c160'],
+ # stage 6, 7x7 in
+ ['cn_r1_k1_s1_c960'],
+ ]
+ else:
+ act_layer = 'hard_swish'
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_e1_c16_nre'], # relu
+ # stage 1, 112x112 in
+ ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
+ # stage 2, 56x56 in
+ ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
+ # stage 3, 28x28 in
+ ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
+ # stage 4, 14x14in
+ ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
+ # stage 5, 14x14in
+ ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
+ # stage 6, 7x7 in
+ ['cn_r1_k1_s1_c960'], # hard-swish
+ ]
+ with layer_config_kwargs(kwargs):
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ num_features=num_features,
+ stem_size=16,
+ channel_multiplier=channel_multiplier,
+ act_layer=resolve_act_layer(kwargs, act_layer),
+ se_kwargs=dict(
+ act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8),
+ norm_kwargs=resolve_bn_args(kwargs),
+ **kwargs,
+ )
+ model = _create_model(model_kwargs, variant, pretrained)
+ return model
+
+
+def mobilenetv3_rw(pretrained=False, **kwargs):
+ """ MobileNet-V3 RW
+ Attn: See note in gen function for this variant.
+ """
+ # NOTE for train set drop_rate=0.2
+ if pretrained:
+ # pretrained model trained with non-default BN epsilon
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mobilenetv3_large_075(pretrained=False, **kwargs):
+ """ MobileNet V3 Large 0.75"""
+ # NOTE for train set drop_rate=0.2
+ model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mobilenetv3_large_100(pretrained=False, **kwargs):
+ """ MobileNet V3 Large 1.0 """
+ # NOTE for train set drop_rate=0.2
+ model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
+ """ MobileNet V3 Large (Minimalistic) 1.0 """
+ # NOTE for train set drop_rate=0.2
+ model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mobilenetv3_small_075(pretrained=False, **kwargs):
+ """ MobileNet V3 Small 0.75 """
+ model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mobilenetv3_small_100(pretrained=False, **kwargs):
+ """ MobileNet V3 Small 1.0 """
+ model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
+ """ MobileNet V3 Small (Minimalistic) 1.0 """
+ model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_mobilenetv3_large_075(pretrained=False, **kwargs):
+ """ MobileNet V3 Large 0.75. Tensorflow compat variant. """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_mobilenetv3_large_100(pretrained=False, **kwargs):
+ """ MobileNet V3 Large 1.0. Tensorflow compat variant. """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
+ """ MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_mobilenetv3_small_075(pretrained=False, **kwargs):
+ """ MobileNet V3 Small 0.75. Tensorflow compat variant. """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_mobilenetv3_small_100(pretrained=False, **kwargs):
+ """ MobileNet V3 Small 1.0. Tensorflow compat variant."""
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
+ """ MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/model_factory.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/model_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d46ea8baedaf3d787826eb3bb314b4230514647
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/model_factory.py
@@ -0,0 +1,27 @@
+from .config import set_layer_config
+from .helpers import load_checkpoint
+
+from .gen_efficientnet import *
+from .mobilenetv3 import *
+
+
+def create_model(
+ model_name='mnasnet_100',
+ pretrained=None,
+ num_classes=1000,
+ in_chans=3,
+ checkpoint_path='',
+ **kwargs):
+
+ model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs)
+
+ if model_name in globals():
+ create_fn = globals()[model_name]
+ model = create_fn(**model_kwargs)
+ else:
+ raise RuntimeError('Unknown model (%s)' % model_name)
+
+ if checkpoint_path and not pretrained:
+ load_checkpoint(model, checkpoint_path)
+
+ return model
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/version.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6221b3de7b1490c5e712e8b5fcc94c3d9d04295
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/geffnet/version.py
@@ -0,0 +1 @@
+__version__ = '1.0.2'
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/hubconf.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/hubconf.py
new file mode 100644
index 0000000000000000000000000000000000000000..45b17b99bbeba34596569e6e50f6e8a2ebc45c54
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/hubconf.py
@@ -0,0 +1,84 @@
+dependencies = ['torch', 'math']
+
+from geffnet import efficientnet_b0
+from geffnet import efficientnet_b1
+from geffnet import efficientnet_b2
+from geffnet import efficientnet_b3
+
+from geffnet import efficientnet_es
+
+from geffnet import efficientnet_lite0
+
+from geffnet import mixnet_s
+from geffnet import mixnet_m
+from geffnet import mixnet_l
+from geffnet import mixnet_xl
+
+from geffnet import mobilenetv2_100
+from geffnet import mobilenetv2_110d
+from geffnet import mobilenetv2_120d
+from geffnet import mobilenetv2_140
+
+from geffnet import mobilenetv3_large_100
+from geffnet import mobilenetv3_rw
+from geffnet import mnasnet_a1
+from geffnet import mnasnet_b1
+from geffnet import fbnetc_100
+from geffnet import spnasnet_100
+
+from geffnet import tf_efficientnet_b0
+from geffnet import tf_efficientnet_b1
+from geffnet import tf_efficientnet_b2
+from geffnet import tf_efficientnet_b3
+from geffnet import tf_efficientnet_b4
+from geffnet import tf_efficientnet_b5
+from geffnet import tf_efficientnet_b6
+from geffnet import tf_efficientnet_b7
+from geffnet import tf_efficientnet_b8
+
+from geffnet import tf_efficientnet_b0_ap
+from geffnet import tf_efficientnet_b1_ap
+from geffnet import tf_efficientnet_b2_ap
+from geffnet import tf_efficientnet_b3_ap
+from geffnet import tf_efficientnet_b4_ap
+from geffnet import tf_efficientnet_b5_ap
+from geffnet import tf_efficientnet_b6_ap
+from geffnet import tf_efficientnet_b7_ap
+from geffnet import tf_efficientnet_b8_ap
+
+from geffnet import tf_efficientnet_b0_ns
+from geffnet import tf_efficientnet_b1_ns
+from geffnet import tf_efficientnet_b2_ns
+from geffnet import tf_efficientnet_b3_ns
+from geffnet import tf_efficientnet_b4_ns
+from geffnet import tf_efficientnet_b5_ns
+from geffnet import tf_efficientnet_b6_ns
+from geffnet import tf_efficientnet_b7_ns
+from geffnet import tf_efficientnet_l2_ns_475
+from geffnet import tf_efficientnet_l2_ns
+
+from geffnet import tf_efficientnet_es
+from geffnet import tf_efficientnet_em
+from geffnet import tf_efficientnet_el
+
+from geffnet import tf_efficientnet_cc_b0_4e
+from geffnet import tf_efficientnet_cc_b0_8e
+from geffnet import tf_efficientnet_cc_b1_8e
+
+from geffnet import tf_efficientnet_lite0
+from geffnet import tf_efficientnet_lite1
+from geffnet import tf_efficientnet_lite2
+from geffnet import tf_efficientnet_lite3
+from geffnet import tf_efficientnet_lite4
+
+from geffnet import tf_mixnet_s
+from geffnet import tf_mixnet_m
+from geffnet import tf_mixnet_l
+
+from geffnet import tf_mobilenetv3_large_075
+from geffnet import tf_mobilenetv3_large_100
+from geffnet import tf_mobilenetv3_large_minimal_100
+from geffnet import tf_mobilenetv3_small_075
+from geffnet import tf_mobilenetv3_small_100
+from geffnet import tf_mobilenetv3_small_minimal_100
+
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/onnx_export.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/onnx_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a5162ce214830df501bdb81edb66c095122f69d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/onnx_export.py
@@ -0,0 +1,120 @@
+""" ONNX export script
+
+Export PyTorch models as ONNX graphs.
+
+This export script originally started as an adaptation of code snippets found at
+https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
+
+The default parameters work with PyTorch 1.6 and ONNX 1.7 and produce an optimal ONNX graph
+for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible
+with caffe2 (see caffe2_benchmark.py and caffe2_validate.py), the --keep-init and --aten-fallback
+flags are currently required.
+
+Older versions of PyTorch/ONNX (tested PyTorch 1.4, ONNX 1.5) do not need extra flags for
+caffe2 compatibility, but they produce a model that isn't as fast running on ONNX runtime.
+
+Most new release of PyTorch and ONNX cause some sort of breakage in the export / usage of ONNX models.
+Please do your research and search ONNX and PyTorch issue tracker before asking me. Thanks.
+
+Copyright 2020 Ross Wightman
+"""
+import argparse
+import torch
+import numpy as np
+
+import onnx
+import geffnet
+
+parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
+parser.add_argument('output', metavar='ONNX_FILE',
+ help='output model filename')
+parser.add_argument('--model', '-m', metavar='MODEL', default='mobilenetv3_large_100',
+ help='model architecture (default: mobilenetv3_large_100)')
+parser.add_argument('--opset', type=int, default=10,
+ help='ONNX opset to use (default: 10)')
+parser.add_argument('--keep-init', action='store_true', default=False,
+ help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.')
+parser.add_argument('--aten-fallback', action='store_true', default=False,
+ help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.')
+parser.add_argument('--dynamic-size', action='store_true', default=False,
+ help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.')
+parser.add_argument('-b', '--batch-size', default=1, type=int,
+ metavar='N', help='mini-batch size (default: 1)')
+parser.add_argument('--img-size', default=None, type=int,
+ metavar='N', help='Input image dimension, uses model default if empty')
+parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
+ help='Override mean pixel value of dataset')
+parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
+ help='Override std deviation of of dataset')
+parser.add_argument('--num-classes', type=int, default=1000,
+ help='Number classes in dataset')
+parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
+ help='path to checkpoint (default: none)')
+
+
+def main():
+ args = parser.parse_args()
+
+ args.pretrained = True
+ if args.checkpoint:
+ args.pretrained = False
+
+ print("==> Creating PyTorch {} model".format(args.model))
+ # NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers
+ # for models using SAME padding
+ model = geffnet.create_model(
+ args.model,
+ num_classes=args.num_classes,
+ in_chans=3,
+ pretrained=args.pretrained,
+ checkpoint_path=args.checkpoint,
+ exportable=True)
+
+ model.eval()
+
+ example_input = torch.randn((args.batch_size, 3, args.img_size or 224, args.img_size or 224), requires_grad=True)
+
+ # Run model once before export trace, sets padding for models with Conv2dSameExport. This means
+ # that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for
+ # the input img_size specified in this script.
+ # Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to
+ # issues in the tracing of the dynamic padding or errors attempting to export the model after jit
+ # scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions...
+ model(example_input)
+
+ print("==> Exporting model to ONNX format at '{}'".format(args.output))
+ input_names = ["input0"]
+ output_names = ["output0"]
+ dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}}
+ if args.dynamic_size:
+ dynamic_axes['input0'][2] = 'height'
+ dynamic_axes['input0'][3] = 'width'
+ if args.aten_fallback:
+ export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
+ else:
+ export_type = torch.onnx.OperatorExportTypes.ONNX
+
+ torch_out = torch.onnx._export(
+ model, example_input, args.output, export_params=True, verbose=True, input_names=input_names,
+ output_names=output_names, keep_initializers_as_inputs=args.keep_init, dynamic_axes=dynamic_axes,
+ opset_version=args.opset, operator_export_type=export_type)
+
+ print("==> Loading and checking exported model from '{}'".format(args.output))
+ onnx_model = onnx.load(args.output)
+ onnx.checker.check_model(onnx_model) # assuming throw on error
+ print("==> Passed")
+
+ if args.keep_init and args.aten_fallback:
+ import caffe2.python.onnx.backend as onnx_caffe2
+ # Caffe2 loading only works properly in newer PyTorch/ONNX combos when
+ # keep_initializers_as_inputs and aten_fallback are set to True.
+ print("==> Loading model into Caffe2 backend and comparing forward pass.".format(args.output))
+ caffe2_backend = onnx_caffe2.prepare(onnx_model)
+ B = {onnx_model.graph.input[0].name: x.data.numpy()}
+ c2_out = caffe2_backend.run(B)[0]
+ np.testing.assert_almost_equal(torch_out.data.numpy(), c2_out, decimal=5)
+ print("==> Passed")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/onnx_optimize.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/onnx_optimize.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee20bbf9f0f9473370489512eb96ca0b570b5388
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/onnx_optimize.py
@@ -0,0 +1,84 @@
+""" ONNX optimization script
+
+Run ONNX models through the optimizer to prune unneeded nodes, fuse batchnorm layers into conv, etc.
+
+NOTE: This isn't working consistently in recent PyTorch/ONNX combos (ie PyTorch 1.6 and ONNX 1.7),
+it seems time to switch to using the onnxruntime online optimizer (can also be saved for offline).
+
+Copyright 2020 Ross Wightman
+"""
+import argparse
+import warnings
+
+import onnx
+from onnx import optimizer
+
+
+parser = argparse.ArgumentParser(description="Optimize ONNX model")
+
+parser.add_argument("model", help="The ONNX model")
+parser.add_argument("--output", required=True, help="The optimized model output filename")
+
+
+def traverse_graph(graph, prefix=''):
+ content = []
+ indent = prefix + ' '
+ graphs = []
+ num_nodes = 0
+ for node in graph.node:
+ pn, gs = onnx.helper.printable_node(node, indent, subgraphs=True)
+ assert isinstance(gs, list)
+ content.append(pn)
+ graphs.extend(gs)
+ num_nodes += 1
+ for g in graphs:
+ g_count, g_str = traverse_graph(g)
+ content.append('\n' + g_str)
+ num_nodes += g_count
+ return num_nodes, '\n'.join(content)
+
+
+def main():
+ args = parser.parse_args()
+ onnx_model = onnx.load(args.model)
+ num_original_nodes, original_graph_str = traverse_graph(onnx_model.graph)
+
+ # Optimizer passes to perform
+ passes = [
+ #'eliminate_deadend',
+ 'eliminate_identity',
+ 'eliminate_nop_dropout',
+ 'eliminate_nop_pad',
+ 'eliminate_nop_transpose',
+ 'eliminate_unused_initializer',
+ 'extract_constant_to_initializer',
+ 'fuse_add_bias_into_conv',
+ 'fuse_bn_into_conv',
+ 'fuse_consecutive_concats',
+ 'fuse_consecutive_reduce_unsqueeze',
+ 'fuse_consecutive_squeezes',
+ 'fuse_consecutive_transposes',
+ #'fuse_matmul_add_bias_into_gemm',
+ 'fuse_pad_into_conv',
+ #'fuse_transpose_into_gemm',
+ #'lift_lexical_references',
+ ]
+
+ # Apply the optimization on the original serialized model
+ # WARNING I've had issues with optimizer in recent versions of PyTorch / ONNX causing
+ # 'duplicate definition of name' errors, see: https://github.com/onnx/onnx/issues/2401
+ # It may be better to rely on onnxruntime optimizations, see onnx_validate.py script.
+ warnings.warn("I've had issues with optimizer in recent versions of PyTorch / ONNX."
+ "Try onnxruntime optimization if this doesn't work.")
+ optimized_model = optimizer.optimize(onnx_model, passes)
+
+ num_optimized_nodes, optimzied_graph_str = traverse_graph(optimized_model.graph)
+ print('==> The model after optimization:\n{}\n'.format(optimzied_graph_str))
+ print('==> The optimized model has {} nodes, the original had {}.'.format(num_optimized_nodes, num_original_nodes))
+
+ # Save the ONNX model
+ onnx.save(optimized_model, args.output)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/onnx_to_caffe.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/onnx_to_caffe.py
new file mode 100644
index 0000000000000000000000000000000000000000..44399aafababcdf6b84147a0613eb0909730db4b
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/onnx_to_caffe.py
@@ -0,0 +1,27 @@
+import argparse
+
+import onnx
+from caffe2.python.onnx.backend import Caffe2Backend
+
+
+parser = argparse.ArgumentParser(description="Convert ONNX to Caffe2")
+
+parser.add_argument("model", help="The ONNX model")
+parser.add_argument("--c2-prefix", required=True,
+ help="The output file prefix for the caffe2 model init and predict file. ")
+
+
+def main():
+ args = parser.parse_args()
+ onnx_model = onnx.load(args.model)
+ caffe2_init, caffe2_predict = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model)
+ caffe2_init_str = caffe2_init.SerializeToString()
+ with open(args.c2_prefix + '.init.pb', "wb") as f:
+ f.write(caffe2_init_str)
+ caffe2_predict_str = caffe2_predict.SerializeToString()
+ with open(args.c2_prefix + '.predict.pb', "wb") as f:
+ f.write(caffe2_predict_str)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/onnx_validate.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/onnx_validate.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab3e4fb141b6ef660dcc5b447fd9f368a2ea19a0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/onnx_validate.py
@@ -0,0 +1,112 @@
+""" ONNX-runtime validation script
+
+This script was created to verify accuracy and performance of exported ONNX
+models running with the onnxruntime. It utilizes the PyTorch dataloader/processing
+pipeline for a fair comparison against the originals.
+
+Copyright 2020 Ross Wightman
+"""
+import argparse
+import numpy as np
+import onnxruntime
+from data import create_loader, resolve_data_config, Dataset
+from utils import AverageMeter
+import time
+
+parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation')
+parser.add_argument('data', metavar='DIR',
+ help='path to dataset')
+parser.add_argument('--onnx-input', default='', type=str, metavar='PATH',
+ help='path to onnx model/weights file')
+parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH',
+ help='path to output optimized onnx graph')
+parser.add_argument('--profile', action='store_true', default=False,
+ help='Enable profiler output.')
+parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
+ help='number of data loading workers (default: 2)')
+parser.add_argument('-b', '--batch-size', default=256, type=int,
+ metavar='N', help='mini-batch size (default: 256)')
+parser.add_argument('--img-size', default=None, type=int,
+ metavar='N', help='Input image dimension, uses model default if empty')
+parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
+ help='Override mean pixel value of dataset')
+parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
+ help='Override std deviation of of dataset')
+parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
+ help='Override default crop pct of 0.875')
+parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
+ help='Image resize interpolation type (overrides model)')
+parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
+ help='use tensorflow mnasnet preporcessing')
+parser.add_argument('--print-freq', '-p', default=10, type=int,
+ metavar='N', help='print frequency (default: 10)')
+
+
+def main():
+ args = parser.parse_args()
+ args.gpu_id = 0
+
+ # Set graph optimization level
+ sess_options = onnxruntime.SessionOptions()
+ sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
+ if args.profile:
+ sess_options.enable_profiling = True
+ if args.onnx_output_opt:
+ sess_options.optimized_model_filepath = args.onnx_output_opt
+
+ session = onnxruntime.InferenceSession(args.onnx_input, sess_options)
+
+ data_config = resolve_data_config(None, args)
+ loader = create_loader(
+ Dataset(args.data, load_bytes=args.tf_preprocessing),
+ input_size=data_config['input_size'],
+ batch_size=args.batch_size,
+ use_prefetcher=False,
+ interpolation=data_config['interpolation'],
+ mean=data_config['mean'],
+ std=data_config['std'],
+ num_workers=args.workers,
+ crop_pct=data_config['crop_pct'],
+ tensorflow_preprocessing=args.tf_preprocessing)
+
+ input_name = session.get_inputs()[0].name
+
+ batch_time = AverageMeter()
+ top1 = AverageMeter()
+ top5 = AverageMeter()
+ end = time.time()
+ for i, (input, target) in enumerate(loader):
+ # run the net and return prediction
+ output = session.run([], {input_name: input.data.numpy()})
+ output = output[0]
+
+ # measure accuracy and record loss
+ prec1, prec5 = accuracy_np(output, target.numpy())
+ top1.update(prec1.item(), input.size(0))
+ top5.update(prec5.item(), input.size(0))
+
+ # measure elapsed time
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ if i % args.print_freq == 0:
+ print('Test: [{0}/{1}]\t'
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t'
+ 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
+ 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
+ i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg,
+ ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5))
+
+ print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
+ top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
+
+
+def accuracy_np(output, target):
+ max_indices = np.argsort(output, axis=1)[:, ::-1]
+ top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean()
+ top1 = 100 * np.equal(max_indices[:, 0], target).mean()
+ return top1, top5
+
+
+if __name__ == '__main__':
+ main()
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/requirements.txt b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ac3ffc13bae15f9b11f7cbe3705760056ecd7f13
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/requirements.txt
@@ -0,0 +1,2 @@
+torch>=1.2.0
+torchvision>=0.4.0
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/setup.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..023e4c30f98164595964423e3a83eefaf7ffdad6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/setup.py
@@ -0,0 +1,47 @@
+""" Setup
+"""
+from setuptools import setup, find_packages
+from codecs import open
+from os import path
+
+here = path.abspath(path.dirname(__file__))
+
+# Get the long description from the README file
+with open(path.join(here, 'README.md'), encoding='utf-8') as f:
+ long_description = f.read()
+
+exec(open('geffnet/version.py').read())
+setup(
+ name='geffnet',
+ version=__version__,
+ description='(Generic) EfficientNets for PyTorch',
+ long_description=long_description,
+ long_description_content_type='text/markdown',
+ url='https://github.com/rwightman/gen-efficientnet-pytorch',
+ author='Ross Wightman',
+ author_email='hello@rwightman.com',
+ classifiers=[
+ # How mature is this project? Common values are
+ # 3 - Alpha
+ # 4 - Beta
+ # 5 - Production/Stable
+ 'Development Status :: 3 - Alpha',
+ 'Intended Audience :: Education',
+ 'Intended Audience :: Science/Research',
+ 'License :: OSI Approved :: Apache Software License',
+ 'Programming Language :: Python :: 3.6',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ 'Topic :: Scientific/Engineering',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ 'Topic :: Software Development',
+ 'Topic :: Software Development :: Libraries',
+ 'Topic :: Software Development :: Libraries :: Python Modules',
+ ],
+
+ # Note that this is a string of words separated by whitespace, not a list.
+ keywords='pytorch pretrained models efficientnet mixnet mobilenetv3 mnasnet',
+ packages=find_packages(exclude=['data']),
+ install_requires=['torch >= 1.4', 'torchvision'],
+ python_requires='>=3.6',
+)
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/utils.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d327e8bd8120c5cd09ae6c15c3991ccbe27f6c1f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/utils.py
@@ -0,0 +1,52 @@
+import os
+
+
+class AverageMeter:
+ """Computes and stores the average and current value"""
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def accuracy(output, target, topk=(1,)):
+ """Computes the precision@k for the specified values of k"""
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].reshape(-1).float().sum(0)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+def get_outdir(path, *paths, inc=False):
+ outdir = os.path.join(path, *paths)
+ if not os.path.exists(outdir):
+ os.makedirs(outdir)
+ elif inc:
+ count = 1
+ outdir_inc = outdir + '-' + str(count)
+ while os.path.exists(outdir_inc):
+ count = count + 1
+ outdir_inc = outdir + '-' + str(count)
+ assert count < 100
+ outdir = outdir_inc
+ os.makedirs(outdir)
+ return outdir
+
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/validate.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/validate.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fd44fbb3165ef81ef81251b6299f6aaa80bf2c2
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/efficientnet_repo/validate.py
@@ -0,0 +1,166 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import time
+import torch
+import torch.nn as nn
+import torch.nn.parallel
+from contextlib import suppress
+
+import geffnet
+from data import Dataset, create_loader, resolve_data_config
+from utils import accuracy, AverageMeter
+
+has_native_amp = False
+try:
+ if getattr(torch.cuda.amp, 'autocast') is not None:
+ has_native_amp = True
+except AttributeError:
+ pass
+
+torch.backends.cudnn.benchmark = True
+
+parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
+parser.add_argument('data', metavar='DIR',
+ help='path to dataset')
+parser.add_argument('--model', '-m', metavar='MODEL', default='spnasnet1_00',
+ help='model architecture (default: dpn92)')
+parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
+ help='number of data loading workers (default: 2)')
+parser.add_argument('-b', '--batch-size', default=256, type=int,
+ metavar='N', help='mini-batch size (default: 256)')
+parser.add_argument('--img-size', default=None, type=int,
+ metavar='N', help='Input image dimension, uses model default if empty')
+parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
+ help='Override mean pixel value of dataset')
+parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
+ help='Override std deviation of of dataset')
+parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
+ help='Override default crop pct of 0.875')
+parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
+ help='Image resize interpolation type (overrides model)')
+parser.add_argument('--num-classes', type=int, default=1000,
+ help='Number classes in dataset')
+parser.add_argument('--print-freq', '-p', default=10, type=int,
+ metavar='N', help='print frequency (default: 10)')
+parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
+ help='path to latest checkpoint (default: none)')
+parser.add_argument('--pretrained', dest='pretrained', action='store_true',
+ help='use pre-trained model')
+parser.add_argument('--torchscript', dest='torchscript', action='store_true',
+ help='convert model torchscript for inference')
+parser.add_argument('--num-gpu', type=int, default=1,
+ help='Number of GPUS to use')
+parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
+ help='use tensorflow mnasnet preporcessing')
+parser.add_argument('--no-cuda', dest='no_cuda', action='store_true',
+ help='')
+parser.add_argument('--channels-last', action='store_true', default=False,
+ help='Use channels_last memory layout')
+parser.add_argument('--amp', action='store_true', default=False,
+ help='Use native Torch AMP mixed precision.')
+
+
+def main():
+ args = parser.parse_args()
+
+ if not args.checkpoint and not args.pretrained:
+ args.pretrained = True
+
+ amp_autocast = suppress # do nothing
+ if args.amp:
+ if not has_native_amp:
+ print("Native Torch AMP is not available (requires torch >= 1.6), using FP32.")
+ else:
+ amp_autocast = torch.cuda.amp.autocast
+
+ # create model
+ model = geffnet.create_model(
+ args.model,
+ num_classes=args.num_classes,
+ in_chans=3,
+ pretrained=args.pretrained,
+ checkpoint_path=args.checkpoint,
+ scriptable=args.torchscript)
+
+ if args.channels_last:
+ model = model.to(memory_format=torch.channels_last)
+
+ if args.torchscript:
+ torch.jit.optimized_execution(True)
+ model = torch.jit.script(model)
+
+ print('Model %s created, param count: %d' %
+ (args.model, sum([m.numel() for m in model.parameters()])))
+
+ data_config = resolve_data_config(model, args)
+
+ criterion = nn.CrossEntropyLoss()
+
+ if not args.no_cuda:
+ if args.num_gpu > 1:
+ model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
+ else:
+ model = model.cuda()
+ criterion = criterion.cuda()
+
+ loader = create_loader(
+ Dataset(args.data, load_bytes=args.tf_preprocessing),
+ input_size=data_config['input_size'],
+ batch_size=args.batch_size,
+ use_prefetcher=not args.no_cuda,
+ interpolation=data_config['interpolation'],
+ mean=data_config['mean'],
+ std=data_config['std'],
+ num_workers=args.workers,
+ crop_pct=data_config['crop_pct'],
+ tensorflow_preprocessing=args.tf_preprocessing)
+
+ batch_time = AverageMeter()
+ losses = AverageMeter()
+ top1 = AverageMeter()
+ top5 = AverageMeter()
+
+ model.eval()
+ end = time.time()
+ with torch.no_grad():
+ for i, (input, target) in enumerate(loader):
+ if not args.no_cuda:
+ target = target.cuda()
+ input = input.cuda()
+ if args.channels_last:
+ input = input.contiguous(memory_format=torch.channels_last)
+
+ # compute output
+ with amp_autocast():
+ output = model(input)
+ loss = criterion(output, target)
+
+ # measure accuracy and record loss
+ prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
+ losses.update(loss.item(), input.size(0))
+ top1.update(prec1.item(), input.size(0))
+ top5.update(prec5.item(), input.size(0))
+
+ # measure elapsed time
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ if i % args.print_freq == 0:
+ print('Test: [{0}/{1}]\t'
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t'
+ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
+ 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
+ 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
+ i, len(loader), batch_time=batch_time,
+ rate_avg=input.size(0) / batch_time.avg,
+ loss=losses, top1=top1, top5=top5))
+
+ print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
+ top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/encoder.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f7149ca3c0cf2b6e019105af7e645cfbb3eda11
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/encoder.py
@@ -0,0 +1,34 @@
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Encoder(nn.Module):
+ def __init__(self):
+ super(Encoder, self).__init__()
+
+ basemodel_name = 'tf_efficientnet_b5_ap'
+ print('Loading base model ()...'.format(basemodel_name), end='')
+ repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo')
+ basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local')
+ print('Done.')
+
+ # Remove last layer
+ print('Removing last two layers (global_pool & classifier).')
+ basemodel.global_pool = nn.Identity()
+ basemodel.classifier = nn.Identity()
+
+ self.original_model = basemodel
+
+ def forward(self, x):
+ features = [x]
+ for k, v in self.original_model._modules.items():
+ if (k == 'blocks'):
+ for ki, vi in v._modules.items():
+ features.append(vi(features[-1]))
+ else:
+ features.append(v(features[-1]))
+ return features
+
+
diff --git a/sd-webui-controlnet/annotator/normalbae/models/submodules/submodules.py b/sd-webui-controlnet/annotator/normalbae/models/submodules/submodules.py
new file mode 100644
index 0000000000000000000000000000000000000000..409733351bd6ab5d191c800aff1bc05bfa4cb6f8
--- /dev/null
+++ b/sd-webui-controlnet/annotator/normalbae/models/submodules/submodules.py
@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+########################################################################################################################
+
+
+# Upsample + BatchNorm
+class UpSampleBN(nn.Module):
+ def __init__(self, skip_input, output_features):
+ super(UpSampleBN, self).__init__()
+
+ self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(output_features),
+ nn.LeakyReLU(),
+ nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(output_features),
+ nn.LeakyReLU())
+
+ def forward(self, x, concat_with):
+ up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
+ f = torch.cat([up_x, concat_with], dim=1)
+ return self._net(f)
+
+
+# Upsample + GroupNorm + Weight Standardization
+class UpSampleGN(nn.Module):
+ def __init__(self, skip_input, output_features):
+ super(UpSampleGN, self).__init__()
+
+ self._net = nn.Sequential(Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
+ nn.GroupNorm(8, output_features),
+ nn.LeakyReLU(),
+ Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
+ nn.GroupNorm(8, output_features),
+ nn.LeakyReLU())
+
+ def forward(self, x, concat_with):
+ up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
+ f = torch.cat([up_x, concat_with], dim=1)
+ return self._net(f)
+
+
+# Conv2d with weight standardization
+class Conv2d(nn.Conv2d):
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True):
+ super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
+ padding, dilation, groups, bias)
+
+ def forward(self, x):
+ weight = self.weight
+ weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
+ keepdim=True).mean(dim=3, keepdim=True)
+ weight = weight - weight_mean
+ std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
+ weight = weight / std.expand_as(weight)
+ return F.conv2d(x, weight, self.bias, self.stride,
+ self.padding, self.dilation, self.groups)
+
+
+# normalize
+def norm_normalize(norm_out):
+ min_kappa = 0.01
+ norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
+ norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10
+ kappa = F.elu(kappa) + 1.0 + min_kappa
+ final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
+ return final_out
+
+
+# uncertainty-guided sampling (only used during training)
+@torch.no_grad()
+def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta):
+ device = init_normal.device
+ B, _, H, W = init_normal.shape
+ N = int(sampling_ratio * H * W)
+ beta = beta
+
+ # uncertainty map
+ uncertainty_map = -1 * init_normal[:, 3, :, :] # B, H, W
+
+ # gt_invalid_mask (B, H, W)
+ if gt_norm_mask is not None:
+ gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
+ gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5
+ uncertainty_map[gt_invalid_mask] = -1e4
+
+ # (B, H*W)
+ _, idx = uncertainty_map.view(B, -1).sort(1, descending=True)
+
+ # importance sampling
+ if int(beta * N) > 0:
+ importance = idx[:, :int(beta * N)] # B, beta*N
+
+ # remaining
+ remaining = idx[:, int(beta * N):] # B, H*W - beta*N
+
+ # coverage
+ num_coverage = N - int(beta * N)
+
+ if num_coverage <= 0:
+ samples = importance
+ else:
+ coverage_list = []
+ for i in range(B):
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
+ samples = torch.cat((importance, coverage), dim=1) # B, N
+
+ else:
+ # remaining
+ remaining = idx[:, :] # B, H*W
+
+ # coverage
+ num_coverage = N
+
+ coverage_list = []
+ for i in range(B):
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
+ samples = coverage
+
+ # point coordinates
+ rows_int = samples // W # 0 for first row, H-1 for last row
+ rows_float = rows_int / float(H-1) # 0 to 1.0
+ rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0
+
+ cols_int = samples % W # 0 for first column, W-1 for last column
+ cols_float = cols_int / float(W-1) # 0 to 1.0
+ cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0
+
+ point_coords = torch.zeros(B, 1, N, 2)
+ point_coords[:, 0, :, 0] = cols_float # x coord
+ point_coords[:, 0, :, 1] = rows_float # y coord
+ point_coords = point_coords.to(device)
+ return point_coords, rows_int, cols_int
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/oneformer/LICENSE b/sd-webui-controlnet/annotator/oneformer/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..16a9d56a3d4c15e4f34ac5426459c58487b01520
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Caroline Chan
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/oneformer/__init__.py b/sd-webui-controlnet/annotator/oneformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..49eb9ec3490917a13c1a93d63c8b8ad44244926e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/__init__.py
@@ -0,0 +1,45 @@
+import os
+from modules import devices
+from annotator.annotator_path import models_path
+from .api import make_detectron2_model, semantic_run
+
+
+class OneformerDetector:
+ model_dir = os.path.join(models_path, "oneformer")
+ configs = {
+ "coco": {
+ "name": "150_16_swin_l_oneformer_coco_100ep.pth",
+ "config": 'configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml'
+ },
+ "ade20k": {
+ "name": "250_16_swin_l_oneformer_ade20k_160k.pth",
+ "config": 'configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml'
+ }
+ }
+
+ def __init__(self, config):
+ self.model = None
+ self.metadata = None
+ self.config = config
+ self.device = devices.get_device_for("controlnet")
+
+ def load_model(self):
+ remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/" + self.config["name"]
+ modelpath = os.path.join(self.model_dir, self.config["name"])
+ if not os.path.exists(modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir=self.model_dir)
+ config = os.path.join(os.path.dirname(__file__), self.config["config"])
+ model, self.metadata = make_detectron2_model(config, modelpath)
+ self.model = model
+
+ def unload_model(self):
+ if self.model is not None:
+ self.model.model.cpu()
+
+ def __call__(self, img):
+ if self.model is None:
+ self.load_model()
+
+ self.model.model.to(self.device)
+ return semantic_run(img, self.model, self.metadata)
diff --git a/sd-webui-controlnet/annotator/oneformer/api.py b/sd-webui-controlnet/annotator/oneformer/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..59e4439f10d537949180b8a9d1b2a0ee347b8ff3
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/api.py
@@ -0,0 +1,39 @@
+import os
+os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
+
+import torch
+
+from annotator.oneformer.detectron2.config import get_cfg
+from annotator.oneformer.detectron2.projects.deeplab import add_deeplab_config
+from annotator.oneformer.detectron2.data import MetadataCatalog
+
+from annotator.oneformer.oneformer import (
+ add_oneformer_config,
+ add_common_config,
+ add_swin_config,
+ add_dinat_config,
+)
+
+from annotator.oneformer.oneformer.demo.defaults import DefaultPredictor
+from annotator.oneformer.oneformer.demo.visualizer import Visualizer, ColorMode
+
+
+def make_detectron2_model(config_path, ckpt_path):
+ cfg = get_cfg()
+ add_deeplab_config(cfg)
+ add_common_config(cfg)
+ add_swin_config(cfg)
+ add_oneformer_config(cfg)
+ add_dinat_config(cfg)
+ cfg.merge_from_file(config_path)
+ cfg.MODEL.WEIGHTS = ckpt_path
+ cfg.freeze()
+ metadata = MetadataCatalog.get(cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused")
+ return DefaultPredictor(cfg), metadata
+
+
+def semantic_run(img, predictor, metadata):
+ predictions = predictor(img[:, :, ::-1], "semantic") # Predictor of OneFormer must use BGR image !!!
+ visualizer_map = Visualizer(img, is_img=False, metadata=metadata, instance_mode=ColorMode.IMAGE)
+ out_map = visualizer_map.draw_sem_seg(predictions["sem_seg"].argmax(dim=0).cpu(), alpha=1, is_text=False).get_image()
+ return out_map
diff --git a/sd-webui-controlnet/annotator/oneformer/configs/ade20k/Base-ADE20K-UnifiedSegmentation.yaml b/sd-webui-controlnet/annotator/oneformer/configs/ade20k/Base-ADE20K-UnifiedSegmentation.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..31eab45b878433fc844a13dbdd54f97c936d9b89
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/configs/ade20k/Base-ADE20K-UnifiedSegmentation.yaml
@@ -0,0 +1,68 @@
+MODEL:
+ BACKBONE:
+ FREEZE_AT: 0
+ NAME: "build_resnet_backbone"
+ WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ RESNETS:
+ DEPTH: 50
+ STEM_TYPE: "basic" # not used
+ STEM_OUT_CHANNELS: 64
+ STRIDE_IN_1X1: False
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
+ # NORM: "SyncBN"
+ RES5_MULTI_GRID: [1, 1, 1] # not used
+DATASETS:
+ TRAIN: ("ade20k_panoptic_train",)
+ TEST_PANOPTIC: ("ade20k_panoptic_val",)
+ TEST_INSTANCE: ("ade20k_instance_val",)
+ TEST_SEMANTIC: ("ade20k_sem_seg_val",)
+SOLVER:
+ IMS_PER_BATCH: 16
+ BASE_LR: 0.0001
+ MAX_ITER: 160000
+ WARMUP_FACTOR: 1.0
+ WARMUP_ITERS: 0
+ WEIGHT_DECAY: 0.05
+ OPTIMIZER: "ADAMW"
+ LR_SCHEDULER_NAME: "WarmupPolyLR"
+ BACKBONE_MULTIPLIER: 0.1
+ CLIP_GRADIENTS:
+ ENABLED: True
+ CLIP_TYPE: "full_model"
+ CLIP_VALUE: 0.01
+ NORM_TYPE: 2.0
+ AMP:
+ ENABLED: True
+INPUT:
+ MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 512) for x in range(5, 21)]"]
+ MIN_SIZE_TRAIN_SAMPLING: "choice"
+ MIN_SIZE_TEST: 512
+ MAX_SIZE_TRAIN: 2048
+ MAX_SIZE_TEST: 2048
+ CROP:
+ ENABLED: True
+ TYPE: "absolute"
+ SIZE: (512, 512)
+ SINGLE_CATEGORY_MAX_AREA: 1.0
+ COLOR_AUG_SSD: True
+ SIZE_DIVISIBILITY: 512 # used in dataset mapper
+ FORMAT: "RGB"
+ DATASET_MAPPER_NAME: "oneformer_unified"
+ MAX_SEQ_LEN: 77
+ TASK_SEQ_LEN: 77
+ TASK_PROB:
+ SEMANTIC: 0.33
+ INSTANCE: 0.66
+TEST:
+ EVAL_PERIOD: 5000
+ AUG:
+ ENABLED: False
+ MIN_SIZES: [256, 384, 512, 640, 768, 896]
+ MAX_SIZE: 3584
+ FLIP: True
+DATALOADER:
+ FILTER_EMPTY_ANNOTATIONS: True
+ NUM_WORKERS: 4
+VERSION: 2
\ No newline at end of file
diff --git a/sd-webui-controlnet/annotator/oneformer/configs/ade20k/oneformer_R50_bs16_160k.yaml b/sd-webui-controlnet/annotator/oneformer/configs/ade20k/oneformer_R50_bs16_160k.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..770ffc81907f8d7c7520e079b1c46060707254b8
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/configs/ade20k/oneformer_R50_bs16_160k.yaml
@@ -0,0 +1,58 @@
+_BASE_: Base-ADE20K-UnifiedSegmentation.yaml
+MODEL:
+ META_ARCHITECTURE: "OneFormer"
+ SEM_SEG_HEAD:
+ NAME: "OneFormerHead"
+ IGNORE_VALUE: 255
+ NUM_CLASSES: 150
+ LOSS_WEIGHT: 1.0
+ CONVS_DIM: 256
+ MASK_DIM: 256
+ NORM: "GN"
+ # pixel decoder
+ PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder"
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
+ COMMON_STRIDE: 4
+ TRANSFORMER_ENC_LAYERS: 6
+ ONE_FORMER:
+ TRANSFORMER_DECODER_NAME: "ContrastiveMultiScaleMaskedTransformerDecoder"
+ TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
+ DEEP_SUPERVISION: True
+ NO_OBJECT_WEIGHT: 0.1
+ CLASS_WEIGHT: 2.0
+ MASK_WEIGHT: 5.0
+ DICE_WEIGHT: 5.0
+ CONTRASTIVE_WEIGHT: 0.5
+ CONTRASTIVE_TEMPERATURE: 0.07
+ HIDDEN_DIM: 256
+ NUM_OBJECT_QUERIES: 150
+ USE_TASK_NORM: True
+ NHEADS: 8
+ DROPOUT: 0.1
+ DIM_FEEDFORWARD: 2048
+ ENC_LAYERS: 0
+ PRE_NORM: False
+ ENFORCE_INPUT_PROJ: False
+ SIZE_DIVISIBILITY: 32
+ CLASS_DEC_LAYERS: 2
+ DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
+ TRAIN_NUM_POINTS: 12544
+ OVERSAMPLE_RATIO: 3.0
+ IMPORTANCE_SAMPLE_RATIO: 0.75
+ TEXT_ENCODER:
+ WIDTH: 256
+ CONTEXT_LENGTH: 77
+ NUM_LAYERS: 6
+ VOCAB_SIZE: 49408
+ PROJ_NUM_LAYERS: 2
+ N_CTX: 16
+ TEST:
+ SEMANTIC_ON: True
+ INSTANCE_ON: True
+ PANOPTIC_ON: True
+ OVERLAP_THRESHOLD: 0.8
+ OBJECT_MASK_THRESHOLD: 0.8
+ TASK: "panoptic"
+TEST:
+ DETECTIONS_PER_IMAGE: 150
diff --git a/sd-webui-controlnet/annotator/oneformer/configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml b/sd-webui-controlnet/annotator/oneformer/configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..69c44ade144e4504077c0fe04fa8bb3491a679ed
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml
@@ -0,0 +1,40 @@
+_BASE_: oneformer_R50_bs16_160k.yaml
+MODEL:
+ BACKBONE:
+ NAME: "D2SwinTransformer"
+ SWIN:
+ EMBED_DIM: 192
+ DEPTHS: [2, 2, 18, 2]
+ NUM_HEADS: [6, 12, 24, 48]
+ WINDOW_SIZE: 12
+ APE: False
+ DROP_PATH_RATE: 0.3
+ PATCH_NORM: True
+ PRETRAIN_IMG_SIZE: 384
+ WEIGHTS: "swin_large_patch4_window12_384_22k.pkl"
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ ONE_FORMER:
+ NUM_OBJECT_QUERIES: 250
+INPUT:
+ MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"]
+ MIN_SIZE_TRAIN_SAMPLING: "choice"
+ MIN_SIZE_TEST: 640
+ MAX_SIZE_TRAIN: 2560
+ MAX_SIZE_TEST: 2560
+ CROP:
+ ENABLED: True
+ TYPE: "absolute"
+ SIZE: (640, 640)
+ SINGLE_CATEGORY_MAX_AREA: 1.0
+ COLOR_AUG_SSD: True
+ SIZE_DIVISIBILITY: 640 # used in dataset mapper
+ FORMAT: "RGB"
+TEST:
+ DETECTIONS_PER_IMAGE: 250
+ EVAL_PERIOD: 5000
+ AUG:
+ ENABLED: False
+ MIN_SIZES: [320, 480, 640, 800, 960, 1120]
+ MAX_SIZE: 4480
+ FLIP: True
diff --git a/sd-webui-controlnet/annotator/oneformer/configs/coco/Base-COCO-UnifiedSegmentation.yaml b/sd-webui-controlnet/annotator/oneformer/configs/coco/Base-COCO-UnifiedSegmentation.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ccd24f348f9bc7d60dcdc4b74d887708e57cb8a8
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/configs/coco/Base-COCO-UnifiedSegmentation.yaml
@@ -0,0 +1,54 @@
+MODEL:
+ BACKBONE:
+ FREEZE_AT: 0
+ NAME: "build_resnet_backbone"
+ WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ RESNETS:
+ DEPTH: 50
+ STEM_TYPE: "basic" # not used
+ STEM_OUT_CHANNELS: 64
+ STRIDE_IN_1X1: False
+ OUT_FEATURES: ["res2", "res3", "res4", "res5"]
+ # NORM: "SyncBN"
+ RES5_MULTI_GRID: [1, 1, 1] # not used
+DATASETS:
+ TRAIN: ("coco_2017_train_panoptic_with_sem_seg",)
+ TEST_PANOPTIC: ("coco_2017_val_panoptic_with_sem_seg",) # to evaluate instance and semantic performance as well
+ TEST_INSTANCE: ("coco_2017_val",)
+ TEST_SEMANTIC: ("coco_2017_val_panoptic_with_sem_seg",)
+SOLVER:
+ IMS_PER_BATCH: 16
+ BASE_LR: 0.0001
+ STEPS: (327778, 355092)
+ MAX_ITER: 368750
+ WARMUP_FACTOR: 1.0
+ WARMUP_ITERS: 10
+ WEIGHT_DECAY: 0.05
+ OPTIMIZER: "ADAMW"
+ BACKBONE_MULTIPLIER: 0.1
+ CLIP_GRADIENTS:
+ ENABLED: True
+ CLIP_TYPE: "full_model"
+ CLIP_VALUE: 0.01
+ NORM_TYPE: 2.0
+ AMP:
+ ENABLED: True
+INPUT:
+ IMAGE_SIZE: 1024
+ MIN_SCALE: 0.1
+ MAX_SCALE: 2.0
+ FORMAT: "RGB"
+ DATASET_MAPPER_NAME: "coco_unified_lsj"
+ MAX_SEQ_LEN: 77
+ TASK_SEQ_LEN: 77
+ TASK_PROB:
+ SEMANTIC: 0.33
+ INSTANCE: 0.66
+TEST:
+ EVAL_PERIOD: 5000
+DATALOADER:
+ FILTER_EMPTY_ANNOTATIONS: True
+ NUM_WORKERS: 4
+VERSION: 2
diff --git a/sd-webui-controlnet/annotator/oneformer/configs/coco/oneformer_R50_bs16_50ep.yaml b/sd-webui-controlnet/annotator/oneformer/configs/coco/oneformer_R50_bs16_50ep.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f768c8fa8b5e4fc1121e65e050053e0d8870cd73
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/configs/coco/oneformer_R50_bs16_50ep.yaml
@@ -0,0 +1,59 @@
+_BASE_: Base-COCO-UnifiedSegmentation.yaml
+MODEL:
+ META_ARCHITECTURE: "OneFormer"
+ SEM_SEG_HEAD:
+ NAME: "OneFormerHead"
+ IGNORE_VALUE: 255
+ NUM_CLASSES: 133
+ LOSS_WEIGHT: 1.0
+ CONVS_DIM: 256
+ MASK_DIM: 256
+ NORM: "GN"
+ # pixel decoder
+ PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder"
+ IN_FEATURES: ["res2", "res3", "res4", "res5"]
+ DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
+ COMMON_STRIDE: 4
+ TRANSFORMER_ENC_LAYERS: 6
+ ONE_FORMER:
+ TRANSFORMER_DECODER_NAME: "ContrastiveMultiScaleMaskedTransformerDecoder"
+ TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
+ DEEP_SUPERVISION: True
+ NO_OBJECT_WEIGHT: 0.1
+ CLASS_WEIGHT: 2.0
+ MASK_WEIGHT: 5.0
+ DICE_WEIGHT: 5.0
+ CONTRASTIVE_WEIGHT: 0.5
+ CONTRASTIVE_TEMPERATURE: 0.07
+ HIDDEN_DIM: 256
+ NUM_OBJECT_QUERIES: 150
+ USE_TASK_NORM: True
+ NHEADS: 8
+ DROPOUT: 0.1
+ DIM_FEEDFORWARD: 2048
+ ENC_LAYERS: 0
+ PRE_NORM: False
+ ENFORCE_INPUT_PROJ: False
+ SIZE_DIVISIBILITY: 32
+ CLASS_DEC_LAYERS: 2
+ DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
+ TRAIN_NUM_POINTS: 12544
+ OVERSAMPLE_RATIO: 3.0
+ IMPORTANCE_SAMPLE_RATIO: 0.75
+ TEXT_ENCODER:
+ WIDTH: 256
+ CONTEXT_LENGTH: 77
+ NUM_LAYERS: 6
+ VOCAB_SIZE: 49408
+ PROJ_NUM_LAYERS: 2
+ N_CTX: 16
+ TEST:
+ SEMANTIC_ON: True
+ INSTANCE_ON: True
+ PANOPTIC_ON: True
+ DETECTION_ON: False
+ OVERLAP_THRESHOLD: 0.8
+ OBJECT_MASK_THRESHOLD: 0.8
+ TASK: "panoptic"
+TEST:
+ DETECTIONS_PER_IMAGE: 150
diff --git a/sd-webui-controlnet/annotator/oneformer/configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml b/sd-webui-controlnet/annotator/oneformer/configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..faae655317c52d90b9f756417f8b1a1adcbe78f2
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml
@@ -0,0 +1,25 @@
+_BASE_: oneformer_R50_bs16_50ep.yaml
+MODEL:
+ BACKBONE:
+ NAME: "D2SwinTransformer"
+ SWIN:
+ EMBED_DIM: 192
+ DEPTHS: [2, 2, 18, 2]
+ NUM_HEADS: [6, 12, 24, 48]
+ WINDOW_SIZE: 12
+ APE: False
+ DROP_PATH_RATE: 0.3
+ PATCH_NORM: True
+ PRETRAIN_IMG_SIZE: 384
+ WEIGHTS: "swin_large_patch4_window12_384_22k.pkl"
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
+ PIXEL_STD: [58.395, 57.120, 57.375]
+ ONE_FORMER:
+ NUM_OBJECT_QUERIES: 150
+SOLVER:
+ STEPS: (655556, 735184)
+ MAX_ITER: 737500
+ AMP:
+ ENABLED: False
+TEST:
+ DETECTIONS_PER_IMAGE: 150
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/__init__.py b/sd-webui-controlnet/annotator/oneformer/detectron2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdd994b49294485c27610772f97f177741f5518f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+from .utils.env import setup_environment
+
+setup_environment()
+
+
+# This line will be programatically read/write by setup.py.
+# Leave them at the bottom of this file and don't touch them.
+__version__ = "0.6"
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/checkpoint/__init__.py b/sd-webui-controlnet/annotator/oneformer/detectron2/checkpoint/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..99da0469ae7e169d8970e4b642fed3f870076860
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/checkpoint/__init__.py
@@ -0,0 +1,10 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+# File:
+
+
+from . import catalog as _UNUSED # register the handler
+from .detection_checkpoint import DetectionCheckpointer
+from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
+
+__all__ = ["Checkpointer", "PeriodicCheckpointer", "DetectionCheckpointer"]
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/checkpoint/c2_model_loading.py b/sd-webui-controlnet/annotator/oneformer/detectron2/checkpoint/c2_model_loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6de2a3c830089aa7a0d27df96bb4a45fc5a7b0d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/checkpoint/c2_model_loading.py
@@ -0,0 +1,412 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import copy
+import logging
+import re
+from typing import Dict, List
+import torch
+from tabulate import tabulate
+
+
+def convert_basic_c2_names(original_keys):
+ """
+ Apply some basic name conversion to names in C2 weights.
+ It only deals with typical backbone models.
+
+ Args:
+ original_keys (list[str]):
+ Returns:
+ list[str]: The same number of strings matching those in original_keys.
+ """
+ layer_keys = copy.deepcopy(original_keys)
+ layer_keys = [
+ {"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys
+ ] # some hard-coded mappings
+
+ layer_keys = [k.replace("_", ".") for k in layer_keys]
+ layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys]
+ layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys]
+ # Uniform both bn and gn names to "norm"
+ layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys]
+ layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys]
+ layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys]
+ layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys]
+ layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys]
+ layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys]
+ layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys]
+ layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys]
+ layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys]
+ layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys]
+
+ # stem
+ layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys]
+ # to avoid mis-matching with "conv1" in other components (e.g. detection head)
+ layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys]
+
+ # layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5)
+ # layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys]
+ # layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys]
+ # layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys]
+ # layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys]
+
+ # blocks
+ layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys]
+ layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
+ layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
+ layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
+
+ # DensePose substitutions
+ layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys]
+ layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys]
+ layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys]
+ layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys]
+ layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys]
+ return layer_keys
+
+
+def convert_c2_detectron_names(weights):
+ """
+ Map Caffe2 Detectron weight names to Detectron2 names.
+
+ Args:
+ weights (dict): name -> tensor
+
+ Returns:
+ dict: detectron2 names -> tensor
+ dict: detectron2 names -> C2 names
+ """
+ logger = logging.getLogger(__name__)
+ logger.info("Renaming Caffe2 weights ......")
+ original_keys = sorted(weights.keys())
+ layer_keys = copy.deepcopy(original_keys)
+
+ layer_keys = convert_basic_c2_names(layer_keys)
+
+ # --------------------------------------------------------------------------
+ # RPN hidden representation conv
+ # --------------------------------------------------------------------------
+ # FPN case
+ # In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
+ # shared for all other levels, hence the appearance of "fpn2"
+ layer_keys = [
+ k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys
+ ]
+ # Non-FPN case
+ layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys]
+
+ # --------------------------------------------------------------------------
+ # RPN box transformation conv
+ # --------------------------------------------------------------------------
+ # FPN case (see note above about "fpn2")
+ layer_keys = [
+ k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas")
+ for k in layer_keys
+ ]
+ layer_keys = [
+ k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits")
+ for k in layer_keys
+ ]
+ # Non-FPN case
+ layer_keys = [
+ k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys
+ ]
+ layer_keys = [
+ k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits")
+ for k in layer_keys
+ ]
+
+ # --------------------------------------------------------------------------
+ # Fast R-CNN box head
+ # --------------------------------------------------------------------------
+ layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys]
+ layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys]
+ layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys]
+ layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys]
+ # 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
+ layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys]
+
+ # --------------------------------------------------------------------------
+ # FPN lateral and output convolutions
+ # --------------------------------------------------------------------------
+ def fpn_map(name):
+ """
+ Look for keys with the following patterns:
+ 1) Starts with "fpn.inner."
+ Example: "fpn.inner.res2.2.sum.lateral.weight"
+ Meaning: These are lateral pathway convolutions
+ 2) Starts with "fpn.res"
+ Example: "fpn.res2.2.sum.weight"
+ Meaning: These are FPN output convolutions
+ """
+ splits = name.split(".")
+ norm = ".norm" if "norm" in splits else ""
+ if name.startswith("fpn.inner."):
+ # splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
+ stage = int(splits[2][len("res") :])
+ return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
+ elif name.startswith("fpn.res"):
+ # splits example: ['fpn', 'res2', '2', 'sum', 'weight']
+ stage = int(splits[1][len("res") :])
+ return "fpn_output{}{}.{}".format(stage, norm, splits[-1])
+ return name
+
+ layer_keys = [fpn_map(k) for k in layer_keys]
+
+ # --------------------------------------------------------------------------
+ # Mask R-CNN mask head
+ # --------------------------------------------------------------------------
+ # roi_heads.StandardROIHeads case
+ layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys]
+ layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys]
+ layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys]
+ # roi_heads.Res5ROIHeads case
+ layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys]
+
+ # --------------------------------------------------------------------------
+ # Keypoint R-CNN head
+ # --------------------------------------------------------------------------
+ # interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
+ layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys]
+ layer_keys = [
+ k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys
+ ]
+ layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys]
+
+ # --------------------------------------------------------------------------
+ # Done with replacements
+ # --------------------------------------------------------------------------
+ assert len(set(layer_keys)) == len(layer_keys)
+ assert len(original_keys) == len(layer_keys)
+
+ new_weights = {}
+ new_keys_to_original_keys = {}
+ for orig, renamed in zip(original_keys, layer_keys):
+ new_keys_to_original_keys[renamed] = orig
+ if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."):
+ # remove the meaningless prediction weight for background class
+ new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1
+ new_weights[renamed] = weights[orig][new_start_idx:]
+ logger.info(
+ "Remove prediction weight for background class in {}. The shape changes from "
+ "{} to {}.".format(
+ renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape)
+ )
+ )
+ elif renamed.startswith("cls_score."):
+ # move weights of bg class from original index 0 to last index
+ logger.info(
+ "Move classification weights for background class in {} from index 0 to "
+ "index {}.".format(renamed, weights[orig].shape[0] - 1)
+ )
+ new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]])
+ else:
+ new_weights[renamed] = weights[orig]
+
+ return new_weights, new_keys_to_original_keys
+
+
+# Note the current matching is not symmetric.
+# it assumes model_state_dict will have longer names.
+def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, c2_conversion=True):
+ """
+ Match names between the two state-dict, and returns a new chkpt_state_dict with names
+ converted to match model_state_dict with heuristics. The returned dict can be later
+ loaded with fvcore checkpointer.
+ If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2
+ model and will be renamed at first.
+
+ Strategy: suppose that the models that we will create will have prefixes appended
+ to each of its keys, for example due to an extra level of nesting that the original
+ pre-trained weights from ImageNet won't contain. For example, model.state_dict()
+ might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
+ res2.conv1.weight. We thus want to match both parameters together.
+ For that, we look for each model weight, look among all loaded keys if there is one
+ that is a suffix of the current weight name, and use it if that's the case.
+ If multiple matches exist, take the one with longest size
+ of the corresponding name. For example, for the same model as before, the pretrained
+ weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
+ we want to match backbone[0].body.conv1.weight to conv1.weight, and
+ backbone[0].body.res2.conv1.weight to res2.conv1.weight.
+ """
+ model_keys = sorted(model_state_dict.keys())
+ if c2_conversion:
+ ckpt_state_dict, original_keys = convert_c2_detectron_names(ckpt_state_dict)
+ # original_keys: the name in the original dict (before renaming)
+ else:
+ original_keys = {x: x for x in ckpt_state_dict.keys()}
+ ckpt_keys = sorted(ckpt_state_dict.keys())
+
+ def match(a, b):
+ # Matched ckpt_key should be a complete (starts with '.') suffix.
+ # For example, roi_heads.mesh_head.whatever_conv1 does not match conv1,
+ # but matches whatever_conv1 or mesh_head.whatever_conv1.
+ return a == b or a.endswith("." + b)
+
+ # get a matrix of string matches, where each (i, j) entry correspond to the size of the
+ # ckpt_key string, if it matches
+ match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys]
+ match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys))
+ # use the matched one with longest size in case of multiple matches
+ max_match_size, idxs = match_matrix.max(1)
+ # remove indices that correspond to no-match
+ idxs[max_match_size == 0] = -1
+
+ logger = logging.getLogger(__name__)
+ # matched_pairs (matched checkpoint key --> matched model key)
+ matched_keys = {}
+ result_state_dict = {}
+ for idx_model, idx_ckpt in enumerate(idxs.tolist()):
+ if idx_ckpt == -1:
+ continue
+ key_model = model_keys[idx_model]
+ key_ckpt = ckpt_keys[idx_ckpt]
+ value_ckpt = ckpt_state_dict[key_ckpt]
+ shape_in_model = model_state_dict[key_model].shape
+
+ if shape_in_model != value_ckpt.shape:
+ logger.warning(
+ "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
+ key_ckpt, value_ckpt.shape, key_model, shape_in_model
+ )
+ )
+ logger.warning(
+ "{} will not be loaded. Please double check and see if this is desired.".format(
+ key_ckpt
+ )
+ )
+ continue
+
+ assert key_model not in result_state_dict
+ result_state_dict[key_model] = value_ckpt
+ if key_ckpt in matched_keys: # already added to matched_keys
+ logger.error(
+ "Ambiguity found for {} in checkpoint!"
+ "It matches at least two keys in the model ({} and {}).".format(
+ key_ckpt, key_model, matched_keys[key_ckpt]
+ )
+ )
+ raise ValueError("Cannot match one checkpoint key to multiple keys in the model.")
+
+ matched_keys[key_ckpt] = key_model
+
+ # logging:
+ matched_model_keys = sorted(matched_keys.values())
+ if len(matched_model_keys) == 0:
+ logger.warning("No weights in checkpoint matched with model.")
+ return ckpt_state_dict
+ common_prefix = _longest_common_prefix(matched_model_keys)
+ rev_matched_keys = {v: k for k, v in matched_keys.items()}
+ original_keys = {k: original_keys[rev_matched_keys[k]] for k in matched_model_keys}
+
+ model_key_groups = _group_keys_by_module(matched_model_keys, original_keys)
+ table = []
+ memo = set()
+ for key_model in matched_model_keys:
+ if key_model in memo:
+ continue
+ if key_model in model_key_groups:
+ group = model_key_groups[key_model]
+ memo |= set(group)
+ shapes = [tuple(model_state_dict[k].shape) for k in group]
+ table.append(
+ (
+ _longest_common_prefix([k[len(common_prefix) :] for k in group]) + "*",
+ _group_str([original_keys[k] for k in group]),
+ " ".join([str(x).replace(" ", "") for x in shapes]),
+ )
+ )
+ else:
+ key_checkpoint = original_keys[key_model]
+ shape = str(tuple(model_state_dict[key_model].shape))
+ table.append((key_model[len(common_prefix) :], key_checkpoint, shape))
+ table_str = tabulate(
+ table, tablefmt="pipe", headers=["Names in Model", "Names in Checkpoint", "Shapes"]
+ )
+ logger.info(
+ "Following weights matched with "
+ + (f"submodule {common_prefix[:-1]}" if common_prefix else "model")
+ + ":\n"
+ + table_str
+ )
+
+ unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set(matched_keys.keys())]
+ for k in unmatched_ckpt_keys:
+ result_state_dict[k] = ckpt_state_dict[k]
+ return result_state_dict
+
+
+def _group_keys_by_module(keys: List[str], original_names: Dict[str, str]):
+ """
+ Params in the same submodule are grouped together.
+
+ Args:
+ keys: names of all parameters
+ original_names: mapping from parameter name to their name in the checkpoint
+
+ Returns:
+ dict[name -> all other names in the same group]
+ """
+
+ def _submodule_name(key):
+ pos = key.rfind(".")
+ if pos < 0:
+ return None
+ prefix = key[: pos + 1]
+ return prefix
+
+ all_submodules = [_submodule_name(k) for k in keys]
+ all_submodules = [x for x in all_submodules if x]
+ all_submodules = sorted(all_submodules, key=len)
+
+ ret = {}
+ for prefix in all_submodules:
+ group = [k for k in keys if k.startswith(prefix)]
+ if len(group) <= 1:
+ continue
+ original_name_lcp = _longest_common_prefix_str([original_names[k] for k in group])
+ if len(original_name_lcp) == 0:
+ # don't group weights if original names don't share prefix
+ continue
+
+ for k in group:
+ if k in ret:
+ continue
+ ret[k] = group
+ return ret
+
+
+def _longest_common_prefix(names: List[str]) -> str:
+ """
+ ["abc.zfg", "abc.zef"] -> "abc."
+ """
+ names = [n.split(".") for n in names]
+ m1, m2 = min(names), max(names)
+ ret = [a for a, b in zip(m1, m2) if a == b]
+ ret = ".".join(ret) + "." if len(ret) else ""
+ return ret
+
+
+def _longest_common_prefix_str(names: List[str]) -> str:
+ m1, m2 = min(names), max(names)
+ lcp = []
+ for a, b in zip(m1, m2):
+ if a == b:
+ lcp.append(a)
+ else:
+ break
+ lcp = "".join(lcp)
+ return lcp
+
+
+def _group_str(names: List[str]) -> str:
+ """
+ Turn "common1", "common2", "common3" into "common{1,2,3}"
+ """
+ lcp = _longest_common_prefix_str(names)
+ rest = [x[len(lcp) :] for x in names]
+ rest = "{" + ",".join(rest) + "}"
+ ret = lcp + rest
+
+ # add some simplification for BN specifically
+ ret = ret.replace("bn_{beta,running_mean,running_var,gamma}", "bn_*")
+ ret = ret.replace("bn_beta,bn_running_mean,bn_running_var,bn_gamma", "bn_*")
+ return ret
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/checkpoint/catalog.py b/sd-webui-controlnet/annotator/oneformer/detectron2/checkpoint/catalog.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5641858fea4936ad10b07a4237faba78dda77ff
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/checkpoint/catalog.py
@@ -0,0 +1,115 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+
+from annotator.oneformer.detectron2.utils.file_io import PathHandler, PathManager
+
+
+class ModelCatalog(object):
+ """
+ Store mappings from names to third-party models.
+ """
+
+ S3_C2_DETECTRON_PREFIX = "https://dl.fbaipublicfiles.com/detectron"
+
+ # MSRA models have STRIDE_IN_1X1=True. False otherwise.
+ # NOTE: all BN models here have fused BN into an affine layer.
+ # As a result, you should only load them to a model with "FrozenBN".
+ # Loading them to a model with regular BN or SyncBN is wrong.
+ # Even when loaded to FrozenBN, it is still different from affine by an epsilon,
+ # which should be negligible for training.
+ # NOTE: all models here uses PIXEL_STD=[1,1,1]
+ # NOTE: Most of the BN models here are no longer used. We use the
+ # re-converted pre-trained models under detectron2 model zoo instead.
+ C2_IMAGENET_MODELS = {
+ "MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl",
+ "MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl",
+ "FAIR/R-50-GN": "ImageNetPretrained/47261647/R-50-GN.pkl",
+ "FAIR/R-101-GN": "ImageNetPretrained/47592356/R-101-GN.pkl",
+ "FAIR/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl",
+ "FAIR/X-101-64x4d": "ImageNetPretrained/FBResNeXt/X-101-64x4d.pkl",
+ "FAIR/X-152-32x8d-IN5k": "ImageNetPretrained/25093814/X-152-32x8d-IN5k.pkl",
+ }
+
+ C2_DETECTRON_PATH_FORMAT = (
+ "{prefix}/{url}/output/train/{dataset}/{type}/model_final.pkl" # noqa B950
+ )
+
+ C2_DATASET_COCO = "coco_2014_train%3Acoco_2014_valminusminival"
+ C2_DATASET_COCO_KEYPOINTS = "keypoints_coco_2014_train%3Akeypoints_coco_2014_valminusminival"
+
+ # format: {model_name} -> part of the url
+ C2_DETECTRON_MODELS = {
+ "35857197/e2e_faster_rcnn_R-50-C4_1x": "35857197/12_2017_baselines/e2e_faster_rcnn_R-50-C4_1x.yaml.01_33_49.iAX0mXvW", # noqa B950
+ "35857345/e2e_faster_rcnn_R-50-FPN_1x": "35857345/12_2017_baselines/e2e_faster_rcnn_R-50-FPN_1x.yaml.01_36_30.cUF7QR7I", # noqa B950
+ "35857890/e2e_faster_rcnn_R-101-FPN_1x": "35857890/12_2017_baselines/e2e_faster_rcnn_R-101-FPN_1x.yaml.01_38_50.sNxI7sX7", # noqa B950
+ "36761737/e2e_faster_rcnn_X-101-32x8d-FPN_1x": "36761737/12_2017_baselines/e2e_faster_rcnn_X-101-32x8d-FPN_1x.yaml.06_31_39.5MIHi1fZ", # noqa B950
+ "35858791/e2e_mask_rcnn_R-50-C4_1x": "35858791/12_2017_baselines/e2e_mask_rcnn_R-50-C4_1x.yaml.01_45_57.ZgkA7hPB", # noqa B950
+ "35858933/e2e_mask_rcnn_R-50-FPN_1x": "35858933/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_1x.yaml.01_48_14.DzEQe4wC", # noqa B950
+ "35861795/e2e_mask_rcnn_R-101-FPN_1x": "35861795/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_1x.yaml.02_31_37.KqyEK4tT", # noqa B950
+ "36761843/e2e_mask_rcnn_X-101-32x8d-FPN_1x": "36761843/12_2017_baselines/e2e_mask_rcnn_X-101-32x8d-FPN_1x.yaml.06_35_59.RZotkLKI", # noqa B950
+ "48616381/e2e_mask_rcnn_R-50-FPN_2x_gn": "GN/48616381/04_2018_gn_baselines/e2e_mask_rcnn_R-50-FPN_2x_gn_0416.13_23_38.bTlTI97Q", # noqa B950
+ "37697547/e2e_keypoint_rcnn_R-50-FPN_1x": "37697547/12_2017_baselines/e2e_keypoint_rcnn_R-50-FPN_1x.yaml.08_42_54.kdzV35ao", # noqa B950
+ "35998355/rpn_R-50-C4_1x": "35998355/12_2017_baselines/rpn_R-50-C4_1x.yaml.08_00_43.njH5oD9L", # noqa B950
+ "35998814/rpn_R-50-FPN_1x": "35998814/12_2017_baselines/rpn_R-50-FPN_1x.yaml.08_06_03.Axg0r179", # noqa B950
+ "36225147/fast_R-50-FPN_1x": "36225147/12_2017_baselines/fast_rcnn_R-50-FPN_1x.yaml.08_39_09.L3obSdQ2", # noqa B950
+ }
+
+ @staticmethod
+ def get(name):
+ if name.startswith("Caffe2Detectron/COCO"):
+ return ModelCatalog._get_c2_detectron_baseline(name)
+ if name.startswith("ImageNetPretrained/"):
+ return ModelCatalog._get_c2_imagenet_pretrained(name)
+ raise RuntimeError("model not present in the catalog: {}".format(name))
+
+ @staticmethod
+ def _get_c2_imagenet_pretrained(name):
+ prefix = ModelCatalog.S3_C2_DETECTRON_PREFIX
+ name = name[len("ImageNetPretrained/") :]
+ name = ModelCatalog.C2_IMAGENET_MODELS[name]
+ url = "/".join([prefix, name])
+ return url
+
+ @staticmethod
+ def _get_c2_detectron_baseline(name):
+ name = name[len("Caffe2Detectron/COCO/") :]
+ url = ModelCatalog.C2_DETECTRON_MODELS[name]
+ if "keypoint_rcnn" in name:
+ dataset = ModelCatalog.C2_DATASET_COCO_KEYPOINTS
+ else:
+ dataset = ModelCatalog.C2_DATASET_COCO
+
+ if "35998355/rpn_R-50-C4_1x" in name:
+ # this one model is somehow different from others ..
+ type = "rpn"
+ else:
+ type = "generalized_rcnn"
+
+ # Detectron C2 models are stored in the structure defined in `C2_DETECTRON_PATH_FORMAT`.
+ url = ModelCatalog.C2_DETECTRON_PATH_FORMAT.format(
+ prefix=ModelCatalog.S3_C2_DETECTRON_PREFIX, url=url, type=type, dataset=dataset
+ )
+ return url
+
+
+class ModelCatalogHandler(PathHandler):
+ """
+ Resolve URL like catalog://.
+ """
+
+ PREFIX = "catalog://"
+
+ def _get_supported_prefixes(self):
+ return [self.PREFIX]
+
+ def _get_local_path(self, path, **kwargs):
+ logger = logging.getLogger(__name__)
+ catalog_path = ModelCatalog.get(path[len(self.PREFIX) :])
+ logger.info("Catalog entry {} points to {}".format(path, catalog_path))
+ return PathManager.get_local_path(catalog_path, **kwargs)
+
+ def _open(self, path, mode="r", **kwargs):
+ return PathManager.open(self._get_local_path(path), mode, **kwargs)
+
+
+PathManager.register_handler(ModelCatalogHandler())
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/checkpoint/detection_checkpoint.py b/sd-webui-controlnet/annotator/oneformer/detectron2/checkpoint/detection_checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d411e54bd5e004504423ba052db6f85ec511f72
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/checkpoint/detection_checkpoint.py
@@ -0,0 +1,145 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+import os
+import pickle
+from urllib.parse import parse_qs, urlparse
+import torch
+from fvcore.common.checkpoint import Checkpointer
+from torch.nn.parallel import DistributedDataParallel
+
+import annotator.oneformer.detectron2.utils.comm as comm
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+from .c2_model_loading import align_and_update_state_dicts
+
+
+class DetectionCheckpointer(Checkpointer):
+ """
+ Same as :class:`Checkpointer`, but is able to:
+ 1. handle models in detectron & detectron2 model zoo, and apply conversions for legacy models.
+ 2. correctly load checkpoints that are only available on the master worker
+ """
+
+ def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables):
+ is_main_process = comm.is_main_process()
+ super().__init__(
+ model,
+ save_dir,
+ save_to_disk=is_main_process if save_to_disk is None else save_to_disk,
+ **checkpointables,
+ )
+ self.path_manager = PathManager
+ self._parsed_url_during_load = None
+
+ def load(self, path, *args, **kwargs):
+ assert self._parsed_url_during_load is None
+ need_sync = False
+ logger = logging.getLogger(__name__)
+ logger.info("[DetectionCheckpointer] Loading from {} ...".format(path))
+
+ if path and isinstance(self.model, DistributedDataParallel):
+ path = self.path_manager.get_local_path(path)
+ has_file = os.path.isfile(path)
+ all_has_file = comm.all_gather(has_file)
+ if not all_has_file[0]:
+ raise OSError(f"File {path} not found on main worker.")
+ if not all(all_has_file):
+ logger.warning(
+ f"Not all workers can read checkpoint {path}. "
+ "Training may fail to fully resume."
+ )
+ # TODO: broadcast the checkpoint file contents from main
+ # worker, and load from it instead.
+ need_sync = True
+ if not has_file:
+ path = None # don't load if not readable
+
+ if path:
+ parsed_url = urlparse(path)
+ self._parsed_url_during_load = parsed_url
+ path = parsed_url._replace(query="").geturl() # remove query from filename
+ path = self.path_manager.get_local_path(path)
+
+ self.logger.setLevel('CRITICAL')
+ ret = super().load(path, *args, **kwargs)
+
+ if need_sync:
+ logger.info("Broadcasting model states from main worker ...")
+ self.model._sync_params_and_buffers()
+ self._parsed_url_during_load = None # reset to None
+ return ret
+
+ def _load_file(self, filename):
+ if filename.endswith(".pkl"):
+ with PathManager.open(filename, "rb") as f:
+ data = pickle.load(f, encoding="latin1")
+ if "model" in data and "__author__" in data:
+ # file is in Detectron2 model zoo format
+ self.logger.info("Reading a file from '{}'".format(data["__author__"]))
+ return data
+ else:
+ # assume file is from Caffe2 / Detectron1 model zoo
+ if "blobs" in data:
+ # Detection models have "blobs", but ImageNet models don't
+ data = data["blobs"]
+ data = {k: v for k, v in data.items() if not k.endswith("_momentum")}
+ return {"model": data, "__author__": "Caffe2", "matching_heuristics": True}
+ elif filename.endswith(".pyth"):
+ # assume file is from pycls; no one else seems to use the ".pyth" extension
+ with PathManager.open(filename, "rb") as f:
+ data = torch.load(f)
+ assert (
+ "model_state" in data
+ ), f"Cannot load .pyth file {filename}; pycls checkpoints must contain 'model_state'."
+ model_state = {
+ k: v
+ for k, v in data["model_state"].items()
+ if not k.endswith("num_batches_tracked")
+ }
+ return {"model": model_state, "__author__": "pycls", "matching_heuristics": True}
+
+ loaded = self._torch_load(filename)
+ if "model" not in loaded:
+ loaded = {"model": loaded}
+ assert self._parsed_url_during_load is not None, "`_load_file` must be called inside `load`"
+ parsed_url = self._parsed_url_during_load
+ queries = parse_qs(parsed_url.query)
+ if queries.pop("matching_heuristics", "False") == ["True"]:
+ loaded["matching_heuristics"] = True
+ if len(queries) > 0:
+ raise ValueError(
+ f"Unsupported query remaining: f{queries}, orginal filename: {parsed_url.geturl()}"
+ )
+ return loaded
+
+ def _torch_load(self, f):
+ return super()._load_file(f)
+
+ def _load_model(self, checkpoint):
+ if checkpoint.get("matching_heuristics", False):
+ self._convert_ndarray_to_tensor(checkpoint["model"])
+ # convert weights by name-matching heuristics
+ checkpoint["model"] = align_and_update_state_dicts(
+ self.model.state_dict(),
+ checkpoint["model"],
+ c2_conversion=checkpoint.get("__author__", None) == "Caffe2",
+ )
+ # for non-caffe2 models, use standard ways to load it
+ incompatible = super()._load_model(checkpoint)
+
+ model_buffers = dict(self.model.named_buffers(recurse=False))
+ for k in ["pixel_mean", "pixel_std"]:
+ # Ignore missing key message about pixel_mean/std.
+ # Though they may be missing in old checkpoints, they will be correctly
+ # initialized from config anyway.
+ if k in model_buffers:
+ try:
+ incompatible.missing_keys.remove(k)
+ except ValueError:
+ pass
+ for k in incompatible.unexpected_keys[:]:
+ # Ignore unexpected keys about cell anchors. They exist in old checkpoints
+ # but now they are non-persistent buffers and will not be in new checkpoints.
+ if "anchor_generator.cell_anchors" in k:
+ incompatible.unexpected_keys.remove(k)
+ return incompatible
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/config/__init__.py b/sd-webui-controlnet/annotator/oneformer/detectron2/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a78ed118685fcfd869f7a72caf6b94621530196a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/config/__init__.py
@@ -0,0 +1,24 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from .compat import downgrade_config, upgrade_config
+from .config import CfgNode, get_cfg, global_cfg, set_global_cfg, configurable
+from .instantiate import instantiate
+from .lazy import LazyCall, LazyConfig
+
+__all__ = [
+ "CfgNode",
+ "get_cfg",
+ "global_cfg",
+ "set_global_cfg",
+ "downgrade_config",
+ "upgrade_config",
+ "configurable",
+ "instantiate",
+ "LazyCall",
+ "LazyConfig",
+]
+
+
+from annotator.oneformer.detectron2.utils.env import fixup_module_metadata
+
+fixup_module_metadata(__name__, globals(), __all__)
+del fixup_module_metadata
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/config/compat.py b/sd-webui-controlnet/annotator/oneformer/detectron2/config/compat.py
new file mode 100644
index 0000000000000000000000000000000000000000..11a08c439bf14defd880e37a938fab8a08e68eeb
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/config/compat.py
@@ -0,0 +1,229 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+"""
+Backward compatibility of configs.
+
+Instructions to bump version:
++ It's not needed to bump version if new keys are added.
+ It's only needed when backward-incompatible changes happen
+ (i.e., some existing keys disappear, or the meaning of a key changes)
++ To bump version, do the following:
+ 1. Increment _C.VERSION in defaults.py
+ 2. Add a converter in this file.
+
+ Each ConverterVX has a function "upgrade" which in-place upgrades config from X-1 to X,
+ and a function "downgrade" which in-place downgrades config from X to X-1
+
+ In each function, VERSION is left unchanged.
+
+ Each converter assumes that its input has the relevant keys
+ (i.e., the input is not a partial config).
+ 3. Run the tests (test_config.py) to make sure the upgrade & downgrade
+ functions are consistent.
+"""
+
+import logging
+from typing import List, Optional, Tuple
+
+from .config import CfgNode as CN
+from .defaults import _C
+
+__all__ = ["upgrade_config", "downgrade_config"]
+
+
+def upgrade_config(cfg: CN, to_version: Optional[int] = None) -> CN:
+ """
+ Upgrade a config from its current version to a newer version.
+
+ Args:
+ cfg (CfgNode):
+ to_version (int): defaults to the latest version.
+ """
+ cfg = cfg.clone()
+ if to_version is None:
+ to_version = _C.VERSION
+
+ assert cfg.VERSION <= to_version, "Cannot upgrade from v{} to v{}!".format(
+ cfg.VERSION, to_version
+ )
+ for k in range(cfg.VERSION, to_version):
+ converter = globals()["ConverterV" + str(k + 1)]
+ converter.upgrade(cfg)
+ cfg.VERSION = k + 1
+ return cfg
+
+
+def downgrade_config(cfg: CN, to_version: int) -> CN:
+ """
+ Downgrade a config from its current version to an older version.
+
+ Args:
+ cfg (CfgNode):
+ to_version (int):
+
+ Note:
+ A general downgrade of arbitrary configs is not always possible due to the
+ different functionalities in different versions.
+ The purpose of downgrade is only to recover the defaults in old versions,
+ allowing it to load an old partial yaml config.
+ Therefore, the implementation only needs to fill in the default values
+ in the old version when a general downgrade is not possible.
+ """
+ cfg = cfg.clone()
+ assert cfg.VERSION >= to_version, "Cannot downgrade from v{} to v{}!".format(
+ cfg.VERSION, to_version
+ )
+ for k in range(cfg.VERSION, to_version, -1):
+ converter = globals()["ConverterV" + str(k)]
+ converter.downgrade(cfg)
+ cfg.VERSION = k - 1
+ return cfg
+
+
+def guess_version(cfg: CN, filename: str) -> int:
+ """
+ Guess the version of a partial config where the VERSION field is not specified.
+ Returns the version, or the latest if cannot make a guess.
+
+ This makes it easier for users to migrate.
+ """
+ logger = logging.getLogger(__name__)
+
+ def _has(name: str) -> bool:
+ cur = cfg
+ for n in name.split("."):
+ if n not in cur:
+ return False
+ cur = cur[n]
+ return True
+
+ # Most users' partial configs have "MODEL.WEIGHT", so guess on it
+ ret = None
+ if _has("MODEL.WEIGHT") or _has("TEST.AUG_ON"):
+ ret = 1
+
+ if ret is not None:
+ logger.warning("Config '{}' has no VERSION. Assuming it to be v{}.".format(filename, ret))
+ else:
+ ret = _C.VERSION
+ logger.warning(
+ "Config '{}' has no VERSION. Assuming it to be compatible with latest v{}.".format(
+ filename, ret
+ )
+ )
+ return ret
+
+
+def _rename(cfg: CN, old: str, new: str) -> None:
+ old_keys = old.split(".")
+ new_keys = new.split(".")
+
+ def _set(key_seq: List[str], val: str) -> None:
+ cur = cfg
+ for k in key_seq[:-1]:
+ if k not in cur:
+ cur[k] = CN()
+ cur = cur[k]
+ cur[key_seq[-1]] = val
+
+ def _get(key_seq: List[str]) -> CN:
+ cur = cfg
+ for k in key_seq:
+ cur = cur[k]
+ return cur
+
+ def _del(key_seq: List[str]) -> None:
+ cur = cfg
+ for k in key_seq[:-1]:
+ cur = cur[k]
+ del cur[key_seq[-1]]
+ if len(cur) == 0 and len(key_seq) > 1:
+ _del(key_seq[:-1])
+
+ _set(new_keys, _get(old_keys))
+ _del(old_keys)
+
+
+class _RenameConverter:
+ """
+ A converter that handles simple rename.
+ """
+
+ RENAME: List[Tuple[str, str]] = [] # list of tuples of (old name, new name)
+
+ @classmethod
+ def upgrade(cls, cfg: CN) -> None:
+ for old, new in cls.RENAME:
+ _rename(cfg, old, new)
+
+ @classmethod
+ def downgrade(cls, cfg: CN) -> None:
+ for old, new in cls.RENAME[::-1]:
+ _rename(cfg, new, old)
+
+
+class ConverterV1(_RenameConverter):
+ RENAME = [("MODEL.RPN_HEAD.NAME", "MODEL.RPN.HEAD_NAME")]
+
+
+class ConverterV2(_RenameConverter):
+ """
+ A large bulk of rename, before public release.
+ """
+
+ RENAME = [
+ ("MODEL.WEIGHT", "MODEL.WEIGHTS"),
+ ("MODEL.PANOPTIC_FPN.SEMANTIC_LOSS_SCALE", "MODEL.SEM_SEG_HEAD.LOSS_WEIGHT"),
+ ("MODEL.PANOPTIC_FPN.RPN_LOSS_SCALE", "MODEL.RPN.LOSS_WEIGHT"),
+ ("MODEL.PANOPTIC_FPN.INSTANCE_LOSS_SCALE", "MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT"),
+ ("MODEL.PANOPTIC_FPN.COMBINE_ON", "MODEL.PANOPTIC_FPN.COMBINE.ENABLED"),
+ (
+ "MODEL.PANOPTIC_FPN.COMBINE_OVERLAP_THRESHOLD",
+ "MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH",
+ ),
+ (
+ "MODEL.PANOPTIC_FPN.COMBINE_STUFF_AREA_LIMIT",
+ "MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT",
+ ),
+ (
+ "MODEL.PANOPTIC_FPN.COMBINE_INSTANCES_CONFIDENCE_THRESHOLD",
+ "MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH",
+ ),
+ ("MODEL.ROI_HEADS.SCORE_THRESH", "MODEL.ROI_HEADS.SCORE_THRESH_TEST"),
+ ("MODEL.ROI_HEADS.NMS", "MODEL.ROI_HEADS.NMS_THRESH_TEST"),
+ ("MODEL.RETINANET.INFERENCE_SCORE_THRESHOLD", "MODEL.RETINANET.SCORE_THRESH_TEST"),
+ ("MODEL.RETINANET.INFERENCE_TOPK_CANDIDATES", "MODEL.RETINANET.TOPK_CANDIDATES_TEST"),
+ ("MODEL.RETINANET.INFERENCE_NMS_THRESHOLD", "MODEL.RETINANET.NMS_THRESH_TEST"),
+ ("TEST.DETECTIONS_PER_IMG", "TEST.DETECTIONS_PER_IMAGE"),
+ ("TEST.AUG_ON", "TEST.AUG.ENABLED"),
+ ("TEST.AUG_MIN_SIZES", "TEST.AUG.MIN_SIZES"),
+ ("TEST.AUG_MAX_SIZE", "TEST.AUG.MAX_SIZE"),
+ ("TEST.AUG_FLIP", "TEST.AUG.FLIP"),
+ ]
+
+ @classmethod
+ def upgrade(cls, cfg: CN) -> None:
+ super().upgrade(cfg)
+
+ if cfg.MODEL.META_ARCHITECTURE == "RetinaNet":
+ _rename(
+ cfg, "MODEL.RETINANET.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS"
+ )
+ _rename(cfg, "MODEL.RETINANET.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES")
+ del cfg["MODEL"]["RPN"]["ANCHOR_SIZES"]
+ del cfg["MODEL"]["RPN"]["ANCHOR_ASPECT_RATIOS"]
+ else:
+ _rename(cfg, "MODEL.RPN.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS")
+ _rename(cfg, "MODEL.RPN.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES")
+ del cfg["MODEL"]["RETINANET"]["ANCHOR_SIZES"]
+ del cfg["MODEL"]["RETINANET"]["ANCHOR_ASPECT_RATIOS"]
+ del cfg["MODEL"]["RETINANET"]["ANCHOR_STRIDES"]
+
+ @classmethod
+ def downgrade(cls, cfg: CN) -> None:
+ super().downgrade(cfg)
+
+ _rename(cfg, "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS", "MODEL.RPN.ANCHOR_ASPECT_RATIOS")
+ _rename(cfg, "MODEL.ANCHOR_GENERATOR.SIZES", "MODEL.RPN.ANCHOR_SIZES")
+ cfg.MODEL.RETINANET.ANCHOR_ASPECT_RATIOS = cfg.MODEL.RPN.ANCHOR_ASPECT_RATIOS
+ cfg.MODEL.RETINANET.ANCHOR_SIZES = cfg.MODEL.RPN.ANCHOR_SIZES
+ cfg.MODEL.RETINANET.ANCHOR_STRIDES = [] # this is not used anywhere in any version
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/config/config.py b/sd-webui-controlnet/annotator/oneformer/detectron2/config/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5b1303422481dc7adb3ee5221377770e0c01a81
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/config/config.py
@@ -0,0 +1,265 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import functools
+import inspect
+import logging
+from fvcore.common.config import CfgNode as _CfgNode
+
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+
+class CfgNode(_CfgNode):
+ """
+ The same as `fvcore.common.config.CfgNode`, but different in:
+
+ 1. Use unsafe yaml loading by default.
+ Note that this may lead to arbitrary code execution: you must not
+ load a config file from untrusted sources before manually inspecting
+ the content of the file.
+ 2. Support config versioning.
+ When attempting to merge an old config, it will convert the old config automatically.
+
+ .. automethod:: clone
+ .. automethod:: freeze
+ .. automethod:: defrost
+ .. automethod:: is_frozen
+ .. automethod:: load_yaml_with_base
+ .. automethod:: merge_from_list
+ .. automethod:: merge_from_other_cfg
+ """
+
+ @classmethod
+ def _open_cfg(cls, filename):
+ return PathManager.open(filename, "r")
+
+ # Note that the default value of allow_unsafe is changed to True
+ def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None:
+ """
+ Load content from the given config file and merge it into self.
+
+ Args:
+ cfg_filename: config filename
+ allow_unsafe: allow unsafe yaml syntax
+ """
+ assert PathManager.isfile(cfg_filename), f"Config file '{cfg_filename}' does not exist!"
+ loaded_cfg = self.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe)
+ loaded_cfg = type(self)(loaded_cfg)
+
+ # defaults.py needs to import CfgNode
+ from .defaults import _C
+
+ latest_ver = _C.VERSION
+ assert (
+ latest_ver == self.VERSION
+ ), "CfgNode.merge_from_file is only allowed on a config object of latest version!"
+
+ logger = logging.getLogger(__name__)
+
+ loaded_ver = loaded_cfg.get("VERSION", None)
+ if loaded_ver is None:
+ from .compat import guess_version
+
+ loaded_ver = guess_version(loaded_cfg, cfg_filename)
+ assert loaded_ver <= self.VERSION, "Cannot merge a v{} config into a v{} config.".format(
+ loaded_ver, self.VERSION
+ )
+
+ if loaded_ver == self.VERSION:
+ self.merge_from_other_cfg(loaded_cfg)
+ else:
+ # compat.py needs to import CfgNode
+ from .compat import upgrade_config, downgrade_config
+
+ logger.warning(
+ "Loading an old v{} config file '{}' by automatically upgrading to v{}. "
+ "See docs/CHANGELOG.md for instructions to update your files.".format(
+ loaded_ver, cfg_filename, self.VERSION
+ )
+ )
+ # To convert, first obtain a full config at an old version
+ old_self = downgrade_config(self, to_version=loaded_ver)
+ old_self.merge_from_other_cfg(loaded_cfg)
+ new_config = upgrade_config(old_self)
+ self.clear()
+ self.update(new_config)
+
+ def dump(self, *args, **kwargs):
+ """
+ Returns:
+ str: a yaml string representation of the config
+ """
+ # to make it show up in docs
+ return super().dump(*args, **kwargs)
+
+
+global_cfg = CfgNode()
+
+
+def get_cfg() -> CfgNode:
+ """
+ Get a copy of the default config.
+
+ Returns:
+ a detectron2 CfgNode instance.
+ """
+ from .defaults import _C
+
+ return _C.clone()
+
+
+def set_global_cfg(cfg: CfgNode) -> None:
+ """
+ Let the global config point to the given cfg.
+
+ Assume that the given "cfg" has the key "KEY", after calling
+ `set_global_cfg(cfg)`, the key can be accessed by:
+ ::
+ from annotator.oneformer.detectron2.config import global_cfg
+ print(global_cfg.KEY)
+
+ By using a hacky global config, you can access these configs anywhere,
+ without having to pass the config object or the values deep into the code.
+ This is a hacky feature introduced for quick prototyping / research exploration.
+ """
+ global global_cfg
+ global_cfg.clear()
+ global_cfg.update(cfg)
+
+
+def configurable(init_func=None, *, from_config=None):
+ """
+ Decorate a function or a class's __init__ method so that it can be called
+ with a :class:`CfgNode` object using a :func:`from_config` function that translates
+ :class:`CfgNode` to arguments.
+
+ Examples:
+ ::
+ # Usage 1: Decorator on __init__:
+ class A:
+ @configurable
+ def __init__(self, a, b=2, c=3):
+ pass
+
+ @classmethod
+ def from_config(cls, cfg): # 'cfg' must be the first argument
+ # Returns kwargs to be passed to __init__
+ return {"a": cfg.A, "b": cfg.B}
+
+ a1 = A(a=1, b=2) # regular construction
+ a2 = A(cfg) # construct with a cfg
+ a3 = A(cfg, b=3, c=4) # construct with extra overwrite
+
+ # Usage 2: Decorator on any function. Needs an extra from_config argument:
+ @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B})
+ def a_func(a, b=2, c=3):
+ pass
+
+ a1 = a_func(a=1, b=2) # regular call
+ a2 = a_func(cfg) # call with a cfg
+ a3 = a_func(cfg, b=3, c=4) # call with extra overwrite
+
+ Args:
+ init_func (callable): a class's ``__init__`` method in usage 1. The
+ class must have a ``from_config`` classmethod which takes `cfg` as
+ the first argument.
+ from_config (callable): the from_config function in usage 2. It must take `cfg`
+ as its first argument.
+ """
+
+ if init_func is not None:
+ assert (
+ inspect.isfunction(init_func)
+ and from_config is None
+ and init_func.__name__ == "__init__"
+ ), "Incorrect use of @configurable. Check API documentation for examples."
+
+ @functools.wraps(init_func)
+ def wrapped(self, *args, **kwargs):
+ try:
+ from_config_func = type(self).from_config
+ except AttributeError as e:
+ raise AttributeError(
+ "Class with @configurable must have a 'from_config' classmethod."
+ ) from e
+ if not inspect.ismethod(from_config_func):
+ raise TypeError("Class with @configurable must have a 'from_config' classmethod.")
+
+ if _called_with_cfg(*args, **kwargs):
+ explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
+ init_func(self, **explicit_args)
+ else:
+ init_func(self, *args, **kwargs)
+
+ return wrapped
+
+ else:
+ if from_config is None:
+ return configurable # @configurable() is made equivalent to @configurable
+ assert inspect.isfunction(
+ from_config
+ ), "from_config argument of configurable must be a function!"
+
+ def wrapper(orig_func):
+ @functools.wraps(orig_func)
+ def wrapped(*args, **kwargs):
+ if _called_with_cfg(*args, **kwargs):
+ explicit_args = _get_args_from_config(from_config, *args, **kwargs)
+ return orig_func(**explicit_args)
+ else:
+ return orig_func(*args, **kwargs)
+
+ wrapped.from_config = from_config
+ return wrapped
+
+ return wrapper
+
+
+def _get_args_from_config(from_config_func, *args, **kwargs):
+ """
+ Use `from_config` to obtain explicit arguments.
+
+ Returns:
+ dict: arguments to be used for cls.__init__
+ """
+ signature = inspect.signature(from_config_func)
+ if list(signature.parameters.keys())[0] != "cfg":
+ if inspect.isfunction(from_config_func):
+ name = from_config_func.__name__
+ else:
+ name = f"{from_config_func.__self__}.from_config"
+ raise TypeError(f"{name} must take 'cfg' as the first argument!")
+ support_var_arg = any(
+ param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]
+ for param in signature.parameters.values()
+ )
+ if support_var_arg: # forward all arguments to from_config, if from_config accepts them
+ ret = from_config_func(*args, **kwargs)
+ else:
+ # forward supported arguments to from_config
+ supported_arg_names = set(signature.parameters.keys())
+ extra_kwargs = {}
+ for name in list(kwargs.keys()):
+ if name not in supported_arg_names:
+ extra_kwargs[name] = kwargs.pop(name)
+ ret = from_config_func(*args, **kwargs)
+ # forward the other arguments to __init__
+ ret.update(extra_kwargs)
+ return ret
+
+
+def _called_with_cfg(*args, **kwargs):
+ """
+ Returns:
+ bool: whether the arguments contain CfgNode and should be considered
+ forwarded to from_config.
+ """
+ from omegaconf import DictConfig
+
+ if len(args) and isinstance(args[0], (_CfgNode, DictConfig)):
+ return True
+ if isinstance(kwargs.pop("cfg", None), (_CfgNode, DictConfig)):
+ return True
+ # `from_config`'s first argument is forced to be "cfg".
+ # So the above check covers all cases.
+ return False
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/config/defaults.py b/sd-webui-controlnet/annotator/oneformer/detectron2/config/defaults.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffb79e763f076c9ae982c727309e19b8e0ef170f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/config/defaults.py
@@ -0,0 +1,650 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from .config import CfgNode as CN
+
+# NOTE: given the new config system
+# (https://detectron2.readthedocs.io/en/latest/tutorials/lazyconfigs.html),
+# we will stop adding new functionalities to default CfgNode.
+
+# -----------------------------------------------------------------------------
+# Convention about Training / Test specific parameters
+# -----------------------------------------------------------------------------
+# Whenever an argument can be either used for training or for testing, the
+# corresponding name will be post-fixed by a _TRAIN for a training parameter,
+# or _TEST for a test-specific parameter.
+# For example, the number of images during training will be
+# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
+# IMAGES_PER_BATCH_TEST
+
+# -----------------------------------------------------------------------------
+# Config definition
+# -----------------------------------------------------------------------------
+
+_C = CN()
+
+# The version number, to upgrade from old configs to new ones if any
+# changes happen. It's recommended to keep a VERSION in your config file.
+_C.VERSION = 2
+
+_C.MODEL = CN()
+_C.MODEL.LOAD_PROPOSALS = False
+_C.MODEL.MASK_ON = False
+_C.MODEL.KEYPOINT_ON = False
+_C.MODEL.DEVICE = "cuda"
+_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"
+
+# Path (a file path, or URL like detectron2://.., https://..) to a checkpoint file
+# to be loaded to the model. You can find available models in the model zoo.
+_C.MODEL.WEIGHTS = ""
+
+# Values to be used for image normalization (BGR order, since INPUT.FORMAT defaults to BGR).
+# To train on images of different number of channels, just set different mean & std.
+# Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675]
+_C.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675]
+# When using pre-trained models in Detectron1 or any MSRA models,
+# std has been absorbed into its conv1 weights, so the std needs to be set 1.
+# Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std)
+_C.MODEL.PIXEL_STD = [1.0, 1.0, 1.0]
+
+
+# -----------------------------------------------------------------------------
+# INPUT
+# -----------------------------------------------------------------------------
+_C.INPUT = CN()
+# By default, {MIN,MAX}_SIZE options are used in transforms.ResizeShortestEdge.
+# Please refer to ResizeShortestEdge for detailed definition.
+# Size of the smallest side of the image during training
+_C.INPUT.MIN_SIZE_TRAIN = (800,)
+# Sample size of smallest side by choice or random selection from range give by
+# INPUT.MIN_SIZE_TRAIN
+_C.INPUT.MIN_SIZE_TRAIN_SAMPLING = "choice"
+# Maximum size of the side of the image during training
+_C.INPUT.MAX_SIZE_TRAIN = 1333
+# Size of the smallest side of the image during testing. Set to zero to disable resize in testing.
+_C.INPUT.MIN_SIZE_TEST = 800
+# Maximum size of the side of the image during testing
+_C.INPUT.MAX_SIZE_TEST = 1333
+# Mode for flipping images used in data augmentation during training
+# choose one of ["horizontal, "vertical", "none"]
+_C.INPUT.RANDOM_FLIP = "horizontal"
+
+# `True` if cropping is used for data augmentation during training
+_C.INPUT.CROP = CN({"ENABLED": False})
+# Cropping type. See documentation of `detectron2.data.transforms.RandomCrop` for explanation.
+_C.INPUT.CROP.TYPE = "relative_range"
+# Size of crop in range (0, 1] if CROP.TYPE is "relative" or "relative_range" and in number of
+# pixels if CROP.TYPE is "absolute"
+_C.INPUT.CROP.SIZE = [0.9, 0.9]
+
+
+# Whether the model needs RGB, YUV, HSV etc.
+# Should be one of the modes defined here, as we use PIL to read the image:
+# https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-modes
+# with BGR being the one exception. One can set image format to BGR, we will
+# internally use RGB for conversion and flip the channels over
+_C.INPUT.FORMAT = "BGR"
+# The ground truth mask format that the model will use.
+# Mask R-CNN supports either "polygon" or "bitmask" as ground truth.
+_C.INPUT.MASK_FORMAT = "polygon" # alternative: "bitmask"
+
+
+# -----------------------------------------------------------------------------
+# Dataset
+# -----------------------------------------------------------------------------
+_C.DATASETS = CN()
+# List of the dataset names for training. Must be registered in DatasetCatalog
+# Samples from these datasets will be merged and used as one dataset.
+_C.DATASETS.TRAIN = ()
+# List of the pre-computed proposal files for training, which must be consistent
+# with datasets listed in DATASETS.TRAIN.
+_C.DATASETS.PROPOSAL_FILES_TRAIN = ()
+# Number of top scoring precomputed proposals to keep for training
+_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN = 2000
+# List of the dataset names for testing. Must be registered in DatasetCatalog
+_C.DATASETS.TEST = ()
+# List of the pre-computed proposal files for test, which must be consistent
+# with datasets listed in DATASETS.TEST.
+_C.DATASETS.PROPOSAL_FILES_TEST = ()
+# Number of top scoring precomputed proposals to keep for test
+_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST = 1000
+
+# -----------------------------------------------------------------------------
+# DataLoader
+# -----------------------------------------------------------------------------
+_C.DATALOADER = CN()
+# Number of data loading threads
+_C.DATALOADER.NUM_WORKERS = 4
+# If True, each batch should contain only images for which the aspect ratio
+# is compatible. This groups portrait images together, and landscape images
+# are not batched with portrait images.
+_C.DATALOADER.ASPECT_RATIO_GROUPING = True
+# Options: TrainingSampler, RepeatFactorTrainingSampler
+_C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler"
+# Repeat threshold for RepeatFactorTrainingSampler
+_C.DATALOADER.REPEAT_THRESHOLD = 0.0
+# Tf True, when working on datasets that have instance annotations, the
+# training dataloader will filter out images without associated annotations
+_C.DATALOADER.FILTER_EMPTY_ANNOTATIONS = True
+
+# ---------------------------------------------------------------------------- #
+# Backbone options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.BACKBONE = CN()
+
+_C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
+# Freeze the first several stages so they are not trained.
+# There are 5 stages in ResNet. The first is a convolution, and the following
+# stages are each group of residual blocks.
+_C.MODEL.BACKBONE.FREEZE_AT = 2
+
+
+# ---------------------------------------------------------------------------- #
+# FPN options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.FPN = CN()
+# Names of the input feature maps to be used by FPN
+# They must have contiguous power of 2 strides
+# e.g., ["res2", "res3", "res4", "res5"]
+_C.MODEL.FPN.IN_FEATURES = []
+_C.MODEL.FPN.OUT_CHANNELS = 256
+
+# Options: "" (no norm), "GN"
+_C.MODEL.FPN.NORM = ""
+
+# Types for fusing the FPN top-down and lateral features. Can be either "sum" or "avg"
+_C.MODEL.FPN.FUSE_TYPE = "sum"
+
+
+# ---------------------------------------------------------------------------- #
+# Proposal generator options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.PROPOSAL_GENERATOR = CN()
+# Current proposal generators include "RPN", "RRPN" and "PrecomputedProposals"
+_C.MODEL.PROPOSAL_GENERATOR.NAME = "RPN"
+# Proposal height and width both need to be greater than MIN_SIZE
+# (a the scale used during training or inference)
+_C.MODEL.PROPOSAL_GENERATOR.MIN_SIZE = 0
+
+
+# ---------------------------------------------------------------------------- #
+# Anchor generator options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.ANCHOR_GENERATOR = CN()
+# The generator can be any name in the ANCHOR_GENERATOR registry
+_C.MODEL.ANCHOR_GENERATOR.NAME = "DefaultAnchorGenerator"
+# Anchor sizes (i.e. sqrt of area) in absolute pixels w.r.t. the network input.
+# Format: list[list[float]]. SIZES[i] specifies the list of sizes to use for
+# IN_FEATURES[i]; len(SIZES) must be equal to len(IN_FEATURES) or 1.
+# When len(SIZES) == 1, SIZES[0] is used for all IN_FEATURES.
+_C.MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64, 128, 256, 512]]
+# Anchor aspect ratios. For each area given in `SIZES`, anchors with different aspect
+# ratios are generated by an anchor generator.
+# Format: list[list[float]]. ASPECT_RATIOS[i] specifies the list of aspect ratios (H/W)
+# to use for IN_FEATURES[i]; len(ASPECT_RATIOS) == len(IN_FEATURES) must be true,
+# or len(ASPECT_RATIOS) == 1 is true and aspect ratio list ASPECT_RATIOS[0] is used
+# for all IN_FEATURES.
+_C.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.5, 1.0, 2.0]]
+# Anchor angles.
+# list[list[float]], the angle in degrees, for each input feature map.
+# ANGLES[i] specifies the list of angles for IN_FEATURES[i].
+_C.MODEL.ANCHOR_GENERATOR.ANGLES = [[-90, 0, 90]]
+# Relative offset between the center of the first anchor and the top-left corner of the image
+# Value has to be in [0, 1). Recommend to use 0.5, which means half stride.
+# The value is not expected to affect model accuracy.
+_C.MODEL.ANCHOR_GENERATOR.OFFSET = 0.0
+
+# ---------------------------------------------------------------------------- #
+# RPN options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.RPN = CN()
+_C.MODEL.RPN.HEAD_NAME = "StandardRPNHead" # used by RPN_HEAD_REGISTRY
+
+# Names of the input feature maps to be used by RPN
+# e.g., ["p2", "p3", "p4", "p5", "p6"] for FPN
+_C.MODEL.RPN.IN_FEATURES = ["res4"]
+# Remove RPN anchors that go outside the image by BOUNDARY_THRESH pixels
+# Set to -1 or a large value, e.g. 100000, to disable pruning anchors
+_C.MODEL.RPN.BOUNDARY_THRESH = -1
+# IOU overlap ratios [BG_IOU_THRESHOLD, FG_IOU_THRESHOLD]
+# Minimum overlap required between an anchor and ground-truth box for the
+# (anchor, gt box) pair to be a positive example (IoU >= FG_IOU_THRESHOLD
+# ==> positive RPN example: 1)
+# Maximum overlap allowed between an anchor and ground-truth box for the
+# (anchor, gt box) pair to be a negative examples (IoU < BG_IOU_THRESHOLD
+# ==> negative RPN example: 0)
+# Anchors with overlap in between (BG_IOU_THRESHOLD <= IoU < FG_IOU_THRESHOLD)
+# are ignored (-1)
+_C.MODEL.RPN.IOU_THRESHOLDS = [0.3, 0.7]
+_C.MODEL.RPN.IOU_LABELS = [0, -1, 1]
+# Number of regions per image used to train RPN
+_C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256
+# Target fraction of foreground (positive) examples per RPN minibatch
+_C.MODEL.RPN.POSITIVE_FRACTION = 0.5
+# Options are: "smooth_l1", "giou", "diou", "ciou"
+_C.MODEL.RPN.BBOX_REG_LOSS_TYPE = "smooth_l1"
+_C.MODEL.RPN.BBOX_REG_LOSS_WEIGHT = 1.0
+# Weights on (dx, dy, dw, dh) for normalizing RPN anchor regression targets
+_C.MODEL.RPN.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
+# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1.
+_C.MODEL.RPN.SMOOTH_L1_BETA = 0.0
+_C.MODEL.RPN.LOSS_WEIGHT = 1.0
+# Number of top scoring RPN proposals to keep before applying NMS
+# When FPN is used, this is *per FPN level* (not total)
+_C.MODEL.RPN.PRE_NMS_TOPK_TRAIN = 12000
+_C.MODEL.RPN.PRE_NMS_TOPK_TEST = 6000
+# Number of top scoring RPN proposals to keep after applying NMS
+# When FPN is used, this limit is applied per level and then again to the union
+# of proposals from all levels
+# NOTE: When FPN is used, the meaning of this config is different from Detectron1.
+# It means per-batch topk in Detectron1, but per-image topk here.
+# See the "find_top_rpn_proposals" function for details.
+_C.MODEL.RPN.POST_NMS_TOPK_TRAIN = 2000
+_C.MODEL.RPN.POST_NMS_TOPK_TEST = 1000
+# NMS threshold used on RPN proposals
+_C.MODEL.RPN.NMS_THRESH = 0.7
+# Set this to -1 to use the same number of output channels as input channels.
+_C.MODEL.RPN.CONV_DIMS = [-1]
+
+# ---------------------------------------------------------------------------- #
+# ROI HEADS options
+# ---------------------------------------------------------------------------- #
+_C.MODEL.ROI_HEADS = CN()
+_C.MODEL.ROI_HEADS.NAME = "Res5ROIHeads"
+# Number of foreground classes
+_C.MODEL.ROI_HEADS.NUM_CLASSES = 80
+# Names of the input feature maps to be used by ROI heads
+# Currently all heads (box, mask, ...) use the same input feature map list
+# e.g., ["p2", "p3", "p4", "p5"] is commonly used for FPN
+_C.MODEL.ROI_HEADS.IN_FEATURES = ["res4"]
+# IOU overlap ratios [IOU_THRESHOLD]
+# Overlap threshold for an RoI to be considered background (if < IOU_THRESHOLD)
+# Overlap threshold for an RoI to be considered foreground (if >= IOU_THRESHOLD)
+_C.MODEL.ROI_HEADS.IOU_THRESHOLDS = [0.5]
+_C.MODEL.ROI_HEADS.IOU_LABELS = [0, 1]
+# RoI minibatch size *per image* (number of regions of interest [ROIs]) during training
+# Total number of RoIs per training minibatch =
+# ROI_HEADS.BATCH_SIZE_PER_IMAGE * SOLVER.IMS_PER_BATCH
+# E.g., a common configuration is: 512 * 16 = 8192
+_C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
+# Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0)
+_C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25
+
+# Only used on test mode
+
+# Minimum score threshold (assuming scores in a [0, 1] range); a value chosen to
+# balance obtaining high recall with not having too many low precision
+# detections that will slow down inference post processing steps (like NMS)
+# A default threshold of 0.0 increases AP by ~0.2-0.3 but significantly slows down
+# inference.
+_C.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05
+# Overlap threshold used for non-maximum suppression (suppress boxes with
+# IoU >= this threshold)
+_C.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.5
+# If True, augment proposals with ground-truth boxes before sampling proposals to
+# train ROI heads.
+_C.MODEL.ROI_HEADS.PROPOSAL_APPEND_GT = True
+
+# ---------------------------------------------------------------------------- #
+# Box Head
+# ---------------------------------------------------------------------------- #
+_C.MODEL.ROI_BOX_HEAD = CN()
+# C4 don't use head name option
+# Options for non-C4 models: FastRCNNConvFCHead,
+_C.MODEL.ROI_BOX_HEAD.NAME = ""
+# Options are: "smooth_l1", "giou", "diou", "ciou"
+_C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_TYPE = "smooth_l1"
+# The final scaling coefficient on the box regression loss, used to balance the magnitude of its
+# gradients with other losses in the model. See also `MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT`.
+_C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_WEIGHT = 1.0
+# Default weights on (dx, dy, dw, dh) for normalizing bbox regression targets
+# These are empirically chosen to approximately lead to unit variance targets
+_C.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0)
+# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1.
+_C.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA = 0.0
+_C.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 14
+_C.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO = 0
+# Type of pooling operation applied to the incoming feature map for each RoI
+_C.MODEL.ROI_BOX_HEAD.POOLER_TYPE = "ROIAlignV2"
+
+_C.MODEL.ROI_BOX_HEAD.NUM_FC = 0
+# Hidden layer dimension for FC layers in the RoI box head
+_C.MODEL.ROI_BOX_HEAD.FC_DIM = 1024
+_C.MODEL.ROI_BOX_HEAD.NUM_CONV = 0
+# Channel dimension for Conv layers in the RoI box head
+_C.MODEL.ROI_BOX_HEAD.CONV_DIM = 256
+# Normalization method for the convolution layers.
+# Options: "" (no norm), "GN", "SyncBN".
+_C.MODEL.ROI_BOX_HEAD.NORM = ""
+# Whether to use class agnostic for bbox regression
+_C.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG = False
+# If true, RoI heads use bounding boxes predicted by the box head rather than proposal boxes.
+_C.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES = False
+
+# Federated loss can be used to improve the training of LVIS
+_C.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = False
+# Sigmoid cross entrophy is used with federated loss
+_C.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE = False
+# The power value applied to image_count when calcualting frequency weight
+_C.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT_POWER = 0.5
+# Number of classes to keep in total
+_C.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CLASSES = 50
+
+# ---------------------------------------------------------------------------- #
+# Cascaded Box Head
+# ---------------------------------------------------------------------------- #
+_C.MODEL.ROI_BOX_CASCADE_HEAD = CN()
+# The number of cascade stages is implicitly defined by the length of the following two configs.
+_C.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS = (
+ (10.0, 10.0, 5.0, 5.0),
+ (20.0, 20.0, 10.0, 10.0),
+ (30.0, 30.0, 15.0, 15.0),
+)
+_C.MODEL.ROI_BOX_CASCADE_HEAD.IOUS = (0.5, 0.6, 0.7)
+
+
+# ---------------------------------------------------------------------------- #
+# Mask Head
+# ---------------------------------------------------------------------------- #
+_C.MODEL.ROI_MASK_HEAD = CN()
+_C.MODEL.ROI_MASK_HEAD.NAME = "MaskRCNNConvUpsampleHead"
+_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION = 14
+_C.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO = 0
+_C.MODEL.ROI_MASK_HEAD.NUM_CONV = 0 # The number of convs in the mask head
+_C.MODEL.ROI_MASK_HEAD.CONV_DIM = 256
+# Normalization method for the convolution layers.
+# Options: "" (no norm), "GN", "SyncBN".
+_C.MODEL.ROI_MASK_HEAD.NORM = ""
+# Whether to use class agnostic for mask prediction
+_C.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK = False
+# Type of pooling operation applied to the incoming feature map for each RoI
+_C.MODEL.ROI_MASK_HEAD.POOLER_TYPE = "ROIAlignV2"
+
+
+# ---------------------------------------------------------------------------- #
+# Keypoint Head
+# ---------------------------------------------------------------------------- #
+_C.MODEL.ROI_KEYPOINT_HEAD = CN()
+_C.MODEL.ROI_KEYPOINT_HEAD.NAME = "KRCNNConvDeconvUpsampleHead"
+_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION = 14
+_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO = 0
+_C.MODEL.ROI_KEYPOINT_HEAD.CONV_DIMS = tuple(512 for _ in range(8))
+_C.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS = 17 # 17 is the number of keypoints in COCO.
+
+# Images with too few (or no) keypoints are excluded from training.
+_C.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE = 1
+# Normalize by the total number of visible keypoints in the minibatch if True.
+# Otherwise, normalize by the total number of keypoints that could ever exist
+# in the minibatch.
+# The keypoint softmax loss is only calculated on visible keypoints.
+# Since the number of visible keypoints can vary significantly between
+# minibatches, this has the effect of up-weighting the importance of
+# minibatches with few visible keypoints. (Imagine the extreme case of
+# only one visible keypoint versus N: in the case of N, each one
+# contributes 1/N to the gradient compared to the single keypoint
+# determining the gradient direction). Instead, we can normalize the
+# loss by the total number of keypoints, if it were the case that all
+# keypoints were visible in a full minibatch. (Returning to the example,
+# this means that the one visible keypoint contributes as much as each
+# of the N keypoints.)
+_C.MODEL.ROI_KEYPOINT_HEAD.NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS = True
+# Multi-task loss weight to use for keypoints
+# Recommended values:
+# - use 1.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is True
+# - use 4.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is False
+_C.MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT = 1.0
+# Type of pooling operation applied to the incoming feature map for each RoI
+_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_TYPE = "ROIAlignV2"
+
+# ---------------------------------------------------------------------------- #
+# Semantic Segmentation Head
+# ---------------------------------------------------------------------------- #
+_C.MODEL.SEM_SEG_HEAD = CN()
+_C.MODEL.SEM_SEG_HEAD.NAME = "SemSegFPNHead"
+_C.MODEL.SEM_SEG_HEAD.IN_FEATURES = ["p2", "p3", "p4", "p5"]
+# Label in the semantic segmentation ground truth that is ignored, i.e., no loss is calculated for
+# the correposnding pixel.
+_C.MODEL.SEM_SEG_HEAD.IGNORE_VALUE = 255
+# Number of classes in the semantic segmentation head
+_C.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 54
+# Number of channels in the 3x3 convs inside semantic-FPN heads.
+_C.MODEL.SEM_SEG_HEAD.CONVS_DIM = 128
+# Outputs from semantic-FPN heads are up-scaled to the COMMON_STRIDE stride.
+_C.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
+# Normalization method for the convolution layers. Options: "" (no norm), "GN".
+_C.MODEL.SEM_SEG_HEAD.NORM = "GN"
+_C.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT = 1.0
+
+_C.MODEL.PANOPTIC_FPN = CN()
+# Scaling of all losses from instance detection / segmentation head.
+_C.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT = 1.0
+
+# options when combining instance & semantic segmentation outputs
+_C.MODEL.PANOPTIC_FPN.COMBINE = CN({"ENABLED": True}) # "COMBINE.ENABLED" is deprecated & not used
+_C.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH = 0.5
+_C.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT = 4096
+_C.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.5
+
+
+# ---------------------------------------------------------------------------- #
+# RetinaNet Head
+# ---------------------------------------------------------------------------- #
+_C.MODEL.RETINANET = CN()
+
+# This is the number of foreground classes.
+_C.MODEL.RETINANET.NUM_CLASSES = 80
+
+_C.MODEL.RETINANET.IN_FEATURES = ["p3", "p4", "p5", "p6", "p7"]
+
+# Convolutions to use in the cls and bbox tower
+# NOTE: this doesn't include the last conv for logits
+_C.MODEL.RETINANET.NUM_CONVS = 4
+
+# IoU overlap ratio [bg, fg] for labeling anchors.
+# Anchors with < bg are labeled negative (0)
+# Anchors with >= bg and < fg are ignored (-1)
+# Anchors with >= fg are labeled positive (1)
+_C.MODEL.RETINANET.IOU_THRESHOLDS = [0.4, 0.5]
+_C.MODEL.RETINANET.IOU_LABELS = [0, -1, 1]
+
+# Prior prob for rare case (i.e. foreground) at the beginning of training.
+# This is used to set the bias for the logits layer of the classifier subnet.
+# This improves training stability in the case of heavy class imbalance.
+_C.MODEL.RETINANET.PRIOR_PROB = 0.01
+
+# Inference cls score threshold, only anchors with score > INFERENCE_TH are
+# considered for inference (to improve speed)
+_C.MODEL.RETINANET.SCORE_THRESH_TEST = 0.05
+# Select topk candidates before NMS
+_C.MODEL.RETINANET.TOPK_CANDIDATES_TEST = 1000
+_C.MODEL.RETINANET.NMS_THRESH_TEST = 0.5
+
+# Weights on (dx, dy, dw, dh) for normalizing Retinanet anchor regression targets
+_C.MODEL.RETINANET.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
+
+# Loss parameters
+_C.MODEL.RETINANET.FOCAL_LOSS_GAMMA = 2.0
+_C.MODEL.RETINANET.FOCAL_LOSS_ALPHA = 0.25
+_C.MODEL.RETINANET.SMOOTH_L1_LOSS_BETA = 0.1
+# Options are: "smooth_l1", "giou", "diou", "ciou"
+_C.MODEL.RETINANET.BBOX_REG_LOSS_TYPE = "smooth_l1"
+
+# One of BN, SyncBN, FrozenBN, GN
+# Only supports GN until unshared norm is implemented
+_C.MODEL.RETINANET.NORM = ""
+
+
+# ---------------------------------------------------------------------------- #
+# ResNe[X]t options (ResNets = {ResNet, ResNeXt}
+# Note that parts of a resnet may be used for both the backbone and the head
+# These options apply to both
+# ---------------------------------------------------------------------------- #
+_C.MODEL.RESNETS = CN()
+
+_C.MODEL.RESNETS.DEPTH = 50
+_C.MODEL.RESNETS.OUT_FEATURES = ["res4"] # res4 for C4 backbone, res2..5 for FPN backbone
+
+# Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt
+_C.MODEL.RESNETS.NUM_GROUPS = 1
+
+# Options: FrozenBN, GN, "SyncBN", "BN"
+_C.MODEL.RESNETS.NORM = "FrozenBN"
+
+# Baseline width of each group.
+# Scaling this parameters will scale the width of all bottleneck layers.
+_C.MODEL.RESNETS.WIDTH_PER_GROUP = 64
+
+# Place the stride 2 conv on the 1x1 filter
+# Use True only for the original MSRA ResNet; use False for C2 and Torch models
+_C.MODEL.RESNETS.STRIDE_IN_1X1 = True
+
+# Apply dilation in stage "res5"
+_C.MODEL.RESNETS.RES5_DILATION = 1
+
+# Output width of res2. Scaling this parameters will scale the width of all 1x1 convs in ResNet
+# For R18 and R34, this needs to be set to 64
+_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256
+_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64
+
+# Apply Deformable Convolution in stages
+# Specify if apply deform_conv on Res2, Res3, Res4, Res5
+_C.MODEL.RESNETS.DEFORM_ON_PER_STAGE = [False, False, False, False]
+# Use True to use modulated deform_conv (DeformableV2, https://arxiv.org/abs/1811.11168);
+# Use False for DeformableV1.
+_C.MODEL.RESNETS.DEFORM_MODULATED = False
+# Number of groups in deformable conv.
+_C.MODEL.RESNETS.DEFORM_NUM_GROUPS = 1
+
+
+# ---------------------------------------------------------------------------- #
+# Solver
+# ---------------------------------------------------------------------------- #
+_C.SOLVER = CN()
+
+# Options: WarmupMultiStepLR, WarmupCosineLR.
+# See detectron2/solver/build.py for definition.
+_C.SOLVER.LR_SCHEDULER_NAME = "WarmupMultiStepLR"
+
+_C.SOLVER.MAX_ITER = 40000
+
+_C.SOLVER.BASE_LR = 0.001
+# The end lr, only used by WarmupCosineLR
+_C.SOLVER.BASE_LR_END = 0.0
+
+_C.SOLVER.MOMENTUM = 0.9
+
+_C.SOLVER.NESTEROV = False
+
+_C.SOLVER.WEIGHT_DECAY = 0.0001
+# The weight decay that's applied to parameters of normalization layers
+# (typically the affine transformation)
+_C.SOLVER.WEIGHT_DECAY_NORM = 0.0
+
+_C.SOLVER.GAMMA = 0.1
+# The iteration number to decrease learning rate by GAMMA.
+_C.SOLVER.STEPS = (30000,)
+# Number of decays in WarmupStepWithFixedGammaLR schedule
+_C.SOLVER.NUM_DECAYS = 3
+
+_C.SOLVER.WARMUP_FACTOR = 1.0 / 1000
+_C.SOLVER.WARMUP_ITERS = 1000
+_C.SOLVER.WARMUP_METHOD = "linear"
+# Whether to rescale the interval for the learning schedule after warmup
+_C.SOLVER.RESCALE_INTERVAL = False
+
+# Save a checkpoint after every this number of iterations
+_C.SOLVER.CHECKPOINT_PERIOD = 5000
+
+# Number of images per batch across all machines. This is also the number
+# of training images per step (i.e. per iteration). If we use 16 GPUs
+# and IMS_PER_BATCH = 32, each GPU will see 2 images per batch.
+# May be adjusted automatically if REFERENCE_WORLD_SIZE is set.
+_C.SOLVER.IMS_PER_BATCH = 16
+
+# The reference number of workers (GPUs) this config is meant to train with.
+# It takes no effect when set to 0.
+# With a non-zero value, it will be used by DefaultTrainer to compute a desired
+# per-worker batch size, and then scale the other related configs (total batch size,
+# learning rate, etc) to match the per-worker batch size.
+# See documentation of `DefaultTrainer.auto_scale_workers` for details:
+_C.SOLVER.REFERENCE_WORLD_SIZE = 0
+
+# Detectron v1 (and previous detection code) used a 2x higher LR and 0 WD for
+# biases. This is not useful (at least for recent models). You should avoid
+# changing these and they exist only to reproduce Detectron v1 training if
+# desired.
+_C.SOLVER.BIAS_LR_FACTOR = 1.0
+_C.SOLVER.WEIGHT_DECAY_BIAS = None # None means following WEIGHT_DECAY
+
+# Gradient clipping
+_C.SOLVER.CLIP_GRADIENTS = CN({"ENABLED": False})
+# Type of gradient clipping, currently 2 values are supported:
+# - "value": the absolute values of elements of each gradients are clipped
+# - "norm": the norm of the gradient for each parameter is clipped thus
+# affecting all elements in the parameter
+_C.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "value"
+# Maximum absolute value used for clipping gradients
+_C.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 1.0
+# Floating point number p for L-p norm to be used with the "norm"
+# gradient clipping type; for L-inf, please specify .inf
+_C.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 2.0
+
+# Enable automatic mixed precision for training
+# Note that this does not change model's inference behavior.
+# To use AMP in inference, run inference under autocast()
+_C.SOLVER.AMP = CN({"ENABLED": False})
+
+# ---------------------------------------------------------------------------- #
+# Specific test options
+# ---------------------------------------------------------------------------- #
+_C.TEST = CN()
+# For end-to-end tests to verify the expected accuracy.
+# Each item is [task, metric, value, tolerance]
+# e.g.: [['bbox', 'AP', 38.5, 0.2]]
+_C.TEST.EXPECTED_RESULTS = []
+# The period (in terms of steps) to evaluate the model during training.
+# Set to 0 to disable.
+_C.TEST.EVAL_PERIOD = 0
+# The sigmas used to calculate keypoint OKS. See http://cocodataset.org/#keypoints-eval
+# When empty, it will use the defaults in COCO.
+# Otherwise it should be a list[float] with the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
+_C.TEST.KEYPOINT_OKS_SIGMAS = []
+# Maximum number of detections to return per image during inference (100 is
+# based on the limit established for the COCO dataset).
+_C.TEST.DETECTIONS_PER_IMAGE = 100
+
+_C.TEST.AUG = CN({"ENABLED": False})
+_C.TEST.AUG.MIN_SIZES = (400, 500, 600, 700, 800, 900, 1000, 1100, 1200)
+_C.TEST.AUG.MAX_SIZE = 4000
+_C.TEST.AUG.FLIP = True
+
+_C.TEST.PRECISE_BN = CN({"ENABLED": False})
+_C.TEST.PRECISE_BN.NUM_ITER = 200
+
+# ---------------------------------------------------------------------------- #
+# Misc options
+# ---------------------------------------------------------------------------- #
+# Directory where output files are written
+_C.OUTPUT_DIR = "./output"
+# Set seed to negative to fully randomize everything.
+# Set seed to positive to use a fixed seed. Note that a fixed seed increases
+# reproducibility but does not guarantee fully deterministic behavior.
+# Disabling all parallelism further increases reproducibility.
+_C.SEED = -1
+# Benchmark different cudnn algorithms.
+# If input images have very different sizes, this option will have large overhead
+# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
+# If input images have the same or similar sizes, benchmark is often helpful.
+_C.CUDNN_BENCHMARK = False
+# The period (in terms of steps) for minibatch visualization at train time.
+# Set to 0 to disable.
+_C.VIS_PERIOD = 0
+
+# global config is for quick hack purposes.
+# You can set them in command line or config files,
+# and access it with:
+#
+# from annotator.oneformer.detectron2.config import global_cfg
+# print(global_cfg.HACK)
+#
+# Do not commit any configs into it.
+_C.GLOBAL = CN()
+_C.GLOBAL.HACK = 1.0
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/config/instantiate.py b/sd-webui-controlnet/annotator/oneformer/detectron2/config/instantiate.py
new file mode 100644
index 0000000000000000000000000000000000000000..26d191b03f800dae5620128957d137cd4fdb1728
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/config/instantiate.py
@@ -0,0 +1,88 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import collections.abc as abc
+import dataclasses
+import logging
+from typing import Any
+
+from annotator.oneformer.detectron2.utils.registry import _convert_target_to_string, locate
+
+__all__ = ["dump_dataclass", "instantiate"]
+
+
+def dump_dataclass(obj: Any):
+ """
+ Dump a dataclass recursively into a dict that can be later instantiated.
+
+ Args:
+ obj: a dataclass object
+
+ Returns:
+ dict
+ """
+ assert dataclasses.is_dataclass(obj) and not isinstance(
+ obj, type
+ ), "dump_dataclass() requires an instance of a dataclass."
+ ret = {"_target_": _convert_target_to_string(type(obj))}
+ for f in dataclasses.fields(obj):
+ v = getattr(obj, f.name)
+ if dataclasses.is_dataclass(v):
+ v = dump_dataclass(v)
+ if isinstance(v, (list, tuple)):
+ v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v]
+ ret[f.name] = v
+ return ret
+
+
+def instantiate(cfg):
+ """
+ Recursively instantiate objects defined in dictionaries by
+ "_target_" and arguments.
+
+ Args:
+ cfg: a dict-like object with "_target_" that defines the caller, and
+ other keys that define the arguments
+
+ Returns:
+ object instantiated by cfg
+ """
+ from omegaconf import ListConfig, DictConfig, OmegaConf
+
+ if isinstance(cfg, ListConfig):
+ lst = [instantiate(x) for x in cfg]
+ return ListConfig(lst, flags={"allow_objects": True})
+ if isinstance(cfg, list):
+ # Specialize for list, because many classes take
+ # list[objects] as arguments, such as ResNet, DatasetMapper
+ return [instantiate(x) for x in cfg]
+
+ # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config),
+ # instantiate it to the actual dataclass.
+ if isinstance(cfg, DictConfig) and dataclasses.is_dataclass(cfg._metadata.object_type):
+ return OmegaConf.to_object(cfg)
+
+ if isinstance(cfg, abc.Mapping) and "_target_" in cfg:
+ # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all,
+ # but faster: https://github.com/facebookresearch/hydra/issues/1200
+ cfg = {k: instantiate(v) for k, v in cfg.items()}
+ cls = cfg.pop("_target_")
+ cls = instantiate(cls)
+
+ if isinstance(cls, str):
+ cls_name = cls
+ cls = locate(cls_name)
+ assert cls is not None, cls_name
+ else:
+ try:
+ cls_name = cls.__module__ + "." + cls.__qualname__
+ except Exception:
+ # target could be anything, so the above could fail
+ cls_name = str(cls)
+ assert callable(cls), f"_target_ {cls} does not define a callable object"
+ try:
+ return cls(**cfg)
+ except TypeError:
+ logger = logging.getLogger(__name__)
+ logger.error(f"Error when instantiating {cls_name}!")
+ raise
+ return cfg # return as-is if don't know what to do
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/config/lazy.py b/sd-webui-controlnet/annotator/oneformer/detectron2/config/lazy.py
new file mode 100644
index 0000000000000000000000000000000000000000..72a3e5c036f9f78a2cdf3ef0975639da3299d694
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/config/lazy.py
@@ -0,0 +1,435 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import ast
+import builtins
+import collections.abc as abc
+import importlib
+import inspect
+import logging
+import os
+import uuid
+from contextlib import contextmanager
+from copy import deepcopy
+from dataclasses import is_dataclass
+from typing import List, Tuple, Union
+import yaml
+from omegaconf import DictConfig, ListConfig, OmegaConf, SCMode
+
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+from annotator.oneformer.detectron2.utils.registry import _convert_target_to_string
+
+__all__ = ["LazyCall", "LazyConfig"]
+
+
+class LazyCall:
+ """
+ Wrap a callable so that when it's called, the call will not be executed,
+ but returns a dict that describes the call.
+
+ LazyCall object has to be called with only keyword arguments. Positional
+ arguments are not yet supported.
+
+ Examples:
+ ::
+ from annotator.oneformer.detectron2.config import instantiate, LazyCall
+
+ layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
+ layer_cfg.out_channels = 64 # can edit it afterwards
+ layer = instantiate(layer_cfg)
+ """
+
+ def __init__(self, target):
+ if not (callable(target) or isinstance(target, (str, abc.Mapping))):
+ raise TypeError(
+ f"target of LazyCall must be a callable or defines a callable! Got {target}"
+ )
+ self._target = target
+
+ def __call__(self, **kwargs):
+ if is_dataclass(self._target):
+ # omegaconf object cannot hold dataclass type
+ # https://github.com/omry/omegaconf/issues/784
+ target = _convert_target_to_string(self._target)
+ else:
+ target = self._target
+ kwargs["_target_"] = target
+
+ return DictConfig(content=kwargs, flags={"allow_objects": True})
+
+
+def _visit_dict_config(cfg, func):
+ """
+ Apply func recursively to all DictConfig in cfg.
+ """
+ if isinstance(cfg, DictConfig):
+ func(cfg)
+ for v in cfg.values():
+ _visit_dict_config(v, func)
+ elif isinstance(cfg, ListConfig):
+ for v in cfg:
+ _visit_dict_config(v, func)
+
+
+def _validate_py_syntax(filename):
+ # see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
+ with PathManager.open(filename, "r") as f:
+ content = f.read()
+ try:
+ ast.parse(content)
+ except SyntaxError as e:
+ raise SyntaxError(f"Config file {filename} has syntax error!") from e
+
+
+def _cast_to_config(obj):
+ # if given a dict, return DictConfig instead
+ if isinstance(obj, dict):
+ return DictConfig(obj, flags={"allow_objects": True})
+ return obj
+
+
+_CFG_PACKAGE_NAME = "detectron2._cfg_loader"
+"""
+A namespace to put all imported config into.
+"""
+
+
+def _random_package_name(filename):
+ # generate a random package name when loading config files
+ return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename)
+
+
+@contextmanager
+def _patch_import():
+ """
+ Enhance relative import statements in config files, so that they:
+ 1. locate files purely based on relative location, regardless of packages.
+ e.g. you can import file without having __init__
+ 2. do not cache modules globally; modifications of module states has no side effect
+ 3. support other storage system through PathManager, so config files can be in the cloud
+ 4. imported dict are turned into omegaconf.DictConfig automatically
+ """
+ old_import = builtins.__import__
+
+ def find_relative_file(original_file, relative_import_path, level):
+ # NOTE: "from . import x" is not handled. Because then it's unclear
+ # if such import should produce `x` as a python module or DictConfig.
+ # This can be discussed further if needed.
+ relative_import_err = """
+Relative import of directories is not allowed within config files.
+Within a config file, relative import can only import other config files.
+""".replace(
+ "\n", " "
+ )
+ if not len(relative_import_path):
+ raise ImportError(relative_import_err)
+
+ cur_file = os.path.dirname(original_file)
+ for _ in range(level - 1):
+ cur_file = os.path.dirname(cur_file)
+ cur_name = relative_import_path.lstrip(".")
+ for part in cur_name.split("."):
+ cur_file = os.path.join(cur_file, part)
+ if not cur_file.endswith(".py"):
+ cur_file += ".py"
+ if not PathManager.isfile(cur_file):
+ cur_file_no_suffix = cur_file[: -len(".py")]
+ if PathManager.isdir(cur_file_no_suffix):
+ raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err)
+ else:
+ raise ImportError(
+ f"Cannot import name {relative_import_path} from "
+ f"{original_file}: {cur_file} does not exist."
+ )
+ return cur_file
+
+ def new_import(name, globals=None, locals=None, fromlist=(), level=0):
+ if (
+ # Only deal with relative imports inside config files
+ level != 0
+ and globals is not None
+ and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME)
+ ):
+ cur_file = find_relative_file(globals["__file__"], name, level)
+ _validate_py_syntax(cur_file)
+ spec = importlib.machinery.ModuleSpec(
+ _random_package_name(cur_file), None, origin=cur_file
+ )
+ module = importlib.util.module_from_spec(spec)
+ module.__file__ = cur_file
+ with PathManager.open(cur_file) as f:
+ content = f.read()
+ exec(compile(content, cur_file, "exec"), module.__dict__)
+ for name in fromlist: # turn imported dict into DictConfig automatically
+ val = _cast_to_config(module.__dict__[name])
+ module.__dict__[name] = val
+ return module
+ return old_import(name, globals, locals, fromlist=fromlist, level=level)
+
+ builtins.__import__ = new_import
+ yield new_import
+ builtins.__import__ = old_import
+
+
+class LazyConfig:
+ """
+ Provide methods to save, load, and overrides an omegaconf config object
+ which may contain definition of lazily-constructed objects.
+ """
+
+ @staticmethod
+ def load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
+ """
+ Similar to :meth:`load()`, but load path relative to the caller's
+ source file.
+
+ This has the same functionality as a relative import, except that this method
+ accepts filename as a string, so more characters are allowed in the filename.
+ """
+ caller_frame = inspect.stack()[1]
+ caller_fname = caller_frame[0].f_code.co_filename
+ assert caller_fname != "", "load_rel Unable to find caller"
+ caller_dir = os.path.dirname(caller_fname)
+ filename = os.path.join(caller_dir, filename)
+ return LazyConfig.load(filename, keys)
+
+ @staticmethod
+ def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
+ """
+ Load a config file.
+
+ Args:
+ filename: absolute path or relative path w.r.t. the current working directory
+ keys: keys to load and return. If not given, return all keys
+ (whose values are config objects) in a dict.
+ """
+ has_keys = keys is not None
+ filename = filename.replace("/./", "/") # redundant
+ if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]:
+ raise ValueError(f"Config file {filename} has to be a python or yaml file.")
+ if filename.endswith(".py"):
+ _validate_py_syntax(filename)
+
+ with _patch_import():
+ # Record the filename
+ module_namespace = {
+ "__file__": filename,
+ "__package__": _random_package_name(filename),
+ }
+ with PathManager.open(filename) as f:
+ content = f.read()
+ # Compile first with filename to:
+ # 1. make filename appears in stacktrace
+ # 2. make load_rel able to find its parent's (possibly remote) location
+ exec(compile(content, filename, "exec"), module_namespace)
+
+ ret = module_namespace
+ else:
+ with PathManager.open(filename) as f:
+ obj = yaml.unsafe_load(f)
+ ret = OmegaConf.create(obj, flags={"allow_objects": True})
+
+ if has_keys:
+ if isinstance(keys, str):
+ return _cast_to_config(ret[keys])
+ else:
+ return tuple(_cast_to_config(ret[a]) for a in keys)
+ else:
+ if filename.endswith(".py"):
+ # when not specified, only load those that are config objects
+ ret = DictConfig(
+ {
+ name: _cast_to_config(value)
+ for name, value in ret.items()
+ if isinstance(value, (DictConfig, ListConfig, dict))
+ and not name.startswith("_")
+ },
+ flags={"allow_objects": True},
+ )
+ return ret
+
+ @staticmethod
+ def save(cfg, filename: str):
+ """
+ Save a config object to a yaml file.
+ Note that when the config dictionary contains complex objects (e.g. lambda),
+ it can't be saved to yaml. In that case we will print an error and
+ attempt to save to a pkl file instead.
+
+ Args:
+ cfg: an omegaconf config object
+ filename: yaml file name to save the config file
+ """
+ logger = logging.getLogger(__name__)
+ try:
+ cfg = deepcopy(cfg)
+ except Exception:
+ pass
+ else:
+ # if it's deep-copyable, then...
+ def _replace_type_by_name(x):
+ if "_target_" in x and callable(x._target_):
+ try:
+ x._target_ = _convert_target_to_string(x._target_)
+ except AttributeError:
+ pass
+
+ # not necessary, but makes yaml looks nicer
+ _visit_dict_config(cfg, _replace_type_by_name)
+
+ save_pkl = False
+ try:
+ dict = OmegaConf.to_container(
+ cfg,
+ # Do not resolve interpolation when saving, i.e. do not turn ${a} into
+ # actual values when saving.
+ resolve=False,
+ # Save structures (dataclasses) in a format that can be instantiated later.
+ # Without this option, the type information of the dataclass will be erased.
+ structured_config_mode=SCMode.INSTANTIATE,
+ )
+ dumped = yaml.dump(dict, default_flow_style=None, allow_unicode=True, width=9999)
+ with PathManager.open(filename, "w") as f:
+ f.write(dumped)
+
+ try:
+ _ = yaml.unsafe_load(dumped) # test that it is loadable
+ except Exception:
+ logger.warning(
+ "The config contains objects that cannot serialize to a valid yaml. "
+ f"{filename} is human-readable but cannot be loaded."
+ )
+ save_pkl = True
+ except Exception:
+ logger.exception("Unable to serialize the config to yaml. Error:")
+ save_pkl = True
+
+ if save_pkl:
+ new_filename = filename + ".pkl"
+ # try:
+ # # retry by pickle
+ # with PathManager.open(new_filename, "wb") as f:
+ # cloudpickle.dump(cfg, f)
+ # logger.warning(f"Config is saved using cloudpickle at {new_filename}.")
+ # except Exception:
+ # pass
+
+ @staticmethod
+ def apply_overrides(cfg, overrides: List[str]):
+ """
+ In-place override contents of cfg.
+
+ Args:
+ cfg: an omegaconf config object
+ overrides: list of strings in the format of "a=b" to override configs.
+ See https://hydra.cc/docs/next/advanced/override_grammar/basic/
+ for syntax.
+
+ Returns:
+ the cfg object
+ """
+
+ def safe_update(cfg, key, value):
+ parts = key.split(".")
+ for idx in range(1, len(parts)):
+ prefix = ".".join(parts[:idx])
+ v = OmegaConf.select(cfg, prefix, default=None)
+ if v is None:
+ break
+ if not OmegaConf.is_config(v):
+ raise KeyError(
+ f"Trying to update key {key}, but {prefix} "
+ f"is not a config, but has type {type(v)}."
+ )
+ OmegaConf.update(cfg, key, value, merge=True)
+
+ try:
+ from hydra.core.override_parser.overrides_parser import OverridesParser
+
+ has_hydra = True
+ except ImportError:
+ has_hydra = False
+
+ if has_hydra:
+ parser = OverridesParser.create()
+ overrides = parser.parse_overrides(overrides)
+ for o in overrides:
+ key = o.key_or_group
+ value = o.value()
+ if o.is_delete():
+ # TODO support this
+ raise NotImplementedError("deletion is not yet a supported override")
+ safe_update(cfg, key, value)
+ else:
+ # Fallback. Does not support all the features and error checking like hydra.
+ for o in overrides:
+ key, value = o.split("=")
+ try:
+ value = eval(value, {})
+ except NameError:
+ pass
+ safe_update(cfg, key, value)
+ return cfg
+
+ # @staticmethod
+ # def to_py(cfg, prefix: str = "cfg."):
+ # """
+ # Try to convert a config object into Python-like psuedo code.
+ #
+ # Note that perfect conversion is not always possible. So the returned
+ # results are mainly meant to be human-readable, and not meant to be executed.
+ #
+ # Args:
+ # cfg: an omegaconf config object
+ # prefix: root name for the resulting code (default: "cfg.")
+ #
+ #
+ # Returns:
+ # str of formatted Python code
+ # """
+ # import black
+ #
+ # cfg = OmegaConf.to_container(cfg, resolve=True)
+ #
+ # def _to_str(obj, prefix=None, inside_call=False):
+ # if prefix is None:
+ # prefix = []
+ # if isinstance(obj, abc.Mapping) and "_target_" in obj:
+ # # Dict representing a function call
+ # target = _convert_target_to_string(obj.pop("_target_"))
+ # args = []
+ # for k, v in sorted(obj.items()):
+ # args.append(f"{k}={_to_str(v, inside_call=True)}")
+ # args = ", ".join(args)
+ # call = f"{target}({args})"
+ # return "".join(prefix) + call
+ # elif isinstance(obj, abc.Mapping) and not inside_call:
+ # # Dict that is not inside a call is a list of top-level config objects that we
+ # # render as one object per line with dot separated prefixes
+ # key_list = []
+ # for k, v in sorted(obj.items()):
+ # if isinstance(v, abc.Mapping) and "_target_" not in v:
+ # key_list.append(_to_str(v, prefix=prefix + [k + "."]))
+ # else:
+ # key = "".join(prefix) + k
+ # key_list.append(f"{key}={_to_str(v)}")
+ # return "\n".join(key_list)
+ # elif isinstance(obj, abc.Mapping):
+ # # Dict that is inside a call is rendered as a regular dict
+ # return (
+ # "{"
+ # + ",".join(
+ # f"{repr(k)}: {_to_str(v, inside_call=inside_call)}"
+ # for k, v in sorted(obj.items())
+ # )
+ # + "}"
+ # )
+ # elif isinstance(obj, list):
+ # return "[" + ",".join(_to_str(x, inside_call=inside_call) for x in obj) + "]"
+ # else:
+ # return repr(obj)
+ #
+ # py_str = _to_str(cfg, prefix=[prefix])
+ # try:
+ # return black.format_str(py_str, mode=black.Mode())
+ # except black.InvalidInput:
+ # return py_str
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/__init__.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..259f669b78bd05815cb8d3351fd6c5fc9a1b85a1
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from . import transforms # isort:skip
+
+from .build import (
+ build_batch_data_loader,
+ build_detection_test_loader,
+ build_detection_train_loader,
+ get_detection_dataset_dicts,
+ load_proposals_into_dataset,
+ print_instances_class_histogram,
+)
+from .catalog import DatasetCatalog, MetadataCatalog, Metadata
+from .common import DatasetFromList, MapDataset, ToIterableDataset
+from .dataset_mapper import DatasetMapper
+
+# ensure the builtin datasets are registered
+from . import datasets, samplers # isort:skip
+
+__all__ = [k for k in globals().keys() if not k.startswith("_")]
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/benchmark.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfd650582c83cd032b4fe76303517cdfd9a2a8b4
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/benchmark.py
@@ -0,0 +1,225 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+import numpy as np
+from itertools import count
+from typing import List, Tuple
+import torch
+import tqdm
+from fvcore.common.timer import Timer
+
+from annotator.oneformer.detectron2.utils import comm
+
+from .build import build_batch_data_loader
+from .common import DatasetFromList, MapDataset
+from .samplers import TrainingSampler
+
+logger = logging.getLogger(__name__)
+
+
+class _EmptyMapDataset(torch.utils.data.Dataset):
+ """
+ Map anything to emptiness.
+ """
+
+ def __init__(self, dataset):
+ self.ds = dataset
+
+ def __len__(self):
+ return len(self.ds)
+
+ def __getitem__(self, idx):
+ _ = self.ds[idx]
+ return [0]
+
+
+def iter_benchmark(
+ iterator, num_iter: int, warmup: int = 5, max_time_seconds: float = 60
+) -> Tuple[float, List[float]]:
+ """
+ Benchmark an iterator/iterable for `num_iter` iterations with an extra
+ `warmup` iterations of warmup.
+ End early if `max_time_seconds` time is spent on iterations.
+
+ Returns:
+ float: average time (seconds) per iteration
+ list[float]: time spent on each iteration. Sometimes useful for further analysis.
+ """
+ num_iter, warmup = int(num_iter), int(warmup)
+
+ iterator = iter(iterator)
+ for _ in range(warmup):
+ next(iterator)
+ timer = Timer()
+ all_times = []
+ for curr_iter in tqdm.trange(num_iter):
+ start = timer.seconds()
+ if start > max_time_seconds:
+ num_iter = curr_iter
+ break
+ next(iterator)
+ all_times.append(timer.seconds() - start)
+ avg = timer.seconds() / num_iter
+ return avg, all_times
+
+
+class DataLoaderBenchmark:
+ """
+ Some common benchmarks that help understand perf bottleneck of a standard dataloader
+ made of dataset, mapper and sampler.
+ """
+
+ def __init__(
+ self,
+ dataset,
+ *,
+ mapper,
+ sampler=None,
+ total_batch_size,
+ num_workers=0,
+ max_time_seconds: int = 90,
+ ):
+ """
+ Args:
+ max_time_seconds (int): maximum time to spent for each benchmark
+ other args: same as in `build.py:build_detection_train_loader`
+ """
+ if isinstance(dataset, list):
+ dataset = DatasetFromList(dataset, copy=False, serialize=True)
+ if sampler is None:
+ sampler = TrainingSampler(len(dataset))
+
+ self.dataset = dataset
+ self.mapper = mapper
+ self.sampler = sampler
+ self.total_batch_size = total_batch_size
+ self.num_workers = num_workers
+ self.per_gpu_batch_size = self.total_batch_size // comm.get_world_size()
+
+ self.max_time_seconds = max_time_seconds
+
+ def _benchmark(self, iterator, num_iter, warmup, msg=None):
+ avg, all_times = iter_benchmark(iterator, num_iter, warmup, self.max_time_seconds)
+ if msg is not None:
+ self._log_time(msg, avg, all_times)
+ return avg, all_times
+
+ def _log_time(self, msg, avg, all_times, distributed=False):
+ percentiles = [np.percentile(all_times, k, interpolation="nearest") for k in [1, 5, 95, 99]]
+ if not distributed:
+ logger.info(
+ f"{msg}: avg={1.0/avg:.1f} it/s, "
+ f"p1={percentiles[0]:.2g}s, p5={percentiles[1]:.2g}s, "
+ f"p95={percentiles[2]:.2g}s, p99={percentiles[3]:.2g}s."
+ )
+ return
+ avg_per_gpu = comm.all_gather(avg)
+ percentiles_per_gpu = comm.all_gather(percentiles)
+ if comm.get_rank() > 0:
+ return
+ for idx, avg, percentiles in zip(count(), avg_per_gpu, percentiles_per_gpu):
+ logger.info(
+ f"GPU{idx} {msg}: avg={1.0/avg:.1f} it/s, "
+ f"p1={percentiles[0]:.2g}s, p5={percentiles[1]:.2g}s, "
+ f"p95={percentiles[2]:.2g}s, p99={percentiles[3]:.2g}s."
+ )
+
+ def benchmark_dataset(self, num_iter, warmup=5):
+ """
+ Benchmark the speed of taking raw samples from the dataset.
+ """
+
+ def loader():
+ while True:
+ for k in self.sampler:
+ yield self.dataset[k]
+
+ self._benchmark(loader(), num_iter, warmup, "Dataset Alone")
+
+ def benchmark_mapper(self, num_iter, warmup=5):
+ """
+ Benchmark the speed of taking raw samples from the dataset and map
+ them in a single process.
+ """
+
+ def loader():
+ while True:
+ for k in self.sampler:
+ yield self.mapper(self.dataset[k])
+
+ self._benchmark(loader(), num_iter, warmup, "Single Process Mapper (sec/sample)")
+
+ def benchmark_workers(self, num_iter, warmup=10):
+ """
+ Benchmark the dataloader by tuning num_workers to [0, 1, self.num_workers].
+ """
+ candidates = [0, 1]
+ if self.num_workers not in candidates:
+ candidates.append(self.num_workers)
+
+ dataset = MapDataset(self.dataset, self.mapper)
+ for n in candidates:
+ loader = build_batch_data_loader(
+ dataset,
+ self.sampler,
+ self.total_batch_size,
+ num_workers=n,
+ )
+ self._benchmark(
+ iter(loader),
+ num_iter * max(n, 1),
+ warmup * max(n, 1),
+ f"DataLoader ({n} workers, bs={self.per_gpu_batch_size})",
+ )
+ del loader
+
+ def benchmark_IPC(self, num_iter, warmup=10):
+ """
+ Benchmark the dataloader where each worker outputs nothing. This
+ eliminates the IPC overhead compared to the regular dataloader.
+
+ PyTorch multiprocessing's IPC only optimizes for torch tensors.
+ Large numpy arrays or other data structure may incur large IPC overhead.
+ """
+ n = self.num_workers
+ dataset = _EmptyMapDataset(MapDataset(self.dataset, self.mapper))
+ loader = build_batch_data_loader(
+ dataset, self.sampler, self.total_batch_size, num_workers=n
+ )
+ self._benchmark(
+ iter(loader),
+ num_iter * max(n, 1),
+ warmup * max(n, 1),
+ f"DataLoader ({n} workers, bs={self.per_gpu_batch_size}) w/o comm",
+ )
+
+ def benchmark_distributed(self, num_iter, warmup=10):
+ """
+ Benchmark the dataloader in each distributed worker, and log results of
+ all workers. This helps understand the final performance as well as
+ the variances among workers.
+
+ It also prints startup time (first iter) of the dataloader.
+ """
+ gpu = comm.get_world_size()
+ dataset = MapDataset(self.dataset, self.mapper)
+ n = self.num_workers
+ loader = build_batch_data_loader(
+ dataset, self.sampler, self.total_batch_size, num_workers=n
+ )
+
+ timer = Timer()
+ loader = iter(loader)
+ next(loader)
+ startup_time = timer.seconds()
+ logger.info("Dataloader startup time: {:.2f} seconds".format(startup_time))
+
+ comm.synchronize()
+
+ avg, all_times = self._benchmark(loader, num_iter * max(n, 1), warmup * max(n, 1))
+ del loader
+ self._log_time(
+ f"DataLoader ({gpu} GPUs x {n} workers, total bs={self.total_batch_size})",
+ avg,
+ all_times,
+ True,
+ )
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/build.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..d03137a9aabfc4a056dd671d4c3d0ba6f349fe03
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/build.py
@@ -0,0 +1,556 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import itertools
+import logging
+import numpy as np
+import operator
+import pickle
+from typing import Any, Callable, Dict, List, Optional, Union
+import torch
+import torch.utils.data as torchdata
+from tabulate import tabulate
+from termcolor import colored
+
+from annotator.oneformer.detectron2.config import configurable
+from annotator.oneformer.detectron2.structures import BoxMode
+from annotator.oneformer.detectron2.utils.comm import get_world_size
+from annotator.oneformer.detectron2.utils.env import seed_all_rng
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+from annotator.oneformer.detectron2.utils.logger import _log_api_usage, log_first_n
+
+from .catalog import DatasetCatalog, MetadataCatalog
+from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset
+from .dataset_mapper import DatasetMapper
+from .detection_utils import check_metadata_consistency
+from .samplers import (
+ InferenceSampler,
+ RandomSubsetTrainingSampler,
+ RepeatFactorTrainingSampler,
+ TrainingSampler,
+)
+
+"""
+This file contains the default logic to build a dataloader for training or testing.
+"""
+
+__all__ = [
+ "build_batch_data_loader",
+ "build_detection_train_loader",
+ "build_detection_test_loader",
+ "get_detection_dataset_dicts",
+ "load_proposals_into_dataset",
+ "print_instances_class_histogram",
+]
+
+
+def filter_images_with_only_crowd_annotations(dataset_dicts):
+ """
+ Filter out images with none annotations or only crowd annotations
+ (i.e., images without non-crowd annotations).
+ A common training-time preprocessing on COCO dataset.
+
+ Args:
+ dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
+
+ Returns:
+ list[dict]: the same format, but filtered.
+ """
+ num_before = len(dataset_dicts)
+
+ def valid(anns):
+ for ann in anns:
+ if ann.get("iscrowd", 0) == 0:
+ return True
+ return False
+
+ dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])]
+ num_after = len(dataset_dicts)
+ logger = logging.getLogger(__name__)
+ logger.info(
+ "Removed {} images with no usable annotations. {} images left.".format(
+ num_before - num_after, num_after
+ )
+ )
+ return dataset_dicts
+
+
+def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image):
+ """
+ Filter out images with too few number of keypoints.
+
+ Args:
+ dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
+
+ Returns:
+ list[dict]: the same format as dataset_dicts, but filtered.
+ """
+ num_before = len(dataset_dicts)
+
+ def visible_keypoints_in_image(dic):
+ # Each keypoints field has the format [x1, y1, v1, ...], where v is visibility
+ annotations = dic["annotations"]
+ return sum(
+ (np.array(ann["keypoints"][2::3]) > 0).sum()
+ for ann in annotations
+ if "keypoints" in ann
+ )
+
+ dataset_dicts = [
+ x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image
+ ]
+ num_after = len(dataset_dicts)
+ logger = logging.getLogger(__name__)
+ logger.info(
+ "Removed {} images with fewer than {} keypoints.".format(
+ num_before - num_after, min_keypoints_per_image
+ )
+ )
+ return dataset_dicts
+
+
+def load_proposals_into_dataset(dataset_dicts, proposal_file):
+ """
+ Load precomputed object proposals into the dataset.
+
+ The proposal file should be a pickled dict with the following keys:
+
+ - "ids": list[int] or list[str], the image ids
+ - "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id
+ - "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores
+ corresponding to the boxes.
+ - "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``.
+
+ Args:
+ dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
+ proposal_file (str): file path of pre-computed proposals, in pkl format.
+
+ Returns:
+ list[dict]: the same format as dataset_dicts, but added proposal field.
+ """
+ logger = logging.getLogger(__name__)
+ logger.info("Loading proposals from: {}".format(proposal_file))
+
+ with PathManager.open(proposal_file, "rb") as f:
+ proposals = pickle.load(f, encoding="latin1")
+
+ # Rename the key names in D1 proposal files
+ rename_keys = {"indexes": "ids", "scores": "objectness_logits"}
+ for key in rename_keys:
+ if key in proposals:
+ proposals[rename_keys[key]] = proposals.pop(key)
+
+ # Fetch the indexes of all proposals that are in the dataset
+ # Convert image_id to str since they could be int.
+ img_ids = set({str(record["image_id"]) for record in dataset_dicts})
+ id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids}
+
+ # Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS'
+ bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS
+
+ for record in dataset_dicts:
+ # Get the index of the proposal
+ i = id_to_index[str(record["image_id"])]
+
+ boxes = proposals["boxes"][i]
+ objectness_logits = proposals["objectness_logits"][i]
+ # Sort the proposals in descending order of the scores
+ inds = objectness_logits.argsort()[::-1]
+ record["proposal_boxes"] = boxes[inds]
+ record["proposal_objectness_logits"] = objectness_logits[inds]
+ record["proposal_bbox_mode"] = bbox_mode
+
+ return dataset_dicts
+
+
+def print_instances_class_histogram(dataset_dicts, class_names):
+ """
+ Args:
+ dataset_dicts (list[dict]): list of dataset dicts.
+ class_names (list[str]): list of class names (zero-indexed).
+ """
+ num_classes = len(class_names)
+ hist_bins = np.arange(num_classes + 1)
+ histogram = np.zeros((num_classes,), dtype=np.int)
+ for entry in dataset_dicts:
+ annos = entry["annotations"]
+ classes = np.asarray(
+ [x["category_id"] for x in annos if not x.get("iscrowd", 0)], dtype=np.int
+ )
+ if len(classes):
+ assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}"
+ assert (
+ classes.max() < num_classes
+ ), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes"
+ histogram += np.histogram(classes, bins=hist_bins)[0]
+
+ N_COLS = min(6, len(class_names) * 2)
+
+ def short_name(x):
+ # make long class names shorter. useful for lvis
+ if len(x) > 13:
+ return x[:11] + ".."
+ return x
+
+ data = list(
+ itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)])
+ )
+ total_num_instances = sum(data[1::2])
+ data.extend([None] * (N_COLS - (len(data) % N_COLS)))
+ if num_classes > 1:
+ data.extend(["total", total_num_instances])
+ data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)])
+ table = tabulate(
+ data,
+ headers=["category", "#instances"] * (N_COLS // 2),
+ tablefmt="pipe",
+ numalign="left",
+ stralign="center",
+ )
+ log_first_n(
+ logging.INFO,
+ "Distribution of instances among all {} categories:\n".format(num_classes)
+ + colored(table, "cyan"),
+ key="message",
+ )
+
+
+def get_detection_dataset_dicts(
+ names,
+ filter_empty=True,
+ min_keypoints=0,
+ proposal_files=None,
+ check_consistency=True,
+):
+ """
+ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
+
+ Args:
+ names (str or list[str]): a dataset name or a list of dataset names
+ filter_empty (bool): whether to filter out images without instance annotations
+ min_keypoints (int): filter out images with fewer keypoints than
+ `min_keypoints`. Set to 0 to do nothing.
+ proposal_files (list[str]): if given, a list of object proposal files
+ that match each dataset in `names`.
+ check_consistency (bool): whether to check if datasets have consistent metadata.
+
+ Returns:
+ list[dict]: a list of dicts following the standard dataset dict format.
+ """
+ if isinstance(names, str):
+ names = [names]
+ assert len(names), names
+ dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]
+
+ if isinstance(dataset_dicts[0], torchdata.Dataset):
+ if len(dataset_dicts) > 1:
+ # ConcatDataset does not work for iterable style dataset.
+ # We could support concat for iterable as well, but it's often
+ # not a good idea to concat iterables anyway.
+ return torchdata.ConcatDataset(dataset_dicts)
+ return dataset_dicts[0]
+
+ for dataset_name, dicts in zip(names, dataset_dicts):
+ assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
+
+ if proposal_files is not None:
+ assert len(names) == len(proposal_files)
+ # load precomputed proposals from proposal files
+ dataset_dicts = [
+ load_proposals_into_dataset(dataset_i_dicts, proposal_file)
+ for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
+ ]
+
+ dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
+
+ has_instances = "annotations" in dataset_dicts[0]
+ if filter_empty and has_instances:
+ dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
+ if min_keypoints > 0 and has_instances:
+ dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
+
+ if check_consistency and has_instances:
+ try:
+ class_names = MetadataCatalog.get(names[0]).thing_classes
+ check_metadata_consistency("thing_classes", names)
+ print_instances_class_histogram(dataset_dicts, class_names)
+ except AttributeError: # class names are not available for this dataset
+ pass
+
+ assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names))
+ return dataset_dicts
+
+
+def build_batch_data_loader(
+ dataset,
+ sampler,
+ total_batch_size,
+ *,
+ aspect_ratio_grouping=False,
+ num_workers=0,
+ collate_fn=None,
+):
+ """
+ Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are:
+ 1. support aspect ratio grouping options
+ 2. use no "batch collation", because this is common for detection training
+
+ Args:
+ dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset.
+ sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices.
+ Must be provided iff. ``dataset`` is a map-style dataset.
+ total_batch_size, aspect_ratio_grouping, num_workers, collate_fn: see
+ :func:`build_detection_train_loader`.
+
+ Returns:
+ iterable[list]. Length of each list is the batch size of the current
+ GPU. Each element in the list comes from the dataset.
+ """
+ world_size = get_world_size()
+ assert (
+ total_batch_size > 0 and total_batch_size % world_size == 0
+ ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
+ total_batch_size, world_size
+ )
+ batch_size = total_batch_size // world_size
+
+ if isinstance(dataset, torchdata.IterableDataset):
+ assert sampler is None, "sampler must be None if dataset is IterableDataset"
+ else:
+ dataset = ToIterableDataset(dataset, sampler)
+
+ if aspect_ratio_grouping:
+ data_loader = torchdata.DataLoader(
+ dataset,
+ num_workers=num_workers,
+ collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
+ worker_init_fn=worker_init_reset_seed,
+ ) # yield individual mapped dict
+ data_loader = AspectRatioGroupedDataset(data_loader, batch_size)
+ if collate_fn is None:
+ return data_loader
+ return MapDataset(data_loader, collate_fn)
+ else:
+ return torchdata.DataLoader(
+ dataset,
+ batch_size=batch_size,
+ drop_last=True,
+ num_workers=num_workers,
+ collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
+ worker_init_fn=worker_init_reset_seed,
+ )
+
+
+def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
+ if dataset is None:
+ dataset = get_detection_dataset_dicts(
+ cfg.DATASETS.TRAIN,
+ filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
+ min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
+ if cfg.MODEL.KEYPOINT_ON
+ else 0,
+ proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
+ )
+ _log_api_usage("dataset." + cfg.DATASETS.TRAIN[0])
+
+ if mapper is None:
+ mapper = DatasetMapper(cfg, True)
+
+ if sampler is None:
+ sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
+ logger = logging.getLogger(__name__)
+ if isinstance(dataset, torchdata.IterableDataset):
+ logger.info("Not using any sampler since the dataset is IterableDataset.")
+ sampler = None
+ else:
+ logger.info("Using training sampler {}".format(sampler_name))
+ if sampler_name == "TrainingSampler":
+ sampler = TrainingSampler(len(dataset))
+ elif sampler_name == "RepeatFactorTrainingSampler":
+ repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
+ dataset, cfg.DATALOADER.REPEAT_THRESHOLD
+ )
+ sampler = RepeatFactorTrainingSampler(repeat_factors)
+ elif sampler_name == "RandomSubsetTrainingSampler":
+ sampler = RandomSubsetTrainingSampler(
+ len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO
+ )
+ else:
+ raise ValueError("Unknown training sampler: {}".format(sampler_name))
+
+ return {
+ "dataset": dataset,
+ "sampler": sampler,
+ "mapper": mapper,
+ "total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
+ "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
+ }
+
+
+@configurable(from_config=_train_loader_from_config)
+def build_detection_train_loader(
+ dataset,
+ *,
+ mapper,
+ sampler=None,
+ total_batch_size,
+ aspect_ratio_grouping=True,
+ num_workers=0,
+ collate_fn=None,
+):
+ """
+ Build a dataloader for object detection with some default features.
+
+ Args:
+ dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
+ or a pytorch dataset (either map-style or iterable). It can be obtained
+ by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
+ mapper (callable): a callable which takes a sample (dict) from dataset and
+ returns the format to be consumed by the model.
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
+ sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
+ indices to be applied on ``dataset``.
+ If ``dataset`` is map-style, the default sampler is a :class:`TrainingSampler`,
+ which coordinates an infinite random shuffle sequence across all workers.
+ Sampler must be None if ``dataset`` is iterable.
+ total_batch_size (int): total batch size across all workers.
+ aspect_ratio_grouping (bool): whether to group images with similar
+ aspect ratio for efficiency. When enabled, it requires each
+ element in dataset be a dict with keys "width" and "height".
+ num_workers (int): number of parallel data loading workers
+ collate_fn: a function that determines how to do batching, same as the argument of
+ `torch.utils.data.DataLoader`. Defaults to do no collation and return a list of
+ data. No collation is OK for small batch size and simple data structures.
+ If your batch size is large and each sample contains too many small tensors,
+ it's more efficient to collate them in data loader.
+
+ Returns:
+ torch.utils.data.DataLoader:
+ a dataloader. Each output from it is a ``list[mapped_element]`` of length
+ ``total_batch_size / num_workers``, where ``mapped_element`` is produced
+ by the ``mapper``.
+ """
+ if isinstance(dataset, list):
+ dataset = DatasetFromList(dataset, copy=False)
+ if mapper is not None:
+ dataset = MapDataset(dataset, mapper)
+
+ if isinstance(dataset, torchdata.IterableDataset):
+ assert sampler is None, "sampler must be None if dataset is IterableDataset"
+ else:
+ if sampler is None:
+ sampler = TrainingSampler(len(dataset))
+ assert isinstance(sampler, torchdata.Sampler), f"Expect a Sampler but got {type(sampler)}"
+ return build_batch_data_loader(
+ dataset,
+ sampler,
+ total_batch_size,
+ aspect_ratio_grouping=aspect_ratio_grouping,
+ num_workers=num_workers,
+ collate_fn=collate_fn,
+ )
+
+
+def _test_loader_from_config(cfg, dataset_name, mapper=None):
+ """
+ Uses the given `dataset_name` argument (instead of the names in cfg), because the
+ standard practice is to evaluate each test set individually (not combining them).
+ """
+ if isinstance(dataset_name, str):
+ dataset_name = [dataset_name]
+
+ dataset = get_detection_dataset_dicts(
+ dataset_name,
+ filter_empty=False,
+ proposal_files=[
+ cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name
+ ]
+ if cfg.MODEL.LOAD_PROPOSALS
+ else None,
+ )
+ if mapper is None:
+ mapper = DatasetMapper(cfg, False)
+ return {
+ "dataset": dataset,
+ "mapper": mapper,
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
+ "sampler": InferenceSampler(len(dataset))
+ if not isinstance(dataset, torchdata.IterableDataset)
+ else None,
+ }
+
+
+@configurable(from_config=_test_loader_from_config)
+def build_detection_test_loader(
+ dataset: Union[List[Any], torchdata.Dataset],
+ *,
+ mapper: Callable[[Dict[str, Any]], Any],
+ sampler: Optional[torchdata.Sampler] = None,
+ batch_size: int = 1,
+ num_workers: int = 0,
+ collate_fn: Optional[Callable[[List[Any]], Any]] = None,
+) -> torchdata.DataLoader:
+ """
+ Similar to `build_detection_train_loader`, with default batch size = 1,
+ and sampler = :class:`InferenceSampler`. This sampler coordinates all workers
+ to produce the exact set of all samples.
+
+ Args:
+ dataset: a list of dataset dicts,
+ or a pytorch dataset (either map-style or iterable). They can be obtained
+ by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
+ mapper: a callable which takes a sample (dict) from dataset
+ and returns the format to be consumed by the model.
+ When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
+ sampler: a sampler that produces
+ indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
+ which splits the dataset across all workers. Sampler must be None
+ if `dataset` is iterable.
+ batch_size: the batch size of the data loader to be created.
+ Default to 1 image per worker since this is the standard when reporting
+ inference time in papers.
+ num_workers: number of parallel data loading workers
+ collate_fn: same as the argument of `torch.utils.data.DataLoader`.
+ Defaults to do no collation and return a list of data.
+
+ Returns:
+ DataLoader: a torch DataLoader, that loads the given detection
+ dataset, with test-time transformation and batching.
+
+ Examples:
+ ::
+ data_loader = build_detection_test_loader(
+ DatasetRegistry.get("my_test"),
+ mapper=DatasetMapper(...))
+
+ # or, instantiate with a CfgNode:
+ data_loader = build_detection_test_loader(cfg, "my_test")
+ """
+ if isinstance(dataset, list):
+ dataset = DatasetFromList(dataset, copy=False)
+ if mapper is not None:
+ dataset = MapDataset(dataset, mapper)
+ if isinstance(dataset, torchdata.IterableDataset):
+ assert sampler is None, "sampler must be None if dataset is IterableDataset"
+ else:
+ if sampler is None:
+ sampler = InferenceSampler(len(dataset))
+ return torchdata.DataLoader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ drop_last=False,
+ num_workers=num_workers,
+ collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
+ )
+
+
+def trivial_batch_collator(batch):
+ """
+ A batch collator that does nothing.
+ """
+ return batch
+
+
+def worker_init_reset_seed(worker_id):
+ initial_seed = torch.initial_seed() % 2**31
+ seed_all_rng(initial_seed + worker_id)
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/catalog.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/catalog.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f5209b5583d01258437bdc9b52a3dd716bdbbf6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/catalog.py
@@ -0,0 +1,236 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import copy
+import logging
+import types
+from collections import UserDict
+from typing import List
+
+from annotator.oneformer.detectron2.utils.logger import log_first_n
+
+__all__ = ["DatasetCatalog", "MetadataCatalog", "Metadata"]
+
+
+class _DatasetCatalog(UserDict):
+ """
+ A global dictionary that stores information about the datasets and how to obtain them.
+
+ It contains a mapping from strings
+ (which are names that identify a dataset, e.g. "coco_2014_train")
+ to a function which parses the dataset and returns the samples in the
+ format of `list[dict]`.
+
+ The returned dicts should be in Detectron2 Dataset format (See DATASETS.md for details)
+ if used with the data loader functionalities in `data/build.py,data/detection_transform.py`.
+
+ The purpose of having this catalog is to make it easy to choose
+ different datasets, by just using the strings in the config.
+ """
+
+ def register(self, name, func):
+ """
+ Args:
+ name (str): the name that identifies a dataset, e.g. "coco_2014_train".
+ func (callable): a callable which takes no arguments and returns a list of dicts.
+ It must return the same results if called multiple times.
+ """
+ assert callable(func), "You must register a function with `DatasetCatalog.register`!"
+ assert name not in self, "Dataset '{}' is already registered!".format(name)
+ self[name] = func
+
+ def get(self, name):
+ """
+ Call the registered function and return its results.
+
+ Args:
+ name (str): the name that identifies a dataset, e.g. "coco_2014_train".
+
+ Returns:
+ list[dict]: dataset annotations.
+ """
+ try:
+ f = self[name]
+ except KeyError as e:
+ raise KeyError(
+ "Dataset '{}' is not registered! Available datasets are: {}".format(
+ name, ", ".join(list(self.keys()))
+ )
+ ) from e
+ return f()
+
+ def list(self) -> List[str]:
+ """
+ List all registered datasets.
+
+ Returns:
+ list[str]
+ """
+ return list(self.keys())
+
+ def remove(self, name):
+ """
+ Alias of ``pop``.
+ """
+ self.pop(name)
+
+ def __str__(self):
+ return "DatasetCatalog(registered datasets: {})".format(", ".join(self.keys()))
+
+ __repr__ = __str__
+
+
+DatasetCatalog = _DatasetCatalog()
+DatasetCatalog.__doc__ = (
+ _DatasetCatalog.__doc__
+ + """
+ .. automethod:: detectron2.data.catalog.DatasetCatalog.register
+ .. automethod:: detectron2.data.catalog.DatasetCatalog.get
+"""
+)
+
+
+class Metadata(types.SimpleNamespace):
+ """
+ A class that supports simple attribute setter/getter.
+ It is intended for storing metadata of a dataset and make it accessible globally.
+
+ Examples:
+ ::
+ # somewhere when you load the data:
+ MetadataCatalog.get("mydataset").thing_classes = ["person", "dog"]
+
+ # somewhere when you print statistics or visualize:
+ classes = MetadataCatalog.get("mydataset").thing_classes
+ """
+
+ # the name of the dataset
+ # set default to N/A so that `self.name` in the errors will not trigger getattr again
+ name: str = "N/A"
+
+ _RENAMED = {
+ "class_names": "thing_classes",
+ "dataset_id_to_contiguous_id": "thing_dataset_id_to_contiguous_id",
+ "stuff_class_names": "stuff_classes",
+ }
+
+ def __getattr__(self, key):
+ if key in self._RENAMED:
+ log_first_n(
+ logging.WARNING,
+ "Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]),
+ n=10,
+ )
+ return getattr(self, self._RENAMED[key])
+
+ # "name" exists in every metadata
+ if len(self.__dict__) > 1:
+ raise AttributeError(
+ "Attribute '{}' does not exist in the metadata of dataset '{}'. Available "
+ "keys are {}.".format(key, self.name, str(self.__dict__.keys()))
+ )
+ else:
+ raise AttributeError(
+ f"Attribute '{key}' does not exist in the metadata of dataset '{self.name}': "
+ "metadata is empty."
+ )
+
+ def __setattr__(self, key, val):
+ if key in self._RENAMED:
+ log_first_n(
+ logging.WARNING,
+ "Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]),
+ n=10,
+ )
+ setattr(self, self._RENAMED[key], val)
+
+ # Ensure that metadata of the same name stays consistent
+ try:
+ oldval = getattr(self, key)
+ assert oldval == val, (
+ "Attribute '{}' in the metadata of '{}' cannot be set "
+ "to a different value!\n{} != {}".format(key, self.name, oldval, val)
+ )
+ except AttributeError:
+ super().__setattr__(key, val)
+
+ def as_dict(self):
+ """
+ Returns all the metadata as a dict.
+ Note that modifications to the returned dict will not reflect on the Metadata object.
+ """
+ return copy.copy(self.__dict__)
+
+ def set(self, **kwargs):
+ """
+ Set multiple metadata with kwargs.
+ """
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+ return self
+
+ def get(self, key, default=None):
+ """
+ Access an attribute and return its value if exists.
+ Otherwise return default.
+ """
+ try:
+ return getattr(self, key)
+ except AttributeError:
+ return default
+
+
+class _MetadataCatalog(UserDict):
+ """
+ MetadataCatalog is a global dictionary that provides access to
+ :class:`Metadata` of a given dataset.
+
+ The metadata associated with a certain name is a singleton: once created, the
+ metadata will stay alive and will be returned by future calls to ``get(name)``.
+
+ It's like global variables, so don't abuse it.
+ It's meant for storing knowledge that's constant and shared across the execution
+ of the program, e.g.: the class names in COCO.
+ """
+
+ def get(self, name):
+ """
+ Args:
+ name (str): name of a dataset (e.g. coco_2014_train).
+
+ Returns:
+ Metadata: The :class:`Metadata` instance associated with this name,
+ or create an empty one if none is available.
+ """
+ assert len(name)
+ r = super().get(name, None)
+ if r is None:
+ r = self[name] = Metadata(name=name)
+ return r
+
+ def list(self):
+ """
+ List all registered metadata.
+
+ Returns:
+ list[str]: keys (names of datasets) of all registered metadata
+ """
+ return list(self.keys())
+
+ def remove(self, name):
+ """
+ Alias of ``pop``.
+ """
+ self.pop(name)
+
+ def __str__(self):
+ return "MetadataCatalog(registered metadata: {})".format(", ".join(self.keys()))
+
+ __repr__ = __str__
+
+
+MetadataCatalog = _MetadataCatalog()
+MetadataCatalog.__doc__ = (
+ _MetadataCatalog.__doc__
+ + """
+ .. automethod:: detectron2.data.catalog.MetadataCatalog.get
+"""
+)
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/common.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa69a6a6546030aee818b195a0fbb399d5b776f6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/common.py
@@ -0,0 +1,301 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import contextlib
+import copy
+import itertools
+import logging
+import numpy as np
+import pickle
+import random
+from typing import Callable, Union
+import torch
+import torch.utils.data as data
+from torch.utils.data.sampler import Sampler
+
+from annotator.oneformer.detectron2.utils.serialize import PicklableWrapper
+
+__all__ = ["MapDataset", "DatasetFromList", "AspectRatioGroupedDataset", "ToIterableDataset"]
+
+logger = logging.getLogger(__name__)
+
+
+def _shard_iterator_dataloader_worker(iterable):
+ # Shard the iterable if we're currently inside pytorch dataloader worker.
+ worker_info = data.get_worker_info()
+ if worker_info is None or worker_info.num_workers == 1:
+ # do nothing
+ yield from iterable
+ else:
+ yield from itertools.islice(iterable, worker_info.id, None, worker_info.num_workers)
+
+
+class _MapIterableDataset(data.IterableDataset):
+ """
+ Map a function over elements in an IterableDataset.
+
+ Similar to pytorch's MapIterDataPipe, but support filtering when map_func
+ returns None.
+
+ This class is not public-facing. Will be called by `MapDataset`.
+ """
+
+ def __init__(self, dataset, map_func):
+ self._dataset = dataset
+ self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work
+
+ def __len__(self):
+ return len(self._dataset)
+
+ def __iter__(self):
+ for x in map(self._map_func, self._dataset):
+ if x is not None:
+ yield x
+
+
+class MapDataset(data.Dataset):
+ """
+ Map a function over the elements in a dataset.
+ """
+
+ def __init__(self, dataset, map_func):
+ """
+ Args:
+ dataset: a dataset where map function is applied. Can be either
+ map-style or iterable dataset. When given an iterable dataset,
+ the returned object will also be an iterable dataset.
+ map_func: a callable which maps the element in dataset. map_func can
+ return None to skip the data (e.g. in case of errors).
+ How None is handled depends on the style of `dataset`.
+ If `dataset` is map-style, it randomly tries other elements.
+ If `dataset` is iterable, it skips the data and tries the next.
+ """
+ self._dataset = dataset
+ self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work
+
+ self._rng = random.Random(42)
+ self._fallback_candidates = set(range(len(dataset)))
+
+ def __new__(cls, dataset, map_func):
+ is_iterable = isinstance(dataset, data.IterableDataset)
+ if is_iterable:
+ return _MapIterableDataset(dataset, map_func)
+ else:
+ return super().__new__(cls)
+
+ def __getnewargs__(self):
+ return self._dataset, self._map_func
+
+ def __len__(self):
+ return len(self._dataset)
+
+ def __getitem__(self, idx):
+ retry_count = 0
+ cur_idx = int(idx)
+
+ while True:
+ data = self._map_func(self._dataset[cur_idx])
+ if data is not None:
+ self._fallback_candidates.add(cur_idx)
+ return data
+
+ # _map_func fails for this idx, use a random new index from the pool
+ retry_count += 1
+ self._fallback_candidates.discard(cur_idx)
+ cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0]
+
+ if retry_count >= 3:
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ "Failed to apply `_map_func` for idx: {}, retry count: {}".format(
+ idx, retry_count
+ )
+ )
+
+
+class _TorchSerializedList(object):
+ """
+ A list-like object whose items are serialized and stored in a torch tensor. When
+ launching a process that uses TorchSerializedList with "fork" start method,
+ the subprocess can read the same buffer without triggering copy-on-access. When
+ launching a process that uses TorchSerializedList with "spawn/forkserver" start
+ method, the list will be pickled by a special ForkingPickler registered by PyTorch
+ that moves data to shared memory. In both cases, this allows parent and child
+ processes to share RAM for the list data, hence avoids the issue in
+ https://github.com/pytorch/pytorch/issues/13246.
+
+ See also https://ppwwyyxx.com/blog/2022/Demystify-RAM-Usage-in-Multiprocess-DataLoader/
+ on how it works.
+ """
+
+ def __init__(self, lst: list):
+ self._lst = lst
+
+ def _serialize(data):
+ buffer = pickle.dumps(data, protocol=-1)
+ return np.frombuffer(buffer, dtype=np.uint8)
+
+ logger.info(
+ "Serializing {} elements to byte tensors and concatenating them all ...".format(
+ len(self._lst)
+ )
+ )
+ self._lst = [_serialize(x) for x in self._lst]
+ self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64)
+ self._addr = torch.from_numpy(np.cumsum(self._addr))
+ self._lst = torch.from_numpy(np.concatenate(self._lst))
+ logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2))
+
+ def __len__(self):
+ return len(self._addr)
+
+ def __getitem__(self, idx):
+ start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
+ end_addr = self._addr[idx].item()
+ bytes = memoryview(self._lst[start_addr:end_addr].numpy())
+
+ # @lint-ignore PYTHONPICKLEISBAD
+ return pickle.loads(bytes)
+
+
+_DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = _TorchSerializedList
+
+
+@contextlib.contextmanager
+def set_default_dataset_from_list_serialize_method(new):
+ """
+ Context manager for using custom serialize function when creating DatasetFromList
+ """
+
+ global _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD
+ orig = _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD
+ _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = new
+ yield
+ _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = orig
+
+
+class DatasetFromList(data.Dataset):
+ """
+ Wrap a list to a torch Dataset. It produces elements of the list as data.
+ """
+
+ def __init__(
+ self,
+ lst: list,
+ copy: bool = True,
+ serialize: Union[bool, Callable] = True,
+ ):
+ """
+ Args:
+ lst (list): a list which contains elements to produce.
+ copy (bool): whether to deepcopy the element when producing it,
+ so that the result can be modified in place without affecting the
+ source in the list.
+ serialize (bool or callable): whether to serialize the stroage to other
+ backend. If `True`, the default serialize method will be used, if given
+ a callable, the callable will be used as serialize method.
+ """
+ self._lst = lst
+ self._copy = copy
+ if not isinstance(serialize, (bool, Callable)):
+ raise TypeError(f"Unsupported type for argument `serailzie`: {serialize}")
+ self._serialize = serialize is not False
+
+ if self._serialize:
+ serialize_method = (
+ serialize
+ if isinstance(serialize, Callable)
+ else _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD
+ )
+ logger.info(f"Serializing the dataset using: {serialize_method}")
+ self._lst = serialize_method(self._lst)
+
+ def __len__(self):
+ return len(self._lst)
+
+ def __getitem__(self, idx):
+ if self._copy and not self._serialize:
+ return copy.deepcopy(self._lst[idx])
+ else:
+ return self._lst[idx]
+
+
+class ToIterableDataset(data.IterableDataset):
+ """
+ Convert an old indices-based (also called map-style) dataset
+ to an iterable-style dataset.
+ """
+
+ def __init__(self, dataset: data.Dataset, sampler: Sampler, shard_sampler: bool = True):
+ """
+ Args:
+ dataset: an old-style dataset with ``__getitem__``
+ sampler: a cheap iterable that produces indices to be applied on ``dataset``.
+ shard_sampler: whether to shard the sampler based on the current pytorch data loader
+ worker id. When an IterableDataset is forked by pytorch's DataLoader into multiple
+ workers, it is responsible for sharding its data based on worker id so that workers
+ don't produce identical data.
+
+ Most samplers (like our TrainingSampler) do not shard based on dataloader worker id
+ and this argument should be set to True. But certain samplers may be already
+ sharded, in that case this argument should be set to False.
+ """
+ assert not isinstance(dataset, data.IterableDataset), dataset
+ assert isinstance(sampler, Sampler), sampler
+ self.dataset = dataset
+ self.sampler = sampler
+ self.shard_sampler = shard_sampler
+
+ def __iter__(self):
+ if not self.shard_sampler:
+ sampler = self.sampler
+ else:
+ # With map-style dataset, `DataLoader(dataset, sampler)` runs the
+ # sampler in main process only. But `DataLoader(ToIterableDataset(dataset, sampler))`
+ # will run sampler in every of the N worker. So we should only keep 1/N of the ids on
+ # each worker. The assumption is that sampler is cheap to iterate so it's fine to
+ # discard ids in workers.
+ sampler = _shard_iterator_dataloader_worker(self.sampler)
+ for idx in sampler:
+ yield self.dataset[idx]
+
+ def __len__(self):
+ return len(self.sampler)
+
+
+class AspectRatioGroupedDataset(data.IterableDataset):
+ """
+ Batch data that have similar aspect ratio together.
+ In this implementation, images whose aspect ratio < (or >) 1 will
+ be batched together.
+ This improves training speed because the images then need less padding
+ to form a batch.
+
+ It assumes the underlying dataset produces dicts with "width" and "height" keys.
+ It will then produce a list of original dicts with length = batch_size,
+ all with similar aspect ratios.
+ """
+
+ def __init__(self, dataset, batch_size):
+ """
+ Args:
+ dataset: an iterable. Each element must be a dict with keys
+ "width" and "height", which will be used to batch data.
+ batch_size (int):
+ """
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self._buckets = [[] for _ in range(2)]
+ # Hard-coded two aspect ratio groups: w > h and w < h.
+ # Can add support for more aspect ratio groups, but doesn't seem useful
+
+ def __iter__(self):
+ for d in self.dataset:
+ w, h = d["width"], d["height"]
+ bucket_id = 0 if w > h else 1
+ bucket = self._buckets[bucket_id]
+ bucket.append(d)
+ if len(bucket) == self.batch_size:
+ data = bucket[:]
+ # Clear bucket first, because code after yield is not
+ # guaranteed to execute
+ del bucket[:]
+ yield data
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/dataset_mapper.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/dataset_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bb6bb1057a68bfb12e55872f391065f02023ed3
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/dataset_mapper.py
@@ -0,0 +1,191 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import copy
+import logging
+import numpy as np
+from typing import List, Optional, Union
+import torch
+
+from annotator.oneformer.detectron2.config import configurable
+
+from . import detection_utils as utils
+from . import transforms as T
+
+"""
+This file contains the default mapping that's applied to "dataset dicts".
+"""
+
+__all__ = ["DatasetMapper"]
+
+
+class DatasetMapper:
+ """
+ A callable which takes a dataset dict in Detectron2 Dataset format,
+ and map it into a format used by the model.
+
+ This is the default callable to be used to map your dataset dict into training data.
+ You may need to follow it to implement your own one for customized logic,
+ such as a different way to read or transform images.
+ See :doc:`/tutorials/data_loading` for details.
+
+ The callable currently does the following:
+
+ 1. Read the image from "file_name"
+ 2. Applies cropping/geometric transforms to the image and annotations
+ 3. Prepare data and annotations to Tensor and :class:`Instances`
+ """
+
+ @configurable
+ def __init__(
+ self,
+ is_train: bool,
+ *,
+ augmentations: List[Union[T.Augmentation, T.Transform]],
+ image_format: str,
+ use_instance_mask: bool = False,
+ use_keypoint: bool = False,
+ instance_mask_format: str = "polygon",
+ keypoint_hflip_indices: Optional[np.ndarray] = None,
+ precomputed_proposal_topk: Optional[int] = None,
+ recompute_boxes: bool = False,
+ ):
+ """
+ NOTE: this interface is experimental.
+
+ Args:
+ is_train: whether it's used in training or inference
+ augmentations: a list of augmentations or deterministic transforms to apply
+ image_format: an image format supported by :func:`detection_utils.read_image`.
+ use_instance_mask: whether to process instance segmentation annotations, if available
+ use_keypoint: whether to process keypoint annotations if available
+ instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
+ masks into this format.
+ keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
+ precomputed_proposal_topk: if given, will load pre-computed
+ proposals from dataset_dict and keep the top k proposals for each image.
+ recompute_boxes: whether to overwrite bounding box annotations
+ by computing tight bounding boxes from instance mask annotations.
+ """
+ if recompute_boxes:
+ assert use_instance_mask, "recompute_boxes requires instance masks"
+ # fmt: off
+ self.is_train = is_train
+ self.augmentations = T.AugmentationList(augmentations)
+ self.image_format = image_format
+ self.use_instance_mask = use_instance_mask
+ self.instance_mask_format = instance_mask_format
+ self.use_keypoint = use_keypoint
+ self.keypoint_hflip_indices = keypoint_hflip_indices
+ self.proposal_topk = precomputed_proposal_topk
+ self.recompute_boxes = recompute_boxes
+ # fmt: on
+ logger = logging.getLogger(__name__)
+ mode = "training" if is_train else "inference"
+ logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
+
+ @classmethod
+ def from_config(cls, cfg, is_train: bool = True):
+ augs = utils.build_augmentation(cfg, is_train)
+ if cfg.INPUT.CROP.ENABLED and is_train:
+ augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
+ recompute_boxes = cfg.MODEL.MASK_ON
+ else:
+ recompute_boxes = False
+
+ ret = {
+ "is_train": is_train,
+ "augmentations": augs,
+ "image_format": cfg.INPUT.FORMAT,
+ "use_instance_mask": cfg.MODEL.MASK_ON,
+ "instance_mask_format": cfg.INPUT.MASK_FORMAT,
+ "use_keypoint": cfg.MODEL.KEYPOINT_ON,
+ "recompute_boxes": recompute_boxes,
+ }
+
+ if cfg.MODEL.KEYPOINT_ON:
+ ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
+
+ if cfg.MODEL.LOAD_PROPOSALS:
+ ret["precomputed_proposal_topk"] = (
+ cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
+ if is_train
+ else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
+ )
+ return ret
+
+ def _transform_annotations(self, dataset_dict, transforms, image_shape):
+ # USER: Modify this if you want to keep them for some reason.
+ for anno in dataset_dict["annotations"]:
+ if not self.use_instance_mask:
+ anno.pop("segmentation", None)
+ if not self.use_keypoint:
+ anno.pop("keypoints", None)
+
+ # USER: Implement additional transformations if you have other types of data
+ annos = [
+ utils.transform_instance_annotations(
+ obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
+ )
+ for obj in dataset_dict.pop("annotations")
+ if obj.get("iscrowd", 0) == 0
+ ]
+ instances = utils.annotations_to_instances(
+ annos, image_shape, mask_format=self.instance_mask_format
+ )
+
+ # After transforms such as cropping are applied, the bounding box may no longer
+ # tightly bound the object. As an example, imagine a triangle object
+ # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
+ # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
+ # the intersection of original bounding box and the cropping box.
+ if self.recompute_boxes:
+ instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
+ dataset_dict["instances"] = utils.filter_empty_instances(instances)
+
+ def __call__(self, dataset_dict):
+ """
+ Args:
+ dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
+
+ Returns:
+ dict: a format that builtin models in detectron2 accept
+ """
+ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
+ # USER: Write your own image loading if it's not from a file
+ image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
+ utils.check_image_size(dataset_dict, image)
+
+ # USER: Remove if you don't do semantic/panoptic segmentation.
+ if "sem_seg_file_name" in dataset_dict:
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
+ else:
+ sem_seg_gt = None
+
+ aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
+ transforms = self.augmentations(aug_input)
+ image, sem_seg_gt = aug_input.image, aug_input.sem_seg
+
+ image_shape = image.shape[:2] # h, w
+ # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
+ # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
+ # Therefore it's important to use torch.Tensor.
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
+ if sem_seg_gt is not None:
+ dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
+
+ # USER: Remove if you don't use pre-computed proposals.
+ # Most users would not need this feature.
+ if self.proposal_topk is not None:
+ utils.transform_proposals(
+ dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
+ )
+
+ if not self.is_train:
+ # USER: Modify this if you want to keep them for some reason.
+ dataset_dict.pop("annotations", None)
+ dataset_dict.pop("sem_seg_file_name", None)
+ return dataset_dict
+
+ if "annotations" in dataset_dict:
+ self._transform_annotations(dataset_dict, transforms, image_shape)
+
+ return dataset_dict
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/README.md b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9fb3e4f7afec17137c95c78be6ef06d520ec8032
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/README.md
@@ -0,0 +1,9 @@
+
+
+### Common Datasets
+
+The dataset implemented here do not need to load the data into the final format.
+It should provide the minimal data structure needed to use the dataset, so it can be very efficient.
+
+For example, for an image dataset, just provide the file names and labels, but don't read the images.
+Let the downstream decide how to read.
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/__init__.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a44bedc15e5f0e762fc4d77efd6f1b07c6ff77d0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from .coco import load_coco_json, load_sem_seg, register_coco_instances, convert_to_coco_json
+from .coco_panoptic import register_coco_panoptic, register_coco_panoptic_separated
+from .lvis import load_lvis_json, register_lvis_instances, get_lvis_instances_meta
+from .pascal_voc import load_voc_instances, register_pascal_voc
+from . import builtin as _builtin # ensure the builtin datasets are registered
+
+
+__all__ = [k for k in globals().keys() if not k.startswith("_")]
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/builtin.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/builtin.py
new file mode 100644
index 0000000000000000000000000000000000000000..39bbb1feec64f76705ba32c46f19f89f71be2ca7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/builtin.py
@@ -0,0 +1,259 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+
+"""
+This file registers pre-defined datasets at hard-coded paths, and their metadata.
+
+We hard-code metadata for common datasets. This will enable:
+1. Consistency check when loading the datasets
+2. Use models on these standard datasets directly and run demos,
+ without having to download the dataset annotations
+
+We hard-code some paths to the dataset that's assumed to
+exist in "./datasets/".
+
+Users SHOULD NOT use this file to create new dataset / metadata for new dataset.
+To add new dataset, refer to the tutorial "docs/DATASETS.md".
+"""
+
+import os
+
+from annotator.oneformer.detectron2.data import DatasetCatalog, MetadataCatalog
+
+from .builtin_meta import ADE20K_SEM_SEG_CATEGORIES, _get_builtin_metadata
+from .cityscapes import load_cityscapes_instances, load_cityscapes_semantic
+from .cityscapes_panoptic import register_all_cityscapes_panoptic
+from .coco import load_sem_seg, register_coco_instances
+from .coco_panoptic import register_coco_panoptic, register_coco_panoptic_separated
+from .lvis import get_lvis_instances_meta, register_lvis_instances
+from .pascal_voc import register_pascal_voc
+
+# ==== Predefined datasets and splits for COCO ==========
+
+_PREDEFINED_SPLITS_COCO = {}
+_PREDEFINED_SPLITS_COCO["coco"] = {
+ "coco_2014_train": ("coco/train2014", "coco/annotations/instances_train2014.json"),
+ "coco_2014_val": ("coco/val2014", "coco/annotations/instances_val2014.json"),
+ "coco_2014_minival": ("coco/val2014", "coco/annotations/instances_minival2014.json"),
+ "coco_2014_valminusminival": (
+ "coco/val2014",
+ "coco/annotations/instances_valminusminival2014.json",
+ ),
+ "coco_2017_train": ("coco/train2017", "coco/annotations/instances_train2017.json"),
+ "coco_2017_val": ("coco/val2017", "coco/annotations/instances_val2017.json"),
+ "coco_2017_test": ("coco/test2017", "coco/annotations/image_info_test2017.json"),
+ "coco_2017_test-dev": ("coco/test2017", "coco/annotations/image_info_test-dev2017.json"),
+ "coco_2017_val_100": ("coco/val2017", "coco/annotations/instances_val2017_100.json"),
+}
+
+_PREDEFINED_SPLITS_COCO["coco_person"] = {
+ "keypoints_coco_2014_train": (
+ "coco/train2014",
+ "coco/annotations/person_keypoints_train2014.json",
+ ),
+ "keypoints_coco_2014_val": ("coco/val2014", "coco/annotations/person_keypoints_val2014.json"),
+ "keypoints_coco_2014_minival": (
+ "coco/val2014",
+ "coco/annotations/person_keypoints_minival2014.json",
+ ),
+ "keypoints_coco_2014_valminusminival": (
+ "coco/val2014",
+ "coco/annotations/person_keypoints_valminusminival2014.json",
+ ),
+ "keypoints_coco_2017_train": (
+ "coco/train2017",
+ "coco/annotations/person_keypoints_train2017.json",
+ ),
+ "keypoints_coco_2017_val": ("coco/val2017", "coco/annotations/person_keypoints_val2017.json"),
+ "keypoints_coco_2017_val_100": (
+ "coco/val2017",
+ "coco/annotations/person_keypoints_val2017_100.json",
+ ),
+}
+
+
+_PREDEFINED_SPLITS_COCO_PANOPTIC = {
+ "coco_2017_train_panoptic": (
+ # This is the original panoptic annotation directory
+ "coco/panoptic_train2017",
+ "coco/annotations/panoptic_train2017.json",
+ # This directory contains semantic annotations that are
+ # converted from panoptic annotations.
+ # It is used by PanopticFPN.
+ # You can use the script at detectron2/datasets/prepare_panoptic_fpn.py
+ # to create these directories.
+ "coco/panoptic_stuff_train2017",
+ ),
+ "coco_2017_val_panoptic": (
+ "coco/panoptic_val2017",
+ "coco/annotations/panoptic_val2017.json",
+ "coco/panoptic_stuff_val2017",
+ ),
+ "coco_2017_val_100_panoptic": (
+ "coco/panoptic_val2017_100",
+ "coco/annotations/panoptic_val2017_100.json",
+ "coco/panoptic_stuff_val2017_100",
+ ),
+}
+
+
+def register_all_coco(root):
+ for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_COCO.items():
+ for key, (image_root, json_file) in splits_per_dataset.items():
+ # Assume pre-defined datasets live in `./datasets`.
+ register_coco_instances(
+ key,
+ _get_builtin_metadata(dataset_name),
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
+ os.path.join(root, image_root),
+ )
+
+ for (
+ prefix,
+ (panoptic_root, panoptic_json, semantic_root),
+ ) in _PREDEFINED_SPLITS_COCO_PANOPTIC.items():
+ prefix_instances = prefix[: -len("_panoptic")]
+ instances_meta = MetadataCatalog.get(prefix_instances)
+ image_root, instances_json = instances_meta.image_root, instances_meta.json_file
+ # The "separated" version of COCO panoptic segmentation dataset,
+ # e.g. used by Panoptic FPN
+ register_coco_panoptic_separated(
+ prefix,
+ _get_builtin_metadata("coco_panoptic_separated"),
+ image_root,
+ os.path.join(root, panoptic_root),
+ os.path.join(root, panoptic_json),
+ os.path.join(root, semantic_root),
+ instances_json,
+ )
+ # The "standard" version of COCO panoptic segmentation dataset,
+ # e.g. used by Panoptic-DeepLab
+ register_coco_panoptic(
+ prefix,
+ _get_builtin_metadata("coco_panoptic_standard"),
+ image_root,
+ os.path.join(root, panoptic_root),
+ os.path.join(root, panoptic_json),
+ instances_json,
+ )
+
+
+# ==== Predefined datasets and splits for LVIS ==========
+
+
+_PREDEFINED_SPLITS_LVIS = {
+ "lvis_v1": {
+ "lvis_v1_train": ("coco/", "lvis/lvis_v1_train.json"),
+ "lvis_v1_val": ("coco/", "lvis/lvis_v1_val.json"),
+ "lvis_v1_test_dev": ("coco/", "lvis/lvis_v1_image_info_test_dev.json"),
+ "lvis_v1_test_challenge": ("coco/", "lvis/lvis_v1_image_info_test_challenge.json"),
+ },
+ "lvis_v0.5": {
+ "lvis_v0.5_train": ("coco/", "lvis/lvis_v0.5_train.json"),
+ "lvis_v0.5_val": ("coco/", "lvis/lvis_v0.5_val.json"),
+ "lvis_v0.5_val_rand_100": ("coco/", "lvis/lvis_v0.5_val_rand_100.json"),
+ "lvis_v0.5_test": ("coco/", "lvis/lvis_v0.5_image_info_test.json"),
+ },
+ "lvis_v0.5_cocofied": {
+ "lvis_v0.5_train_cocofied": ("coco/", "lvis/lvis_v0.5_train_cocofied.json"),
+ "lvis_v0.5_val_cocofied": ("coco/", "lvis/lvis_v0.5_val_cocofied.json"),
+ },
+}
+
+
+def register_all_lvis(root):
+ for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_LVIS.items():
+ for key, (image_root, json_file) in splits_per_dataset.items():
+ register_lvis_instances(
+ key,
+ get_lvis_instances_meta(dataset_name),
+ os.path.join(root, json_file) if "://" not in json_file else json_file,
+ os.path.join(root, image_root),
+ )
+
+
+# ==== Predefined splits for raw cityscapes images ===========
+_RAW_CITYSCAPES_SPLITS = {
+ "cityscapes_fine_{task}_train": ("cityscapes/leftImg8bit/train/", "cityscapes/gtFine/train/"),
+ "cityscapes_fine_{task}_val": ("cityscapes/leftImg8bit/val/", "cityscapes/gtFine/val/"),
+ "cityscapes_fine_{task}_test": ("cityscapes/leftImg8bit/test/", "cityscapes/gtFine/test/"),
+}
+
+
+def register_all_cityscapes(root):
+ for key, (image_dir, gt_dir) in _RAW_CITYSCAPES_SPLITS.items():
+ meta = _get_builtin_metadata("cityscapes")
+ image_dir = os.path.join(root, image_dir)
+ gt_dir = os.path.join(root, gt_dir)
+
+ inst_key = key.format(task="instance_seg")
+ DatasetCatalog.register(
+ inst_key,
+ lambda x=image_dir, y=gt_dir: load_cityscapes_instances(
+ x, y, from_json=True, to_polygons=True
+ ),
+ )
+ MetadataCatalog.get(inst_key).set(
+ image_dir=image_dir, gt_dir=gt_dir, evaluator_type="cityscapes_instance", **meta
+ )
+
+ sem_key = key.format(task="sem_seg")
+ DatasetCatalog.register(
+ sem_key, lambda x=image_dir, y=gt_dir: load_cityscapes_semantic(x, y)
+ )
+ MetadataCatalog.get(sem_key).set(
+ image_dir=image_dir,
+ gt_dir=gt_dir,
+ evaluator_type="cityscapes_sem_seg",
+ ignore_label=255,
+ **meta,
+ )
+
+
+# ==== Predefined splits for PASCAL VOC ===========
+def register_all_pascal_voc(root):
+ SPLITS = [
+ ("voc_2007_trainval", "VOC2007", "trainval"),
+ ("voc_2007_train", "VOC2007", "train"),
+ ("voc_2007_val", "VOC2007", "val"),
+ ("voc_2007_test", "VOC2007", "test"),
+ ("voc_2012_trainval", "VOC2012", "trainval"),
+ ("voc_2012_train", "VOC2012", "train"),
+ ("voc_2012_val", "VOC2012", "val"),
+ ]
+ for name, dirname, split in SPLITS:
+ year = 2007 if "2007" in name else 2012
+ register_pascal_voc(name, os.path.join(root, dirname), split, year)
+ MetadataCatalog.get(name).evaluator_type = "pascal_voc"
+
+
+def register_all_ade20k(root):
+ root = os.path.join(root, "ADEChallengeData2016")
+ for name, dirname in [("train", "training"), ("val", "validation")]:
+ image_dir = os.path.join(root, "images", dirname)
+ gt_dir = os.path.join(root, "annotations_detectron2", dirname)
+ name = f"ade20k_sem_seg_{name}"
+ DatasetCatalog.register(
+ name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg")
+ )
+ MetadataCatalog.get(name).set(
+ stuff_classes=ADE20K_SEM_SEG_CATEGORIES[:],
+ image_root=image_dir,
+ sem_seg_root=gt_dir,
+ evaluator_type="sem_seg",
+ ignore_label=255,
+ )
+
+
+# True for open source;
+# Internally at fb, we register them elsewhere
+if __name__.endswith(".builtin"):
+ # Assume pre-defined datasets live in `./datasets`.
+ _root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets"))
+ register_all_coco(_root)
+ register_all_lvis(_root)
+ register_all_cityscapes(_root)
+ register_all_cityscapes_panoptic(_root)
+ register_all_pascal_voc(_root)
+ register_all_ade20k(_root)
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/builtin_meta.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/builtin_meta.py
new file mode 100644
index 0000000000000000000000000000000000000000..63c7a1a31b31dd89b82011effee26471faccacf5
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/builtin_meta.py
@@ -0,0 +1,350 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+"""
+Note:
+For your custom dataset, there is no need to hard-code metadata anywhere in the code.
+For example, for COCO-format dataset, metadata will be obtained automatically
+when calling `load_coco_json`. For other dataset, metadata may also be obtained in other ways
+during loading.
+
+However, we hard-coded metadata for a few common dataset here.
+The only goal is to allow users who don't have these dataset to use pre-trained models.
+Users don't have to download a COCO json (which contains metadata), in order to visualize a
+COCO model (with correct class names and colors).
+"""
+
+
+# All coco categories, together with their nice-looking visualization colors
+# It's from https://github.com/cocodataset/panopticapi/blob/master/panoptic_coco_categories.json
+COCO_CATEGORIES = [
+ {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
+ {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
+ {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
+ {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
+ {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
+ {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
+ {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
+ {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
+ {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
+ {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
+ {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
+ {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
+ {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
+ {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
+ {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
+ {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
+ {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
+ {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
+ {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
+ {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
+ {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
+ {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
+ {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
+ {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
+ {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
+ {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
+ {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
+ {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
+ {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
+ {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
+ {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
+ {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
+ {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
+ {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
+ {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
+ {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
+ {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
+ {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
+ {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
+ {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
+ {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
+ {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
+ {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
+ {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
+ {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
+ {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
+ {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
+ {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
+ {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
+ {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
+ {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
+ {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
+ {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
+ {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
+ {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
+ {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
+ {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
+ {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
+ {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
+ {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
+ {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
+ {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
+ {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
+ {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
+ {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
+ {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
+ {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
+ {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
+ {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
+ {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
+ {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
+ {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
+ {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
+ {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
+ {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
+ {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
+ {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
+ {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
+ {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
+ {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
+ {"color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"},
+ {"color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"},
+ {"color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"},
+ {"color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"},
+ {"color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"},
+ {"color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"},
+ {"color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"},
+ {"color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"},
+ {"color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"},
+ {"color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"},
+ {"color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"},
+ {"color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"},
+ {"color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"},
+ {"color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"},
+ {"color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"},
+ {"color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"},
+ {"color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"},
+ {"color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"},
+ {"color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"},
+ {"color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"},
+ {"color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"},
+ {"color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"},
+ {"color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"},
+ {"color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"},
+ {"color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"},
+ {"color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"},
+ {"color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"},
+ {"color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"},
+ {"color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"},
+ {"color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"},
+ {"color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"},
+ {"color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"},
+ {"color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"},
+ {"color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"},
+ {"color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"},
+ {"color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"},
+ {"color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"},
+ {"color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"},
+ {"color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"},
+ {"color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"},
+ {"color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"},
+ {"color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"},
+ {"color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"},
+ {"color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"},
+ {"color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"},
+ {"color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"},
+ {"color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"},
+ {"color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"},
+ {"color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"},
+ {"color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"},
+ {"color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"},
+ {"color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"},
+ {"color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"},
+]
+
+# fmt: off
+COCO_PERSON_KEYPOINT_NAMES = (
+ "nose",
+ "left_eye", "right_eye",
+ "left_ear", "right_ear",
+ "left_shoulder", "right_shoulder",
+ "left_elbow", "right_elbow",
+ "left_wrist", "right_wrist",
+ "left_hip", "right_hip",
+ "left_knee", "right_knee",
+ "left_ankle", "right_ankle",
+)
+# fmt: on
+
+# Pairs of keypoints that should be exchanged under horizontal flipping
+COCO_PERSON_KEYPOINT_FLIP_MAP = (
+ ("left_eye", "right_eye"),
+ ("left_ear", "right_ear"),
+ ("left_shoulder", "right_shoulder"),
+ ("left_elbow", "right_elbow"),
+ ("left_wrist", "right_wrist"),
+ ("left_hip", "right_hip"),
+ ("left_knee", "right_knee"),
+ ("left_ankle", "right_ankle"),
+)
+
+# rules for pairs of keypoints to draw a line between, and the line color to use.
+KEYPOINT_CONNECTION_RULES = [
+ # face
+ ("left_ear", "left_eye", (102, 204, 255)),
+ ("right_ear", "right_eye", (51, 153, 255)),
+ ("left_eye", "nose", (102, 0, 204)),
+ ("nose", "right_eye", (51, 102, 255)),
+ # upper-body
+ ("left_shoulder", "right_shoulder", (255, 128, 0)),
+ ("left_shoulder", "left_elbow", (153, 255, 204)),
+ ("right_shoulder", "right_elbow", (128, 229, 255)),
+ ("left_elbow", "left_wrist", (153, 255, 153)),
+ ("right_elbow", "right_wrist", (102, 255, 224)),
+ # lower-body
+ ("left_hip", "right_hip", (255, 102, 0)),
+ ("left_hip", "left_knee", (255, 255, 77)),
+ ("right_hip", "right_knee", (153, 255, 204)),
+ ("left_knee", "left_ankle", (191, 255, 128)),
+ ("right_knee", "right_ankle", (255, 195, 77)),
+]
+
+# All Cityscapes categories, together with their nice-looking visualization colors
+# It's from https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py # noqa
+CITYSCAPES_CATEGORIES = [
+ {"color": (128, 64, 128), "isthing": 0, "id": 7, "trainId": 0, "name": "road"},
+ {"color": (244, 35, 232), "isthing": 0, "id": 8, "trainId": 1, "name": "sidewalk"},
+ {"color": (70, 70, 70), "isthing": 0, "id": 11, "trainId": 2, "name": "building"},
+ {"color": (102, 102, 156), "isthing": 0, "id": 12, "trainId": 3, "name": "wall"},
+ {"color": (190, 153, 153), "isthing": 0, "id": 13, "trainId": 4, "name": "fence"},
+ {"color": (153, 153, 153), "isthing": 0, "id": 17, "trainId": 5, "name": "pole"},
+ {"color": (250, 170, 30), "isthing": 0, "id": 19, "trainId": 6, "name": "traffic light"},
+ {"color": (220, 220, 0), "isthing": 0, "id": 20, "trainId": 7, "name": "traffic sign"},
+ {"color": (107, 142, 35), "isthing": 0, "id": 21, "trainId": 8, "name": "vegetation"},
+ {"color": (152, 251, 152), "isthing": 0, "id": 22, "trainId": 9, "name": "terrain"},
+ {"color": (70, 130, 180), "isthing": 0, "id": 23, "trainId": 10, "name": "sky"},
+ {"color": (220, 20, 60), "isthing": 1, "id": 24, "trainId": 11, "name": "person"},
+ {"color": (255, 0, 0), "isthing": 1, "id": 25, "trainId": 12, "name": "rider"},
+ {"color": (0, 0, 142), "isthing": 1, "id": 26, "trainId": 13, "name": "car"},
+ {"color": (0, 0, 70), "isthing": 1, "id": 27, "trainId": 14, "name": "truck"},
+ {"color": (0, 60, 100), "isthing": 1, "id": 28, "trainId": 15, "name": "bus"},
+ {"color": (0, 80, 100), "isthing": 1, "id": 31, "trainId": 16, "name": "train"},
+ {"color": (0, 0, 230), "isthing": 1, "id": 32, "trainId": 17, "name": "motorcycle"},
+ {"color": (119, 11, 32), "isthing": 1, "id": 33, "trainId": 18, "name": "bicycle"},
+]
+
+# fmt: off
+ADE20K_SEM_SEG_CATEGORIES = [
+ "wall", "building", "sky", "floor", "tree", "ceiling", "road, route", "bed", "window ", "grass", "cabinet", "sidewalk, pavement", "person", "earth, ground", "door", "table", "mountain, mount", "plant", "curtain", "chair", "car", "water", "painting, picture", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat", "fence", "desk", "rock, stone", "wardrobe, closet, press", "lamp", "tub", "rail", "cushion", "base, pedestal, stand", "box", "column, pillar", "signboard, sign", "chest of drawers, chest, bureau, dresser", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator, icebox", "grandstand, covered stand", "path", "stairs", "runway", "case, display case, showcase, vitrine", "pool table, billiard table, snooker table", "pillow", "screen door, screen", "stairway, staircase", "river", "bridge, span", "bookcase", "blind, screen", "coffee table", "toilet, can, commode, crapper, pot, potty, stool, throne", "flower", "book", "hill", "bench", "countertop", "stove", "palm, palm tree", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", "hovel, hut, hutch, shack, shanty", "bus", "towel", "light", "truck", "tower", "chandelier", "awning, sunshade, sunblind", "street lamp", "booth", "tv", "plane", "dirt track", "clothes", "pole", "land, ground, soil", "bannister, banister, balustrade, balusters, handrail", "escalator, moving staircase, moving stairway", "ottoman, pouf, pouffe, puff, hassock", "bottle", "buffet, counter, sideboard", "poster, posting, placard, notice, bill, card", "stage", "van", "ship", "fountain", "conveyer belt, conveyor belt, conveyer, conveyor, transporter", "canopy", "washer, automatic washer, washing machine", "plaything, toy", "pool", "stool", "barrel, cask", "basket, handbasket", "falls", "tent", "bag", "minibike, motorbike", "cradle", "oven", "ball", "food, solid food", "step, stair", "tank, storage tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", "screen", "blanket, cover", "sculpture", "hood, exhaust hood", "sconce", "vase", "traffic light", "tray", "trash can", "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", "glass, drinking glass", "clock", "flag", # noqa
+]
+# After processed by `prepare_ade20k_sem_seg.py`, id 255 means ignore
+# fmt: on
+
+
+def _get_coco_instances_meta():
+ thing_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 1]
+ thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1]
+ assert len(thing_ids) == 80, len(thing_ids)
+ # Mapping from the incontiguous COCO category id to an id in [0, 79]
+ thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
+ thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1]
+ ret = {
+ "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
+ "thing_classes": thing_classes,
+ "thing_colors": thing_colors,
+ }
+ return ret
+
+
+def _get_coco_panoptic_separated_meta():
+ """
+ Returns metadata for "separated" version of the panoptic segmentation dataset.
+ """
+ stuff_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 0]
+ assert len(stuff_ids) == 53, len(stuff_ids)
+
+ # For semantic segmentation, this mapping maps from contiguous stuff id
+ # (in [0, 53], used in models) to ids in the dataset (used for processing results)
+ # The id 0 is mapped to an extra category "thing".
+ stuff_dataset_id_to_contiguous_id = {k: i + 1 for i, k in enumerate(stuff_ids)}
+ # When converting COCO panoptic annotations to semantic annotations
+ # We label the "thing" category to 0
+ stuff_dataset_id_to_contiguous_id[0] = 0
+
+ # 54 names for COCO stuff categories (including "things")
+ stuff_classes = ["things"] + [
+ k["name"].replace("-other", "").replace("-merged", "")
+ for k in COCO_CATEGORIES
+ if k["isthing"] == 0
+ ]
+
+ # NOTE: I randomly picked a color for things
+ stuff_colors = [[82, 18, 128]] + [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 0]
+ ret = {
+ "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
+ "stuff_classes": stuff_classes,
+ "stuff_colors": stuff_colors,
+ }
+ ret.update(_get_coco_instances_meta())
+ return ret
+
+
+def _get_builtin_metadata(dataset_name):
+ if dataset_name == "coco":
+ return _get_coco_instances_meta()
+ if dataset_name == "coco_panoptic_separated":
+ return _get_coco_panoptic_separated_meta()
+ elif dataset_name == "coco_panoptic_standard":
+ meta = {}
+ # The following metadata maps contiguous id from [0, #thing categories +
+ # #stuff categories) to their names and colors. We have to replica of the
+ # same name and color under "thing_*" and "stuff_*" because the current
+ # visualization function in D2 handles thing and class classes differently
+ # due to some heuristic used in Panoptic FPN. We keep the same naming to
+ # enable reusing existing visualization functions.
+ thing_classes = [k["name"] for k in COCO_CATEGORIES]
+ thing_colors = [k["color"] for k in COCO_CATEGORIES]
+ stuff_classes = [k["name"] for k in COCO_CATEGORIES]
+ stuff_colors = [k["color"] for k in COCO_CATEGORIES]
+
+ meta["thing_classes"] = thing_classes
+ meta["thing_colors"] = thing_colors
+ meta["stuff_classes"] = stuff_classes
+ meta["stuff_colors"] = stuff_colors
+
+ # Convert category id for training:
+ # category id: like semantic segmentation, it is the class id for each
+ # pixel. Since there are some classes not used in evaluation, the category
+ # id is not always contiguous and thus we have two set of category ids:
+ # - original category id: category id in the original dataset, mainly
+ # used for evaluation.
+ # - contiguous category id: [0, #classes), in order to train the linear
+ # softmax classifier.
+ thing_dataset_id_to_contiguous_id = {}
+ stuff_dataset_id_to_contiguous_id = {}
+
+ for i, cat in enumerate(COCO_CATEGORIES):
+ if cat["isthing"]:
+ thing_dataset_id_to_contiguous_id[cat["id"]] = i
+ else:
+ stuff_dataset_id_to_contiguous_id[cat["id"]] = i
+
+ meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
+ meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
+
+ return meta
+ elif dataset_name == "coco_person":
+ return {
+ "thing_classes": ["person"],
+ "keypoint_names": COCO_PERSON_KEYPOINT_NAMES,
+ "keypoint_flip_map": COCO_PERSON_KEYPOINT_FLIP_MAP,
+ "keypoint_connection_rules": KEYPOINT_CONNECTION_RULES,
+ }
+ elif dataset_name == "cityscapes":
+ # fmt: off
+ CITYSCAPES_THING_CLASSES = [
+ "person", "rider", "car", "truck",
+ "bus", "train", "motorcycle", "bicycle",
+ ]
+ CITYSCAPES_STUFF_CLASSES = [
+ "road", "sidewalk", "building", "wall", "fence", "pole", "traffic light",
+ "traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car",
+ "truck", "bus", "train", "motorcycle", "bicycle",
+ ]
+ # fmt: on
+ return {
+ "thing_classes": CITYSCAPES_THING_CLASSES,
+ "stuff_classes": CITYSCAPES_STUFF_CLASSES,
+ }
+ raise KeyError("No built-in metadata for dataset {}".format(dataset_name))
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/cityscapes.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/cityscapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..f646be9da15914c2ea5e34e478fda3cfb5fb309f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/cityscapes.py
@@ -0,0 +1,329 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import functools
+import json
+import logging
+import multiprocessing as mp
+import numpy as np
+import os
+from itertools import chain
+import annotator.oneformer.pycocotools.mask as mask_util
+from PIL import Image
+
+from annotator.oneformer.detectron2.structures import BoxMode
+from annotator.oneformer.detectron2.utils.comm import get_world_size
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+from annotator.oneformer.detectron2.utils.logger import setup_logger
+
+try:
+ import cv2 # noqa
+except ImportError:
+ # OpenCV is an optional dependency at the moment
+ pass
+
+
+logger = logging.getLogger(__name__)
+
+
+def _get_cityscapes_files(image_dir, gt_dir):
+ files = []
+ # scan through the directory
+ cities = PathManager.ls(image_dir)
+ logger.info(f"{len(cities)} cities found in '{image_dir}'.")
+ for city in cities:
+ city_img_dir = os.path.join(image_dir, city)
+ city_gt_dir = os.path.join(gt_dir, city)
+ for basename in PathManager.ls(city_img_dir):
+ image_file = os.path.join(city_img_dir, basename)
+
+ suffix = "leftImg8bit.png"
+ assert basename.endswith(suffix), basename
+ basename = basename[: -len(suffix)]
+
+ instance_file = os.path.join(city_gt_dir, basename + "gtFine_instanceIds.png")
+ label_file = os.path.join(city_gt_dir, basename + "gtFine_labelIds.png")
+ json_file = os.path.join(city_gt_dir, basename + "gtFine_polygons.json")
+
+ files.append((image_file, instance_file, label_file, json_file))
+ assert len(files), "No images found in {}".format(image_dir)
+ for f in files[0]:
+ assert PathManager.isfile(f), f
+ return files
+
+
+def load_cityscapes_instances(image_dir, gt_dir, from_json=True, to_polygons=True):
+ """
+ Args:
+ image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
+ gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train".
+ from_json (bool): whether to read annotations from the raw json file or the png files.
+ to_polygons (bool): whether to represent the segmentation as polygons
+ (COCO's format) instead of masks (cityscapes's format).
+
+ Returns:
+ list[dict]: a list of dicts in Detectron2 standard format. (See
+ `Using Custom Datasets `_ )
+ """
+ if from_json:
+ assert to_polygons, (
+ "Cityscapes's json annotations are in polygon format. "
+ "Converting to mask format is not supported now."
+ )
+ files = _get_cityscapes_files(image_dir, gt_dir)
+
+ logger.info("Preprocessing cityscapes annotations ...")
+ # This is still not fast: all workers will execute duplicate works and will
+ # take up to 10m on a 8GPU server.
+ pool = mp.Pool(processes=max(mp.cpu_count() // get_world_size() // 2, 4))
+
+ ret = pool.map(
+ functools.partial(_cityscapes_files_to_dict, from_json=from_json, to_polygons=to_polygons),
+ files,
+ )
+ logger.info("Loaded {} images from {}".format(len(ret), image_dir))
+
+ # Map cityscape ids to contiguous ids
+ from cityscapesscripts.helpers.labels import labels
+
+ labels = [l for l in labels if l.hasInstances and not l.ignoreInEval]
+ dataset_id_to_contiguous_id = {l.id: idx for idx, l in enumerate(labels)}
+ for dict_per_image in ret:
+ for anno in dict_per_image["annotations"]:
+ anno["category_id"] = dataset_id_to_contiguous_id[anno["category_id"]]
+ return ret
+
+
+def load_cityscapes_semantic(image_dir, gt_dir):
+ """
+ Args:
+ image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
+ gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train".
+
+ Returns:
+ list[dict]: a list of dict, each has "file_name" and
+ "sem_seg_file_name".
+ """
+ ret = []
+ # gt_dir is small and contain many small files. make sense to fetch to local first
+ gt_dir = PathManager.get_local_path(gt_dir)
+ for image_file, _, label_file, json_file in _get_cityscapes_files(image_dir, gt_dir):
+ label_file = label_file.replace("labelIds", "labelTrainIds")
+
+ with PathManager.open(json_file, "r") as f:
+ jsonobj = json.load(f)
+ ret.append(
+ {
+ "file_name": image_file,
+ "sem_seg_file_name": label_file,
+ "height": jsonobj["imgHeight"],
+ "width": jsonobj["imgWidth"],
+ }
+ )
+ assert len(ret), f"No images found in {image_dir}!"
+ assert PathManager.isfile(
+ ret[0]["sem_seg_file_name"]
+ ), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa
+ return ret
+
+
+def _cityscapes_files_to_dict(files, from_json, to_polygons):
+ """
+ Parse cityscapes annotation files to a instance segmentation dataset dict.
+
+ Args:
+ files (tuple): consists of (image_file, instance_id_file, label_id_file, json_file)
+ from_json (bool): whether to read annotations from the raw json file or the png files.
+ to_polygons (bool): whether to represent the segmentation as polygons
+ (COCO's format) instead of masks (cityscapes's format).
+
+ Returns:
+ A dict in Detectron2 Dataset format.
+ """
+ from cityscapesscripts.helpers.labels import id2label, name2label
+
+ image_file, instance_id_file, _, json_file = files
+
+ annos = []
+
+ if from_json:
+ from shapely.geometry import MultiPolygon, Polygon
+
+ with PathManager.open(json_file, "r") as f:
+ jsonobj = json.load(f)
+ ret = {
+ "file_name": image_file,
+ "image_id": os.path.basename(image_file),
+ "height": jsonobj["imgHeight"],
+ "width": jsonobj["imgWidth"],
+ }
+
+ # `polygons_union` contains the union of all valid polygons.
+ polygons_union = Polygon()
+
+ # CityscapesScripts draw the polygons in sequential order
+ # and each polygon *overwrites* existing ones. See
+ # (https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/json2instanceImg.py) # noqa
+ # We use reverse order, and each polygon *avoids* early ones.
+ # This will resolve the ploygon overlaps in the same way as CityscapesScripts.
+ for obj in jsonobj["objects"][::-1]:
+ if "deleted" in obj: # cityscapes data format specific
+ continue
+ label_name = obj["label"]
+
+ try:
+ label = name2label[label_name]
+ except KeyError:
+ if label_name.endswith("group"): # crowd area
+ label = name2label[label_name[: -len("group")]]
+ else:
+ raise
+ if label.id < 0: # cityscapes data format
+ continue
+
+ # Cityscapes's raw annotations uses integer coordinates
+ # Therefore +0.5 here
+ poly_coord = np.asarray(obj["polygon"], dtype="f4") + 0.5
+ # CityscapesScript uses PIL.ImageDraw.polygon to rasterize
+ # polygons for evaluation. This function operates in integer space
+ # and draws each pixel whose center falls into the polygon.
+ # Therefore it draws a polygon which is 0.5 "fatter" in expectation.
+ # We therefore dilate the input polygon by 0.5 as our input.
+ poly = Polygon(poly_coord).buffer(0.5, resolution=4)
+
+ if not label.hasInstances or label.ignoreInEval:
+ # even if we won't store the polygon it still contributes to overlaps resolution
+ polygons_union = polygons_union.union(poly)
+ continue
+
+ # Take non-overlapping part of the polygon
+ poly_wo_overlaps = poly.difference(polygons_union)
+ if poly_wo_overlaps.is_empty:
+ continue
+ polygons_union = polygons_union.union(poly)
+
+ anno = {}
+ anno["iscrowd"] = label_name.endswith("group")
+ anno["category_id"] = label.id
+
+ if isinstance(poly_wo_overlaps, Polygon):
+ poly_list = [poly_wo_overlaps]
+ elif isinstance(poly_wo_overlaps, MultiPolygon):
+ poly_list = poly_wo_overlaps.geoms
+ else:
+ raise NotImplementedError("Unknown geometric structure {}".format(poly_wo_overlaps))
+
+ poly_coord = []
+ for poly_el in poly_list:
+ # COCO API can work only with exterior boundaries now, hence we store only them.
+ # TODO: store both exterior and interior boundaries once other parts of the
+ # codebase support holes in polygons.
+ poly_coord.append(list(chain(*poly_el.exterior.coords)))
+ anno["segmentation"] = poly_coord
+ (xmin, ymin, xmax, ymax) = poly_wo_overlaps.bounds
+
+ anno["bbox"] = (xmin, ymin, xmax, ymax)
+ anno["bbox_mode"] = BoxMode.XYXY_ABS
+
+ annos.append(anno)
+ else:
+ # See also the official annotation parsing scripts at
+ # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/instances2dict.py # noqa
+ with PathManager.open(instance_id_file, "rb") as f:
+ inst_image = np.asarray(Image.open(f), order="F")
+ # ids < 24 are stuff labels (filtering them first is about 5% faster)
+ flattened_ids = np.unique(inst_image[inst_image >= 24])
+
+ ret = {
+ "file_name": image_file,
+ "image_id": os.path.basename(image_file),
+ "height": inst_image.shape[0],
+ "width": inst_image.shape[1],
+ }
+
+ for instance_id in flattened_ids:
+ # For non-crowd annotations, instance_id // 1000 is the label_id
+ # Crowd annotations have <1000 instance ids
+ label_id = instance_id // 1000 if instance_id >= 1000 else instance_id
+ label = id2label[label_id]
+ if not label.hasInstances or label.ignoreInEval:
+ continue
+
+ anno = {}
+ anno["iscrowd"] = instance_id < 1000
+ anno["category_id"] = label.id
+
+ mask = np.asarray(inst_image == instance_id, dtype=np.uint8, order="F")
+
+ inds = np.nonzero(mask)
+ ymin, ymax = inds[0].min(), inds[0].max()
+ xmin, xmax = inds[1].min(), inds[1].max()
+ anno["bbox"] = (xmin, ymin, xmax, ymax)
+ if xmax <= xmin or ymax <= ymin:
+ continue
+ anno["bbox_mode"] = BoxMode.XYXY_ABS
+ if to_polygons:
+ # This conversion comes from D4809743 and D5171122,
+ # when Mask-RCNN was first developed.
+ contours = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[
+ -2
+ ]
+ polygons = [c.reshape(-1).tolist() for c in contours if len(c) >= 3]
+ # opencv's can produce invalid polygons
+ if len(polygons) == 0:
+ continue
+ anno["segmentation"] = polygons
+ else:
+ anno["segmentation"] = mask_util.encode(mask[:, :, None])[0]
+ annos.append(anno)
+ ret["annotations"] = annos
+ return ret
+
+
+if __name__ == "__main__":
+ """
+ Test the cityscapes dataset loader.
+
+ Usage:
+ python -m detectron2.data.datasets.cityscapes \
+ cityscapes/leftImg8bit/train cityscapes/gtFine/train
+ """
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("image_dir")
+ parser.add_argument("gt_dir")
+ parser.add_argument("--type", choices=["instance", "semantic"], default="instance")
+ args = parser.parse_args()
+ from annotator.oneformer.detectron2.data.catalog import Metadata
+ from annotator.oneformer.detectron2.utils.visualizer import Visualizer
+ from cityscapesscripts.helpers.labels import labels
+
+ logger = setup_logger(name=__name__)
+
+ dirname = "cityscapes-data-vis"
+ os.makedirs(dirname, exist_ok=True)
+
+ if args.type == "instance":
+ dicts = load_cityscapes_instances(
+ args.image_dir, args.gt_dir, from_json=True, to_polygons=True
+ )
+ logger.info("Done loading {} samples.".format(len(dicts)))
+
+ thing_classes = [k.name for k in labels if k.hasInstances and not k.ignoreInEval]
+ meta = Metadata().set(thing_classes=thing_classes)
+
+ else:
+ dicts = load_cityscapes_semantic(args.image_dir, args.gt_dir)
+ logger.info("Done loading {} samples.".format(len(dicts)))
+
+ stuff_classes = [k.name for k in labels if k.trainId != 255]
+ stuff_colors = [k.color for k in labels if k.trainId != 255]
+ meta = Metadata().set(stuff_classes=stuff_classes, stuff_colors=stuff_colors)
+
+ for d in dicts:
+ img = np.array(Image.open(PathManager.open(d["file_name"], "rb")))
+ visualizer = Visualizer(img, metadata=meta)
+ vis = visualizer.draw_dataset_dict(d)
+ # cv2.imshow("a", vis.get_image()[:, :, ::-1])
+ # cv2.waitKey()
+ fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
+ vis.save(fpath)
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/cityscapes_panoptic.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/cityscapes_panoptic.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ce9ec48f673dadf3f5b4ae0592fc82415d9f925
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/cityscapes_panoptic.py
@@ -0,0 +1,187 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import json
+import logging
+import os
+
+from annotator.oneformer.detectron2.data import DatasetCatalog, MetadataCatalog
+from annotator.oneformer.detectron2.data.datasets.builtin_meta import CITYSCAPES_CATEGORIES
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+"""
+This file contains functions to register the Cityscapes panoptic dataset to the DatasetCatalog.
+"""
+
+
+logger = logging.getLogger(__name__)
+
+
+def get_cityscapes_panoptic_files(image_dir, gt_dir, json_info):
+ files = []
+ # scan through the directory
+ cities = PathManager.ls(image_dir)
+ logger.info(f"{len(cities)} cities found in '{image_dir}'.")
+ image_dict = {}
+ for city in cities:
+ city_img_dir = os.path.join(image_dir, city)
+ for basename in PathManager.ls(city_img_dir):
+ image_file = os.path.join(city_img_dir, basename)
+
+ suffix = "_leftImg8bit.png"
+ assert basename.endswith(suffix), basename
+ basename = os.path.basename(basename)[: -len(suffix)]
+
+ image_dict[basename] = image_file
+
+ for ann in json_info["annotations"]:
+ image_file = image_dict.get(ann["image_id"], None)
+ assert image_file is not None, "No image {} found for annotation {}".format(
+ ann["image_id"], ann["file_name"]
+ )
+ label_file = os.path.join(gt_dir, ann["file_name"])
+ segments_info = ann["segments_info"]
+
+ files.append((image_file, label_file, segments_info))
+
+ assert len(files), "No images found in {}".format(image_dir)
+ assert PathManager.isfile(files[0][0]), files[0][0]
+ assert PathManager.isfile(files[0][1]), files[0][1]
+ return files
+
+
+def load_cityscapes_panoptic(image_dir, gt_dir, gt_json, meta):
+ """
+ Args:
+ image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
+ gt_dir (str): path to the raw annotations. e.g.,
+ "~/cityscapes/gtFine/cityscapes_panoptic_train".
+ gt_json (str): path to the json file. e.g.,
+ "~/cityscapes/gtFine/cityscapes_panoptic_train.json".
+ meta (dict): dictionary containing "thing_dataset_id_to_contiguous_id"
+ and "stuff_dataset_id_to_contiguous_id" to map category ids to
+ contiguous ids for training.
+
+ Returns:
+ list[dict]: a list of dicts in Detectron2 standard format. (See
+ `Using Custom Datasets `_ )
+ """
+
+ def _convert_category_id(segment_info, meta):
+ if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
+ segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
+ segment_info["category_id"]
+ ]
+ else:
+ segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
+ segment_info["category_id"]
+ ]
+ return segment_info
+
+ assert os.path.exists(
+ gt_json
+ ), "Please run `python cityscapesscripts/preparation/createPanopticImgs.py` to generate label files." # noqa
+ with open(gt_json) as f:
+ json_info = json.load(f)
+ files = get_cityscapes_panoptic_files(image_dir, gt_dir, json_info)
+ ret = []
+ for image_file, label_file, segments_info in files:
+ sem_label_file = (
+ image_file.replace("leftImg8bit", "gtFine").split(".")[0] + "_labelTrainIds.png"
+ )
+ segments_info = [_convert_category_id(x, meta) for x in segments_info]
+ ret.append(
+ {
+ "file_name": image_file,
+ "image_id": "_".join(
+ os.path.splitext(os.path.basename(image_file))[0].split("_")[:3]
+ ),
+ "sem_seg_file_name": sem_label_file,
+ "pan_seg_file_name": label_file,
+ "segments_info": segments_info,
+ }
+ )
+ assert len(ret), f"No images found in {image_dir}!"
+ assert PathManager.isfile(
+ ret[0]["sem_seg_file_name"]
+ ), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa
+ assert PathManager.isfile(
+ ret[0]["pan_seg_file_name"]
+ ), "Please generate panoptic annotation with python cityscapesscripts/preparation/createPanopticImgs.py" # noqa
+ return ret
+
+
+_RAW_CITYSCAPES_PANOPTIC_SPLITS = {
+ "cityscapes_fine_panoptic_train": (
+ "cityscapes/leftImg8bit/train",
+ "cityscapes/gtFine/cityscapes_panoptic_train",
+ "cityscapes/gtFine/cityscapes_panoptic_train.json",
+ ),
+ "cityscapes_fine_panoptic_val": (
+ "cityscapes/leftImg8bit/val",
+ "cityscapes/gtFine/cityscapes_panoptic_val",
+ "cityscapes/gtFine/cityscapes_panoptic_val.json",
+ ),
+ # "cityscapes_fine_panoptic_test": not supported yet
+}
+
+
+def register_all_cityscapes_panoptic(root):
+ meta = {}
+ # The following metadata maps contiguous id from [0, #thing categories +
+ # #stuff categories) to their names and colors. We have to replica of the
+ # same name and color under "thing_*" and "stuff_*" because the current
+ # visualization function in D2 handles thing and class classes differently
+ # due to some heuristic used in Panoptic FPN. We keep the same naming to
+ # enable reusing existing visualization functions.
+ thing_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
+ thing_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
+ stuff_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
+ stuff_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
+
+ meta["thing_classes"] = thing_classes
+ meta["thing_colors"] = thing_colors
+ meta["stuff_classes"] = stuff_classes
+ meta["stuff_colors"] = stuff_colors
+
+ # There are three types of ids in cityscapes panoptic segmentation:
+ # (1) category id: like semantic segmentation, it is the class id for each
+ # pixel. Since there are some classes not used in evaluation, the category
+ # id is not always contiguous and thus we have two set of category ids:
+ # - original category id: category id in the original dataset, mainly
+ # used for evaluation.
+ # - contiguous category id: [0, #classes), in order to train the classifier
+ # (2) instance id: this id is used to differentiate different instances from
+ # the same category. For "stuff" classes, the instance id is always 0; for
+ # "thing" classes, the instance id starts from 1 and 0 is reserved for
+ # ignored instances (e.g. crowd annotation).
+ # (3) panoptic id: this is the compact id that encode both category and
+ # instance id by: category_id * 1000 + instance_id.
+ thing_dataset_id_to_contiguous_id = {}
+ stuff_dataset_id_to_contiguous_id = {}
+
+ for k in CITYSCAPES_CATEGORIES:
+ if k["isthing"] == 1:
+ thing_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
+ else:
+ stuff_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
+
+ meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
+ meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
+
+ for key, (image_dir, gt_dir, gt_json) in _RAW_CITYSCAPES_PANOPTIC_SPLITS.items():
+ image_dir = os.path.join(root, image_dir)
+ gt_dir = os.path.join(root, gt_dir)
+ gt_json = os.path.join(root, gt_json)
+
+ DatasetCatalog.register(
+ key, lambda x=image_dir, y=gt_dir, z=gt_json: load_cityscapes_panoptic(x, y, z, meta)
+ )
+ MetadataCatalog.get(key).set(
+ panoptic_root=gt_dir,
+ image_root=image_dir,
+ panoptic_json=gt_json,
+ gt_dir=gt_dir.replace("cityscapes_panoptic_", ""),
+ evaluator_type="cityscapes_panoptic_seg",
+ ignore_label=255,
+ label_divisor=1000,
+ **meta,
+ )
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/coco.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a7cdba855979f9453904b1d6f0aedd47dd81200
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/coco.py
@@ -0,0 +1,539 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import contextlib
+import datetime
+import io
+import json
+import logging
+import numpy as np
+import os
+import shutil
+import annotator.oneformer.pycocotools.mask as mask_util
+from fvcore.common.timer import Timer
+from iopath.common.file_io import file_lock
+from PIL import Image
+
+from annotator.oneformer.detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+from .. import DatasetCatalog, MetadataCatalog
+
+"""
+This file contains functions to parse COCO-format annotations into dicts in "Detectron2 format".
+"""
+
+
+logger = logging.getLogger(__name__)
+
+__all__ = ["load_coco_json", "load_sem_seg", "convert_to_coco_json", "register_coco_instances"]
+
+
+def load_coco_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
+ """
+ Load a json file with COCO's instances annotation format.
+ Currently supports instance detection, instance segmentation,
+ and person keypoints annotations.
+
+ Args:
+ json_file (str): full path to the json file in COCO instances annotation format.
+ image_root (str or path-like): the directory where the images in this json file exists.
+ dataset_name (str or None): the name of the dataset (e.g., coco_2017_train).
+ When provided, this function will also do the following:
+
+ * Put "thing_classes" into the metadata associated with this dataset.
+ * Map the category ids into a contiguous range (needed by standard dataset format),
+ and add "thing_dataset_id_to_contiguous_id" to the metadata associated
+ with this dataset.
+
+ This option should usually be provided, unless users need to load
+ the original json content and apply more processing manually.
+ extra_annotation_keys (list[str]): list of per-annotation keys that should also be
+ loaded into the dataset dict (besides "iscrowd", "bbox", "keypoints",
+ "category_id", "segmentation"). The values for these keys will be returned as-is.
+ For example, the densepose annotations are loaded in this way.
+
+ Returns:
+ list[dict]: a list of dicts in Detectron2 standard dataset dicts format (See
+ `Using Custom Datasets `_ ) when `dataset_name` is not None.
+ If `dataset_name` is None, the returned `category_ids` may be
+ incontiguous and may not conform to the Detectron2 standard format.
+
+ Notes:
+ 1. This function does not read the image files.
+ The results do not have the "image" field.
+ """
+ from annotator.oneformer.pycocotools.coco import COCO
+
+ timer = Timer()
+ json_file = PathManager.get_local_path(json_file)
+ with contextlib.redirect_stdout(io.StringIO()):
+ coco_api = COCO(json_file)
+ if timer.seconds() > 1:
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
+
+ id_map = None
+ if dataset_name is not None:
+ meta = MetadataCatalog.get(dataset_name)
+ cat_ids = sorted(coco_api.getCatIds())
+ cats = coco_api.loadCats(cat_ids)
+ # The categories in a custom json file may not be sorted.
+ thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
+ meta.thing_classes = thing_classes
+
+ # In COCO, certain category ids are artificially removed,
+ # and by convention they are always ignored.
+ # We deal with COCO's id issue and translate
+ # the category ids to contiguous ids in [0, 80).
+
+ # It works by looking at the "categories" field in the json, therefore
+ # if users' own json also have incontiguous ids, we'll
+ # apply this mapping as well but print a warning.
+ if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
+ if "coco" not in dataset_name:
+ logger.warning(
+ """
+Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
+"""
+ )
+ id_map = {v: i for i, v in enumerate(cat_ids)}
+ meta.thing_dataset_id_to_contiguous_id = id_map
+
+ # sort indices for reproducible results
+ img_ids = sorted(coco_api.imgs.keys())
+ # imgs is a list of dicts, each looks something like:
+ # {'license': 4,
+ # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
+ # 'file_name': 'COCO_val2014_000000001268.jpg',
+ # 'height': 427,
+ # 'width': 640,
+ # 'date_captured': '2013-11-17 05:57:24',
+ # 'id': 1268}
+ imgs = coco_api.loadImgs(img_ids)
+ # anns is a list[list[dict]], where each dict is an annotation
+ # record for an object. The inner list enumerates the objects in an image
+ # and the outer list enumerates over images. Example of anns[0]:
+ # [{'segmentation': [[192.81,
+ # 247.09,
+ # ...
+ # 219.03,
+ # 249.06]],
+ # 'area': 1035.749,
+ # 'iscrowd': 0,
+ # 'image_id': 1268,
+ # 'bbox': [192.81, 224.8, 74.73, 33.43],
+ # 'category_id': 16,
+ # 'id': 42986},
+ # ...]
+ anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
+ total_num_valid_anns = sum([len(x) for x in anns])
+ total_num_anns = len(coco_api.anns)
+ if total_num_valid_anns < total_num_anns:
+ logger.warning(
+ f"{json_file} contains {total_num_anns} annotations, but only "
+ f"{total_num_valid_anns} of them match to images in the file."
+ )
+
+ if "minival" not in json_file:
+ # The popular valminusminival & minival annotations for COCO2014 contain this bug.
+ # However the ratio of buggy annotations there is tiny and does not affect accuracy.
+ # Therefore we explicitly white-list them.
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
+ assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
+ json_file
+ )
+
+ imgs_anns = list(zip(imgs, anns))
+ logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file))
+
+ dataset_dicts = []
+
+ ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"] + (extra_annotation_keys or [])
+
+ num_instances_without_valid_segmentation = 0
+
+ for (img_dict, anno_dict_list) in imgs_anns:
+ record = {}
+ record["file_name"] = os.path.join(image_root, img_dict["file_name"])
+ record["height"] = img_dict["height"]
+ record["width"] = img_dict["width"]
+ image_id = record["image_id"] = img_dict["id"]
+
+ objs = []
+ for anno in anno_dict_list:
+ # Check that the image_id in this annotation is the same as
+ # the image_id we're looking at.
+ # This fails only when the data parsing logic or the annotation file is buggy.
+
+ # The original COCO valminusminival2014 & minival2014 annotation files
+ # actually contains bugs that, together with certain ways of using COCO API,
+ # can trigger this assertion.
+ assert anno["image_id"] == image_id
+
+ assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.'
+
+ obj = {key: anno[key] for key in ann_keys if key in anno}
+ if "bbox" in obj and len(obj["bbox"]) == 0:
+ raise ValueError(
+ f"One annotation of image {image_id} contains empty 'bbox' value! "
+ "This json does not have valid COCO format."
+ )
+
+ segm = anno.get("segmentation", None)
+ if segm: # either list[list[float]] or dict(RLE)
+ if isinstance(segm, dict):
+ if isinstance(segm["counts"], list):
+ # convert to compressed RLE
+ segm = mask_util.frPyObjects(segm, *segm["size"])
+ else:
+ # filter out invalid polygons (< 3 points)
+ segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
+ if len(segm) == 0:
+ num_instances_without_valid_segmentation += 1
+ continue # ignore this instance
+ obj["segmentation"] = segm
+
+ keypts = anno.get("keypoints", None)
+ if keypts: # list[int]
+ for idx, v in enumerate(keypts):
+ if idx % 3 != 2:
+ # COCO's segmentation coordinates are floating points in [0, H or W],
+ # but keypoint coordinates are integers in [0, H-1 or W-1]
+ # Therefore we assume the coordinates are "pixel indices" and
+ # add 0.5 to convert to floating point coordinates.
+ keypts[idx] = v + 0.5
+ obj["keypoints"] = keypts
+
+ obj["bbox_mode"] = BoxMode.XYWH_ABS
+ if id_map:
+ annotation_category_id = obj["category_id"]
+ try:
+ obj["category_id"] = id_map[annotation_category_id]
+ except KeyError as e:
+ raise KeyError(
+ f"Encountered category_id={annotation_category_id} "
+ "but this id does not exist in 'categories' of the json file."
+ ) from e
+ objs.append(obj)
+ record["annotations"] = objs
+ dataset_dicts.append(record)
+
+ if num_instances_without_valid_segmentation > 0:
+ logger.warning(
+ "Filtered out {} instances without valid segmentation. ".format(
+ num_instances_without_valid_segmentation
+ )
+ + "There might be issues in your dataset generation process. Please "
+ "check https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html carefully"
+ )
+ return dataset_dicts
+
+
+def load_sem_seg(gt_root, image_root, gt_ext="png", image_ext="jpg"):
+ """
+ Load semantic segmentation datasets. All files under "gt_root" with "gt_ext" extension are
+ treated as ground truth annotations and all files under "image_root" with "image_ext" extension
+ as input images. Ground truth and input images are matched using file paths relative to
+ "gt_root" and "image_root" respectively without taking into account file extensions.
+ This works for COCO as well as some other datasets.
+
+ Args:
+ gt_root (str): full path to ground truth semantic segmentation files. Semantic segmentation
+ annotations are stored as images with integer values in pixels that represent
+ corresponding semantic labels.
+ image_root (str): the directory where the input images are.
+ gt_ext (str): file extension for ground truth annotations.
+ image_ext (str): file extension for input images.
+
+ Returns:
+ list[dict]:
+ a list of dicts in detectron2 standard format without instance-level
+ annotation.
+
+ Notes:
+ 1. This function does not read the image and ground truth files.
+ The results do not have the "image" and "sem_seg" fields.
+ """
+
+ # We match input images with ground truth based on their relative filepaths (without file
+ # extensions) starting from 'image_root' and 'gt_root' respectively.
+ def file2id(folder_path, file_path):
+ # extract relative path starting from `folder_path`
+ image_id = os.path.normpath(os.path.relpath(file_path, start=folder_path))
+ # remove file extension
+ image_id = os.path.splitext(image_id)[0]
+ return image_id
+
+ input_files = sorted(
+ (os.path.join(image_root, f) for f in PathManager.ls(image_root) if f.endswith(image_ext)),
+ key=lambda file_path: file2id(image_root, file_path),
+ )
+ gt_files = sorted(
+ (os.path.join(gt_root, f) for f in PathManager.ls(gt_root) if f.endswith(gt_ext)),
+ key=lambda file_path: file2id(gt_root, file_path),
+ )
+
+ assert len(gt_files) > 0, "No annotations found in {}.".format(gt_root)
+
+ # Use the intersection, so that val2017_100 annotations can run smoothly with val2017 images
+ if len(input_files) != len(gt_files):
+ logger.warn(
+ "Directory {} and {} has {} and {} files, respectively.".format(
+ image_root, gt_root, len(input_files), len(gt_files)
+ )
+ )
+ input_basenames = [os.path.basename(f)[: -len(image_ext)] for f in input_files]
+ gt_basenames = [os.path.basename(f)[: -len(gt_ext)] for f in gt_files]
+ intersect = list(set(input_basenames) & set(gt_basenames))
+ # sort, otherwise each worker may obtain a list[dict] in different order
+ intersect = sorted(intersect)
+ logger.warn("Will use their intersection of {} files.".format(len(intersect)))
+ input_files = [os.path.join(image_root, f + image_ext) for f in intersect]
+ gt_files = [os.path.join(gt_root, f + gt_ext) for f in intersect]
+
+ logger.info(
+ "Loaded {} images with semantic segmentation from {}".format(len(input_files), image_root)
+ )
+
+ dataset_dicts = []
+ for (img_path, gt_path) in zip(input_files, gt_files):
+ record = {}
+ record["file_name"] = img_path
+ record["sem_seg_file_name"] = gt_path
+ dataset_dicts.append(record)
+
+ return dataset_dicts
+
+
+def convert_to_coco_dict(dataset_name):
+ """
+ Convert an instance detection/segmentation or keypoint detection dataset
+ in detectron2's standard format into COCO json format.
+
+ Generic dataset description can be found here:
+ https://detectron2.readthedocs.io/tutorials/datasets.html#register-a-dataset
+
+ COCO data format description can be found here:
+ http://cocodataset.org/#format-data
+
+ Args:
+ dataset_name (str):
+ name of the source dataset
+ Must be registered in DatastCatalog and in detectron2's standard format.
+ Must have corresponding metadata "thing_classes"
+ Returns:
+ coco_dict: serializable dict in COCO json format
+ """
+
+ dataset_dicts = DatasetCatalog.get(dataset_name)
+ metadata = MetadataCatalog.get(dataset_name)
+
+ # unmap the category mapping ids for COCO
+ if hasattr(metadata, "thing_dataset_id_to_contiguous_id"):
+ reverse_id_mapping = {v: k for k, v in metadata.thing_dataset_id_to_contiguous_id.items()}
+ reverse_id_mapper = lambda contiguous_id: reverse_id_mapping[contiguous_id] # noqa
+ else:
+ reverse_id_mapper = lambda contiguous_id: contiguous_id # noqa
+
+ categories = [
+ {"id": reverse_id_mapper(id), "name": name}
+ for id, name in enumerate(metadata.thing_classes)
+ ]
+
+ logger.info("Converting dataset dicts into COCO format")
+ coco_images = []
+ coco_annotations = []
+
+ for image_id, image_dict in enumerate(dataset_dicts):
+ coco_image = {
+ "id": image_dict.get("image_id", image_id),
+ "width": int(image_dict["width"]),
+ "height": int(image_dict["height"]),
+ "file_name": str(image_dict["file_name"]),
+ }
+ coco_images.append(coco_image)
+
+ anns_per_image = image_dict.get("annotations", [])
+ for annotation in anns_per_image:
+ # create a new dict with only COCO fields
+ coco_annotation = {}
+
+ # COCO requirement: XYWH box format for axis-align and XYWHA for rotated
+ bbox = annotation["bbox"]
+ if isinstance(bbox, np.ndarray):
+ if bbox.ndim != 1:
+ raise ValueError(f"bbox has to be 1-dimensional. Got shape={bbox.shape}.")
+ bbox = bbox.tolist()
+ if len(bbox) not in [4, 5]:
+ raise ValueError(f"bbox has to has length 4 or 5. Got {bbox}.")
+ from_bbox_mode = annotation["bbox_mode"]
+ to_bbox_mode = BoxMode.XYWH_ABS if len(bbox) == 4 else BoxMode.XYWHA_ABS
+ bbox = BoxMode.convert(bbox, from_bbox_mode, to_bbox_mode)
+
+ # COCO requirement: instance area
+ if "segmentation" in annotation:
+ # Computing areas for instances by counting the pixels
+ segmentation = annotation["segmentation"]
+ # TODO: check segmentation type: RLE, BinaryMask or Polygon
+ if isinstance(segmentation, list):
+ polygons = PolygonMasks([segmentation])
+ area = polygons.area()[0].item()
+ elif isinstance(segmentation, dict): # RLE
+ area = mask_util.area(segmentation).item()
+ else:
+ raise TypeError(f"Unknown segmentation type {type(segmentation)}!")
+ else:
+ # Computing areas using bounding boxes
+ if to_bbox_mode == BoxMode.XYWH_ABS:
+ bbox_xy = BoxMode.convert(bbox, to_bbox_mode, BoxMode.XYXY_ABS)
+ area = Boxes([bbox_xy]).area()[0].item()
+ else:
+ area = RotatedBoxes([bbox]).area()[0].item()
+
+ if "keypoints" in annotation:
+ keypoints = annotation["keypoints"] # list[int]
+ for idx, v in enumerate(keypoints):
+ if idx % 3 != 2:
+ # COCO's segmentation coordinates are floating points in [0, H or W],
+ # but keypoint coordinates are integers in [0, H-1 or W-1]
+ # For COCO format consistency we substract 0.5
+ # https://github.com/facebookresearch/detectron2/pull/175#issuecomment-551202163
+ keypoints[idx] = v - 0.5
+ if "num_keypoints" in annotation:
+ num_keypoints = annotation["num_keypoints"]
+ else:
+ num_keypoints = sum(kp > 0 for kp in keypoints[2::3])
+
+ # COCO requirement:
+ # linking annotations to images
+ # "id" field must start with 1
+ coco_annotation["id"] = len(coco_annotations) + 1
+ coco_annotation["image_id"] = coco_image["id"]
+ coco_annotation["bbox"] = [round(float(x), 3) for x in bbox]
+ coco_annotation["area"] = float(area)
+ coco_annotation["iscrowd"] = int(annotation.get("iscrowd", 0))
+ coco_annotation["category_id"] = int(reverse_id_mapper(annotation["category_id"]))
+
+ # Add optional fields
+ if "keypoints" in annotation:
+ coco_annotation["keypoints"] = keypoints
+ coco_annotation["num_keypoints"] = num_keypoints
+
+ if "segmentation" in annotation:
+ seg = coco_annotation["segmentation"] = annotation["segmentation"]
+ if isinstance(seg, dict): # RLE
+ counts = seg["counts"]
+ if not isinstance(counts, str):
+ # make it json-serializable
+ seg["counts"] = counts.decode("ascii")
+
+ coco_annotations.append(coco_annotation)
+
+ logger.info(
+ "Conversion finished, "
+ f"#images: {len(coco_images)}, #annotations: {len(coco_annotations)}"
+ )
+
+ info = {
+ "date_created": str(datetime.datetime.now()),
+ "description": "Automatically generated COCO json file for Detectron2.",
+ }
+ coco_dict = {"info": info, "images": coco_images, "categories": categories, "licenses": None}
+ if len(coco_annotations) > 0:
+ coco_dict["annotations"] = coco_annotations
+ return coco_dict
+
+
+def convert_to_coco_json(dataset_name, output_file, allow_cached=True):
+ """
+ Converts dataset into COCO format and saves it to a json file.
+ dataset_name must be registered in DatasetCatalog and in detectron2's standard format.
+
+ Args:
+ dataset_name:
+ reference from the config file to the catalogs
+ must be registered in DatasetCatalog and in detectron2's standard format
+ output_file: path of json file that will be saved to
+ allow_cached: if json file is already present then skip conversion
+ """
+
+ # TODO: The dataset or the conversion script *may* change,
+ # a checksum would be useful for validating the cached data
+
+ PathManager.mkdirs(os.path.dirname(output_file))
+ with file_lock(output_file):
+ if PathManager.exists(output_file) and allow_cached:
+ logger.warning(
+ f"Using previously cached COCO format annotations at '{output_file}'. "
+ "You need to clear the cache file if your dataset has been modified."
+ )
+ else:
+ logger.info(f"Converting annotations of dataset '{dataset_name}' to COCO format ...)")
+ coco_dict = convert_to_coco_dict(dataset_name)
+
+ logger.info(f"Caching COCO format annotations at '{output_file}' ...")
+ tmp_file = output_file + ".tmp"
+ with PathManager.open(tmp_file, "w") as f:
+ json.dump(coco_dict, f)
+ shutil.move(tmp_file, output_file)
+
+
+def register_coco_instances(name, metadata, json_file, image_root):
+ """
+ Register a dataset in COCO's json annotation format for
+ instance detection, instance segmentation and keypoint detection.
+ (i.e., Type 1 and 2 in http://cocodataset.org/#format-data.
+ `instances*.json` and `person_keypoints*.json` in the dataset).
+
+ This is an example of how to register a new dataset.
+ You can do something similar to this function, to register new datasets.
+
+ Args:
+ name (str): the name that identifies a dataset, e.g. "coco_2014_train".
+ metadata (dict): extra metadata associated with this dataset. You can
+ leave it as an empty dict.
+ json_file (str): path to the json instance annotation file.
+ image_root (str or path-like): directory which contains all the images.
+ """
+ assert isinstance(name, str), name
+ assert isinstance(json_file, (str, os.PathLike)), json_file
+ assert isinstance(image_root, (str, os.PathLike)), image_root
+ # 1. register a function which returns dicts
+ DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))
+
+ # 2. Optionally, add metadata about this dataset,
+ # since they might be useful in evaluation, visualization or logging
+ MetadataCatalog.get(name).set(
+ json_file=json_file, image_root=image_root, evaluator_type="coco", **metadata
+ )
+
+
+if __name__ == "__main__":
+ """
+ Test the COCO json dataset loader.
+
+ Usage:
+ python -m detectron2.data.datasets.coco \
+ path/to/json path/to/image_root dataset_name
+
+ "dataset_name" can be "coco_2014_minival_100", or other
+ pre-registered ones
+ """
+ from annotator.oneformer.detectron2.utils.logger import setup_logger
+ from annotator.oneformer.detectron2.utils.visualizer import Visualizer
+ import annotator.oneformer.detectron2.data.datasets # noqa # add pre-defined metadata
+ import sys
+
+ logger = setup_logger(name=__name__)
+ assert sys.argv[3] in DatasetCatalog.list()
+ meta = MetadataCatalog.get(sys.argv[3])
+
+ dicts = load_coco_json(sys.argv[1], sys.argv[2], sys.argv[3])
+ logger.info("Done loading {} samples.".format(len(dicts)))
+
+ dirname = "coco-data-vis"
+ os.makedirs(dirname, exist_ok=True)
+ for d in dicts:
+ img = np.array(Image.open(d["file_name"]))
+ visualizer = Visualizer(img, metadata=meta)
+ vis = visualizer.draw_dataset_dict(d)
+ fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
+ vis.save(fpath)
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/coco_panoptic.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/coco_panoptic.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7180df512c29665222b1a90323ccfa7e7623137
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/coco_panoptic.py
@@ -0,0 +1,228 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import copy
+import json
+import os
+
+from annotator.oneformer.detectron2.data import DatasetCatalog, MetadataCatalog
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+from .coco import load_coco_json, load_sem_seg
+
+__all__ = ["register_coco_panoptic", "register_coco_panoptic_separated"]
+
+
+def load_coco_panoptic_json(json_file, image_dir, gt_dir, meta):
+ """
+ Args:
+ image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
+ gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
+ json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
+
+ Returns:
+ list[dict]: a list of dicts in Detectron2 standard format. (See
+ `Using Custom Datasets `_ )
+ """
+
+ def _convert_category_id(segment_info, meta):
+ if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
+ segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
+ segment_info["category_id"]
+ ]
+ segment_info["isthing"] = True
+ else:
+ segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
+ segment_info["category_id"]
+ ]
+ segment_info["isthing"] = False
+ return segment_info
+
+ with PathManager.open(json_file) as f:
+ json_info = json.load(f)
+
+ ret = []
+ for ann in json_info["annotations"]:
+ image_id = int(ann["image_id"])
+ # TODO: currently we assume image and label has the same filename but
+ # different extension, and images have extension ".jpg" for COCO. Need
+ # to make image extension a user-provided argument if we extend this
+ # function to support other COCO-like datasets.
+ image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")
+ label_file = os.path.join(gt_dir, ann["file_name"])
+ segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]
+ ret.append(
+ {
+ "file_name": image_file,
+ "image_id": image_id,
+ "pan_seg_file_name": label_file,
+ "segments_info": segments_info,
+ }
+ )
+ assert len(ret), f"No images found in {image_dir}!"
+ assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
+ assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]
+ return ret
+
+
+def register_coco_panoptic(
+ name, metadata, image_root, panoptic_root, panoptic_json, instances_json=None
+):
+ """
+ Register a "standard" version of COCO panoptic segmentation dataset named `name`.
+ The dictionaries in this registered dataset follows detectron2's standard format.
+ Hence it's called "standard".
+
+ Args:
+ name (str): the name that identifies a dataset,
+ e.g. "coco_2017_train_panoptic"
+ metadata (dict): extra metadata associated with this dataset.
+ image_root (str): directory which contains all the images
+ panoptic_root (str): directory which contains panoptic annotation images in COCO format
+ panoptic_json (str): path to the json panoptic annotation file in COCO format
+ sem_seg_root (none): not used, to be consistent with
+ `register_coco_panoptic_separated`.
+ instances_json (str): path to the json instance annotation file
+ """
+ panoptic_name = name
+ DatasetCatalog.register(
+ panoptic_name,
+ lambda: load_coco_panoptic_json(panoptic_json, image_root, panoptic_root, metadata),
+ )
+ MetadataCatalog.get(panoptic_name).set(
+ panoptic_root=panoptic_root,
+ image_root=image_root,
+ panoptic_json=panoptic_json,
+ json_file=instances_json,
+ evaluator_type="coco_panoptic_seg",
+ ignore_label=255,
+ label_divisor=1000,
+ **metadata,
+ )
+
+
+def register_coco_panoptic_separated(
+ name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, instances_json
+):
+ """
+ Register a "separated" version of COCO panoptic segmentation dataset named `name`.
+ The annotations in this registered dataset will contain both instance annotations and
+ semantic annotations, each with its own contiguous ids. Hence it's called "separated".
+
+ It follows the setting used by the PanopticFPN paper:
+
+ 1. The instance annotations directly come from polygons in the COCO
+ instances annotation task, rather than from the masks in the COCO panoptic annotations.
+
+ The two format have small differences:
+ Polygons in the instance annotations may have overlaps.
+ The mask annotations are produced by labeling the overlapped polygons
+ with depth ordering.
+
+ 2. The semantic annotations are converted from panoptic annotations, where
+ all "things" are assigned a semantic id of 0.
+ All semantic categories will therefore have ids in contiguous
+ range [1, #stuff_categories].
+
+ This function will also register a pure semantic segmentation dataset
+ named ``name + '_stuffonly'``.
+
+ Args:
+ name (str): the name that identifies a dataset,
+ e.g. "coco_2017_train_panoptic"
+ metadata (dict): extra metadata associated with this dataset.
+ image_root (str): directory which contains all the images
+ panoptic_root (str): directory which contains panoptic annotation images
+ panoptic_json (str): path to the json panoptic annotation file
+ sem_seg_root (str): directory which contains all the ground truth segmentation annotations.
+ instances_json (str): path to the json instance annotation file
+ """
+ panoptic_name = name + "_separated"
+ DatasetCatalog.register(
+ panoptic_name,
+ lambda: merge_to_panoptic(
+ load_coco_json(instances_json, image_root, panoptic_name),
+ load_sem_seg(sem_seg_root, image_root),
+ ),
+ )
+ MetadataCatalog.get(panoptic_name).set(
+ panoptic_root=panoptic_root,
+ image_root=image_root,
+ panoptic_json=panoptic_json,
+ sem_seg_root=sem_seg_root,
+ json_file=instances_json, # TODO rename
+ evaluator_type="coco_panoptic_seg",
+ ignore_label=255,
+ **metadata,
+ )
+
+ semantic_name = name + "_stuffonly"
+ DatasetCatalog.register(semantic_name, lambda: load_sem_seg(sem_seg_root, image_root))
+ MetadataCatalog.get(semantic_name).set(
+ sem_seg_root=sem_seg_root,
+ image_root=image_root,
+ evaluator_type="sem_seg",
+ ignore_label=255,
+ **metadata,
+ )
+
+
+def merge_to_panoptic(detection_dicts, sem_seg_dicts):
+ """
+ Create dataset dicts for panoptic segmentation, by
+ merging two dicts using "file_name" field to match their entries.
+
+ Args:
+ detection_dicts (list[dict]): lists of dicts for object detection or instance segmentation.
+ sem_seg_dicts (list[dict]): lists of dicts for semantic segmentation.
+
+ Returns:
+ list[dict] (one per input image): Each dict contains all (key, value) pairs from dicts in
+ both detection_dicts and sem_seg_dicts that correspond to the same image.
+ The function assumes that the same key in different dicts has the same value.
+ """
+ results = []
+ sem_seg_file_to_entry = {x["file_name"]: x for x in sem_seg_dicts}
+ assert len(sem_seg_file_to_entry) > 0
+
+ for det_dict in detection_dicts:
+ dic = copy.copy(det_dict)
+ dic.update(sem_seg_file_to_entry[dic["file_name"]])
+ results.append(dic)
+ return results
+
+
+if __name__ == "__main__":
+ """
+ Test the COCO panoptic dataset loader.
+
+ Usage:
+ python -m detectron2.data.datasets.coco_panoptic \
+ path/to/image_root path/to/panoptic_root path/to/panoptic_json dataset_name 10
+
+ "dataset_name" can be "coco_2017_train_panoptic", or other
+ pre-registered ones
+ """
+ from annotator.oneformer.detectron2.utils.logger import setup_logger
+ from annotator.oneformer.detectron2.utils.visualizer import Visualizer
+ import annotator.oneformer.detectron2.data.datasets # noqa # add pre-defined metadata
+ import sys
+ from PIL import Image
+ import numpy as np
+
+ logger = setup_logger(name=__name__)
+ assert sys.argv[4] in DatasetCatalog.list()
+ meta = MetadataCatalog.get(sys.argv[4])
+
+ dicts = load_coco_panoptic_json(sys.argv[3], sys.argv[1], sys.argv[2], meta.as_dict())
+ logger.info("Done loading {} samples.".format(len(dicts)))
+
+ dirname = "coco-data-vis"
+ os.makedirs(dirname, exist_ok=True)
+ num_imgs_to_vis = int(sys.argv[5])
+ for i, d in enumerate(dicts):
+ img = np.array(Image.open(d["file_name"]))
+ visualizer = Visualizer(img, metadata=meta)
+ vis = visualizer.draw_dataset_dict(d)
+ fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
+ vis.save(fpath)
+ if i + 1 >= num_imgs_to_vis:
+ break
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/lvis.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/lvis.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e1e6ecc657e83d6df57da342b0655177402c514
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/lvis.py
@@ -0,0 +1,241 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+import os
+from fvcore.common.timer import Timer
+
+from annotator.oneformer.detectron2.data import DatasetCatalog, MetadataCatalog
+from annotator.oneformer.detectron2.structures import BoxMode
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+from .builtin_meta import _get_coco_instances_meta
+from .lvis_v0_5_categories import LVIS_CATEGORIES as LVIS_V0_5_CATEGORIES
+from .lvis_v1_categories import LVIS_CATEGORIES as LVIS_V1_CATEGORIES
+from .lvis_v1_category_image_count import LVIS_CATEGORY_IMAGE_COUNT as LVIS_V1_CATEGORY_IMAGE_COUNT
+
+"""
+This file contains functions to parse LVIS-format annotations into dicts in the
+"Detectron2 format".
+"""
+
+logger = logging.getLogger(__name__)
+
+__all__ = ["load_lvis_json", "register_lvis_instances", "get_lvis_instances_meta"]
+
+
+def register_lvis_instances(name, metadata, json_file, image_root):
+ """
+ Register a dataset in LVIS's json annotation format for instance detection and segmentation.
+
+ Args:
+ name (str): a name that identifies the dataset, e.g. "lvis_v0.5_train".
+ metadata (dict): extra metadata associated with this dataset. It can be an empty dict.
+ json_file (str): path to the json instance annotation file.
+ image_root (str or path-like): directory which contains all the images.
+ """
+ DatasetCatalog.register(name, lambda: load_lvis_json(json_file, image_root, name))
+ MetadataCatalog.get(name).set(
+ json_file=json_file, image_root=image_root, evaluator_type="lvis", **metadata
+ )
+
+
+def load_lvis_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
+ """
+ Load a json file in LVIS's annotation format.
+
+ Args:
+ json_file (str): full path to the LVIS json annotation file.
+ image_root (str): the directory where the images in this json file exists.
+ dataset_name (str): the name of the dataset (e.g., "lvis_v0.5_train").
+ If provided, this function will put "thing_classes" into the metadata
+ associated with this dataset.
+ extra_annotation_keys (list[str]): list of per-annotation keys that should also be
+ loaded into the dataset dict (besides "bbox", "bbox_mode", "category_id",
+ "segmentation"). The values for these keys will be returned as-is.
+
+ Returns:
+ list[dict]: a list of dicts in Detectron2 standard format. (See
+ `Using Custom Datasets `_ )
+
+ Notes:
+ 1. This function does not read the image files.
+ The results do not have the "image" field.
+ """
+ from lvis import LVIS
+
+ json_file = PathManager.get_local_path(json_file)
+
+ timer = Timer()
+ lvis_api = LVIS(json_file)
+ if timer.seconds() > 1:
+ logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
+
+ if dataset_name is not None:
+ meta = get_lvis_instances_meta(dataset_name)
+ MetadataCatalog.get(dataset_name).set(**meta)
+
+ # sort indices for reproducible results
+ img_ids = sorted(lvis_api.imgs.keys())
+ # imgs is a list of dicts, each looks something like:
+ # {'license': 4,
+ # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg',
+ # 'file_name': 'COCO_val2014_000000001268.jpg',
+ # 'height': 427,
+ # 'width': 640,
+ # 'date_captured': '2013-11-17 05:57:24',
+ # 'id': 1268}
+ imgs = lvis_api.load_imgs(img_ids)
+ # anns is a list[list[dict]], where each dict is an annotation
+ # record for an object. The inner list enumerates the objects in an image
+ # and the outer list enumerates over images. Example of anns[0]:
+ # [{'segmentation': [[192.81,
+ # 247.09,
+ # ...
+ # 219.03,
+ # 249.06]],
+ # 'area': 1035.749,
+ # 'image_id': 1268,
+ # 'bbox': [192.81, 224.8, 74.73, 33.43],
+ # 'category_id': 16,
+ # 'id': 42986},
+ # ...]
+ anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
+
+ # Sanity check that each annotation has a unique id
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
+ assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique".format(
+ json_file
+ )
+
+ imgs_anns = list(zip(imgs, anns))
+
+ logger.info("Loaded {} images in the LVIS format from {}".format(len(imgs_anns), json_file))
+
+ if extra_annotation_keys:
+ logger.info(
+ "The following extra annotation keys will be loaded: {} ".format(extra_annotation_keys)
+ )
+ else:
+ extra_annotation_keys = []
+
+ def get_file_name(img_root, img_dict):
+ # Determine the path including the split folder ("train2017", "val2017", "test2017") from
+ # the coco_url field. Example:
+ # 'coco_url': 'http://images.cocodataset.org/train2017/000000155379.jpg'
+ split_folder, file_name = img_dict["coco_url"].split("/")[-2:]
+ return os.path.join(img_root + split_folder, file_name)
+
+ dataset_dicts = []
+
+ for (img_dict, anno_dict_list) in imgs_anns:
+ record = {}
+ record["file_name"] = get_file_name(image_root, img_dict)
+ record["height"] = img_dict["height"]
+ record["width"] = img_dict["width"]
+ record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", [])
+ record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
+ image_id = record["image_id"] = img_dict["id"]
+
+ objs = []
+ for anno in anno_dict_list:
+ # Check that the image_id in this annotation is the same as
+ # the image_id we're looking at.
+ # This fails only when the data parsing logic or the annotation file is buggy.
+ assert anno["image_id"] == image_id
+ obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
+ # LVIS data loader can be used to load COCO dataset categories. In this case `meta`
+ # variable will have a field with COCO-specific category mapping.
+ if dataset_name is not None and "thing_dataset_id_to_contiguous_id" in meta:
+ obj["category_id"] = meta["thing_dataset_id_to_contiguous_id"][anno["category_id"]]
+ else:
+ obj["category_id"] = anno["category_id"] - 1 # Convert 1-indexed to 0-indexed
+ segm = anno["segmentation"] # list[list[float]]
+ # filter out invalid polygons (< 3 points)
+ valid_segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
+ assert len(segm) == len(
+ valid_segm
+ ), "Annotation contains an invalid polygon with < 3 points"
+ assert len(segm) > 0
+ obj["segmentation"] = segm
+ for extra_ann_key in extra_annotation_keys:
+ obj[extra_ann_key] = anno[extra_ann_key]
+ objs.append(obj)
+ record["annotations"] = objs
+ dataset_dicts.append(record)
+
+ return dataset_dicts
+
+
+def get_lvis_instances_meta(dataset_name):
+ """
+ Load LVIS metadata.
+
+ Args:
+ dataset_name (str): LVIS dataset name without the split name (e.g., "lvis_v0.5").
+
+ Returns:
+ dict: LVIS metadata with keys: thing_classes
+ """
+ if "cocofied" in dataset_name:
+ return _get_coco_instances_meta()
+ if "v0.5" in dataset_name:
+ return _get_lvis_instances_meta_v0_5()
+ elif "v1" in dataset_name:
+ return _get_lvis_instances_meta_v1()
+ raise ValueError("No built-in metadata for dataset {}".format(dataset_name))
+
+
+def _get_lvis_instances_meta_v0_5():
+ assert len(LVIS_V0_5_CATEGORIES) == 1230
+ cat_ids = [k["id"] for k in LVIS_V0_5_CATEGORIES]
+ assert min(cat_ids) == 1 and max(cat_ids) == len(
+ cat_ids
+ ), "Category ids are not in [1, #categories], as expected"
+ # Ensure that the category list is sorted by id
+ lvis_categories = sorted(LVIS_V0_5_CATEGORIES, key=lambda x: x["id"])
+ thing_classes = [k["synonyms"][0] for k in lvis_categories]
+ meta = {"thing_classes": thing_classes}
+ return meta
+
+
+def _get_lvis_instances_meta_v1():
+ assert len(LVIS_V1_CATEGORIES) == 1203
+ cat_ids = [k["id"] for k in LVIS_V1_CATEGORIES]
+ assert min(cat_ids) == 1 and max(cat_ids) == len(
+ cat_ids
+ ), "Category ids are not in [1, #categories], as expected"
+ # Ensure that the category list is sorted by id
+ lvis_categories = sorted(LVIS_V1_CATEGORIES, key=lambda x: x["id"])
+ thing_classes = [k["synonyms"][0] for k in lvis_categories]
+ meta = {"thing_classes": thing_classes, "class_image_count": LVIS_V1_CATEGORY_IMAGE_COUNT}
+ return meta
+
+
+if __name__ == "__main__":
+ """
+ Test the LVIS json dataset loader.
+
+ Usage:
+ python -m detectron2.data.datasets.lvis \
+ path/to/json path/to/image_root dataset_name vis_limit
+ """
+ import sys
+ import numpy as np
+ from annotator.oneformer.detectron2.utils.logger import setup_logger
+ from PIL import Image
+ import annotator.oneformer.detectron2.data.datasets # noqa # add pre-defined metadata
+ from annotator.oneformer.detectron2.utils.visualizer import Visualizer
+
+ logger = setup_logger(name=__name__)
+ meta = MetadataCatalog.get(sys.argv[3])
+
+ dicts = load_lvis_json(sys.argv[1], sys.argv[2], sys.argv[3])
+ logger.info("Done loading {} samples.".format(len(dicts)))
+
+ dirname = "lvis-data-vis"
+ os.makedirs(dirname, exist_ok=True)
+ for d in dicts[: int(sys.argv[4])]:
+ img = np.array(Image.open(d["file_name"]))
+ visualizer = Visualizer(img, metadata=meta)
+ vis = visualizer.draw_dataset_dict(d)
+ fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
+ vis.save(fpath)
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/lvis_v0_5_categories.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/lvis_v0_5_categories.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3dab6198da614937b08682f4c9edf52bdf1d236
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/lvis_v0_5_categories.py
@@ -0,0 +1,13 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Autogen with
+# with open("lvis_v0.5_val.json", "r") as f:
+# a = json.load(f)
+# c = a["categories"]
+# for x in c:
+# del x["image_count"]
+# del x["instance_count"]
+# LVIS_CATEGORIES = repr(c) + " # noqa"
+
+# fmt: off
+LVIS_CATEGORIES = [{'frequency': 'r', 'id': 1, 'synset': 'acorn.n.01', 'synonyms': ['acorn'], 'def': 'nut from an oak tree', 'name': 'acorn'}, {'frequency': 'c', 'id': 2, 'synset': 'aerosol.n.02', 'synonyms': ['aerosol_can', 'spray_can'], 'def': 'a dispenser that holds a substance under pressure', 'name': 'aerosol_can'}, {'frequency': 'f', 'id': 3, 'synset': 'air_conditioner.n.01', 'synonyms': ['air_conditioner'], 'def': 'a machine that keeps air cool and dry', 'name': 'air_conditioner'}, {'frequency': 'f', 'id': 4, 'synset': 'airplane.n.01', 'synonyms': ['airplane', 'aeroplane'], 'def': 'an aircraft that has a fixed wing and is powered by propellers or jets', 'name': 'airplane'}, {'frequency': 'c', 'id': 5, 'synset': 'alarm_clock.n.01', 'synonyms': ['alarm_clock'], 'def': 'a clock that wakes a sleeper at some preset time', 'name': 'alarm_clock'}, {'frequency': 'c', 'id': 6, 'synset': 'alcohol.n.01', 'synonyms': ['alcohol', 'alcoholic_beverage'], 'def': 'a liquor or brew containing alcohol as the active agent', 'name': 'alcohol'}, {'frequency': 'r', 'id': 7, 'synset': 'alligator.n.02', 'synonyms': ['alligator', 'gator'], 'def': 'amphibious reptiles related to crocodiles but with shorter broader snouts', 'name': 'alligator'}, {'frequency': 'c', 'id': 8, 'synset': 'almond.n.02', 'synonyms': ['almond'], 'def': 'oval-shaped edible seed of the almond tree', 'name': 'almond'}, {'frequency': 'c', 'id': 9, 'synset': 'ambulance.n.01', 'synonyms': ['ambulance'], 'def': 'a vehicle that takes people to and from hospitals', 'name': 'ambulance'}, {'frequency': 'r', 'id': 10, 'synset': 'amplifier.n.01', 'synonyms': ['amplifier'], 'def': 'electronic equipment that increases strength of signals', 'name': 'amplifier'}, {'frequency': 'c', 'id': 11, 'synset': 'anklet.n.03', 'synonyms': ['anklet', 'ankle_bracelet'], 'def': 'an ornament worn around the ankle', 'name': 'anklet'}, {'frequency': 'f', 'id': 12, 'synset': 'antenna.n.01', 'synonyms': ['antenna', 'aerial', 'transmitting_aerial'], 'def': 'an electrical device that sends or receives radio or television signals', 'name': 'antenna'}, {'frequency': 'f', 'id': 13, 'synset': 'apple.n.01', 'synonyms': ['apple'], 'def': 'fruit with red or yellow or green skin and sweet to tart crisp whitish flesh', 'name': 'apple'}, {'frequency': 'r', 'id': 14, 'synset': 'apple_juice.n.01', 'synonyms': ['apple_juice'], 'def': 'the juice of apples', 'name': 'apple_juice'}, {'frequency': 'r', 'id': 15, 'synset': 'applesauce.n.01', 'synonyms': ['applesauce'], 'def': 'puree of stewed apples usually sweetened and spiced', 'name': 'applesauce'}, {'frequency': 'r', 'id': 16, 'synset': 'apricot.n.02', 'synonyms': ['apricot'], 'def': 'downy yellow to rosy-colored fruit resembling a small peach', 'name': 'apricot'}, {'frequency': 'f', 'id': 17, 'synset': 'apron.n.01', 'synonyms': ['apron'], 'def': 'a garment of cloth that is tied about the waist and worn to protect clothing', 'name': 'apron'}, {'frequency': 'c', 'id': 18, 'synset': 'aquarium.n.01', 'synonyms': ['aquarium', 'fish_tank'], 'def': 'a tank/pool/bowl filled with water for keeping live fish and underwater animals', 'name': 'aquarium'}, {'frequency': 'c', 'id': 19, 'synset': 'armband.n.02', 'synonyms': ['armband'], 'def': 'a band worn around the upper arm', 'name': 'armband'}, {'frequency': 'f', 'id': 20, 'synset': 'armchair.n.01', 'synonyms': ['armchair'], 'def': 'chair with a support on each side for arms', 'name': 'armchair'}, {'frequency': 'r', 'id': 21, 'synset': 'armoire.n.01', 'synonyms': ['armoire'], 'def': 'a large wardrobe or cabinet', 'name': 'armoire'}, {'frequency': 'r', 'id': 22, 'synset': 'armor.n.01', 'synonyms': ['armor', 'armour'], 'def': 'protective covering made of metal and used in combat', 'name': 'armor'}, {'frequency': 'c', 'id': 23, 'synset': 'artichoke.n.02', 'synonyms': ['artichoke'], 'def': 'a thistlelike flower head with edible fleshy leaves and heart', 'name': 'artichoke'}, {'frequency': 'f', 'id': 24, 'synset': 'ashcan.n.01', 'synonyms': ['trash_can', 'garbage_can', 'wastebin', 'dustbin', 'trash_barrel', 'trash_bin'], 'def': 'a bin that holds rubbish until it is collected', 'name': 'trash_can'}, {'frequency': 'c', 'id': 25, 'synset': 'ashtray.n.01', 'synonyms': ['ashtray'], 'def': "a receptacle for the ash from smokers' cigars or cigarettes", 'name': 'ashtray'}, {'frequency': 'c', 'id': 26, 'synset': 'asparagus.n.02', 'synonyms': ['asparagus'], 'def': 'edible young shoots of the asparagus plant', 'name': 'asparagus'}, {'frequency': 'c', 'id': 27, 'synset': 'atomizer.n.01', 'synonyms': ['atomizer', 'atomiser', 'spray', 'sprayer', 'nebulizer', 'nebuliser'], 'def': 'a dispenser that turns a liquid (such as perfume) into a fine mist', 'name': 'atomizer'}, {'frequency': 'c', 'id': 28, 'synset': 'avocado.n.01', 'synonyms': ['avocado'], 'def': 'a pear-shaped fruit with green or blackish skin and rich yellowish pulp enclosing a single large seed', 'name': 'avocado'}, {'frequency': 'c', 'id': 29, 'synset': 'award.n.02', 'synonyms': ['award', 'accolade'], 'def': 'a tangible symbol signifying approval or distinction', 'name': 'award'}, {'frequency': 'f', 'id': 30, 'synset': 'awning.n.01', 'synonyms': ['awning'], 'def': 'a canopy made of canvas to shelter people or things from rain or sun', 'name': 'awning'}, {'frequency': 'r', 'id': 31, 'synset': 'ax.n.01', 'synonyms': ['ax', 'axe'], 'def': 'an edge tool with a heavy bladed head mounted across a handle', 'name': 'ax'}, {'frequency': 'f', 'id': 32, 'synset': 'baby_buggy.n.01', 'synonyms': ['baby_buggy', 'baby_carriage', 'perambulator', 'pram', 'stroller'], 'def': 'a small vehicle with four wheels in which a baby or child is pushed around', 'name': 'baby_buggy'}, {'frequency': 'c', 'id': 33, 'synset': 'backboard.n.01', 'synonyms': ['basketball_backboard'], 'def': 'a raised vertical board with basket attached; used to play basketball', 'name': 'basketball_backboard'}, {'frequency': 'f', 'id': 34, 'synset': 'backpack.n.01', 'synonyms': ['backpack', 'knapsack', 'packsack', 'rucksack', 'haversack'], 'def': 'a bag carried by a strap on your back or shoulder', 'name': 'backpack'}, {'frequency': 'f', 'id': 35, 'synset': 'bag.n.04', 'synonyms': ['handbag', 'purse', 'pocketbook'], 'def': 'a container used for carrying money and small personal items or accessories', 'name': 'handbag'}, {'frequency': 'f', 'id': 36, 'synset': 'bag.n.06', 'synonyms': ['suitcase', 'baggage', 'luggage'], 'def': 'cases used to carry belongings when traveling', 'name': 'suitcase'}, {'frequency': 'c', 'id': 37, 'synset': 'bagel.n.01', 'synonyms': ['bagel', 'beigel'], 'def': 'glazed yeast-raised doughnut-shaped roll with hard crust', 'name': 'bagel'}, {'frequency': 'r', 'id': 38, 'synset': 'bagpipe.n.01', 'synonyms': ['bagpipe'], 'def': 'a tubular wind instrument; the player blows air into a bag and squeezes it out', 'name': 'bagpipe'}, {'frequency': 'r', 'id': 39, 'synset': 'baguet.n.01', 'synonyms': ['baguet', 'baguette'], 'def': 'narrow French stick loaf', 'name': 'baguet'}, {'frequency': 'r', 'id': 40, 'synset': 'bait.n.02', 'synonyms': ['bait', 'lure'], 'def': 'something used to lure fish or other animals into danger so they can be trapped or killed', 'name': 'bait'}, {'frequency': 'f', 'id': 41, 'synset': 'ball.n.06', 'synonyms': ['ball'], 'def': 'a spherical object used as a plaything', 'name': 'ball'}, {'frequency': 'r', 'id': 42, 'synset': 'ballet_skirt.n.01', 'synonyms': ['ballet_skirt', 'tutu'], 'def': 'very short skirt worn by ballerinas', 'name': 'ballet_skirt'}, {'frequency': 'f', 'id': 43, 'synset': 'balloon.n.01', 'synonyms': ['balloon'], 'def': 'large tough nonrigid bag filled with gas or heated air', 'name': 'balloon'}, {'frequency': 'c', 'id': 44, 'synset': 'bamboo.n.02', 'synonyms': ['bamboo'], 'def': 'woody tropical grass having hollow woody stems', 'name': 'bamboo'}, {'frequency': 'f', 'id': 45, 'synset': 'banana.n.02', 'synonyms': ['banana'], 'def': 'elongated crescent-shaped yellow fruit with soft sweet flesh', 'name': 'banana'}, {'frequency': 'r', 'id': 46, 'synset': 'band_aid.n.01', 'synonyms': ['Band_Aid'], 'def': 'trade name for an adhesive bandage to cover small cuts or blisters', 'name': 'Band_Aid'}, {'frequency': 'c', 'id': 47, 'synset': 'bandage.n.01', 'synonyms': ['bandage'], 'def': 'a piece of soft material that covers and protects an injured part of the body', 'name': 'bandage'}, {'frequency': 'c', 'id': 48, 'synset': 'bandanna.n.01', 'synonyms': ['bandanna', 'bandana'], 'def': 'large and brightly colored handkerchief; often used as a neckerchief', 'name': 'bandanna'}, {'frequency': 'r', 'id': 49, 'synset': 'banjo.n.01', 'synonyms': ['banjo'], 'def': 'a stringed instrument of the guitar family with a long neck and circular body', 'name': 'banjo'}, {'frequency': 'f', 'id': 50, 'synset': 'banner.n.01', 'synonyms': ['banner', 'streamer'], 'def': 'long strip of cloth or paper used for decoration or advertising', 'name': 'banner'}, {'frequency': 'r', 'id': 51, 'synset': 'barbell.n.01', 'synonyms': ['barbell'], 'def': 'a bar to which heavy discs are attached at each end; used in weightlifting', 'name': 'barbell'}, {'frequency': 'r', 'id': 52, 'synset': 'barge.n.01', 'synonyms': ['barge'], 'def': 'a flatbottom boat for carrying heavy loads (especially on canals)', 'name': 'barge'}, {'frequency': 'f', 'id': 53, 'synset': 'barrel.n.02', 'synonyms': ['barrel', 'cask'], 'def': 'a cylindrical container that holds liquids', 'name': 'barrel'}, {'frequency': 'c', 'id': 54, 'synset': 'barrette.n.01', 'synonyms': ['barrette'], 'def': "a pin for holding women's hair in place", 'name': 'barrette'}, {'frequency': 'c', 'id': 55, 'synset': 'barrow.n.03', 'synonyms': ['barrow', 'garden_cart', 'lawn_cart', 'wheelbarrow'], 'def': 'a cart for carrying small loads; has handles and one or more wheels', 'name': 'barrow'}, {'frequency': 'f', 'id': 56, 'synset': 'base.n.03', 'synonyms': ['baseball_base'], 'def': 'a place that the runner must touch before scoring', 'name': 'baseball_base'}, {'frequency': 'f', 'id': 57, 'synset': 'baseball.n.02', 'synonyms': ['baseball'], 'def': 'a ball used in playing baseball', 'name': 'baseball'}, {'frequency': 'f', 'id': 58, 'synset': 'baseball_bat.n.01', 'synonyms': ['baseball_bat'], 'def': 'an implement used in baseball by the batter', 'name': 'baseball_bat'}, {'frequency': 'f', 'id': 59, 'synset': 'baseball_cap.n.01', 'synonyms': ['baseball_cap', 'jockey_cap', 'golf_cap'], 'def': 'a cap with a bill', 'name': 'baseball_cap'}, {'frequency': 'f', 'id': 60, 'synset': 'baseball_glove.n.01', 'synonyms': ['baseball_glove', 'baseball_mitt'], 'def': 'the handwear used by fielders in playing baseball', 'name': 'baseball_glove'}, {'frequency': 'f', 'id': 61, 'synset': 'basket.n.01', 'synonyms': ['basket', 'handbasket'], 'def': 'a container that is usually woven and has handles', 'name': 'basket'}, {'frequency': 'c', 'id': 62, 'synset': 'basket.n.03', 'synonyms': ['basketball_hoop'], 'def': 'metal hoop supporting a net through which players try to throw the basketball', 'name': 'basketball_hoop'}, {'frequency': 'c', 'id': 63, 'synset': 'basketball.n.02', 'synonyms': ['basketball'], 'def': 'an inflated ball used in playing basketball', 'name': 'basketball'}, {'frequency': 'r', 'id': 64, 'synset': 'bass_horn.n.01', 'synonyms': ['bass_horn', 'sousaphone', 'tuba'], 'def': 'the lowest brass wind instrument', 'name': 'bass_horn'}, {'frequency': 'r', 'id': 65, 'synset': 'bat.n.01', 'synonyms': ['bat_(animal)'], 'def': 'nocturnal mouselike mammal with forelimbs modified to form membranous wings', 'name': 'bat_(animal)'}, {'frequency': 'f', 'id': 66, 'synset': 'bath_mat.n.01', 'synonyms': ['bath_mat'], 'def': 'a heavy towel or mat to stand on while drying yourself after a bath', 'name': 'bath_mat'}, {'frequency': 'f', 'id': 67, 'synset': 'bath_towel.n.01', 'synonyms': ['bath_towel'], 'def': 'a large towel; to dry yourself after a bath', 'name': 'bath_towel'}, {'frequency': 'c', 'id': 68, 'synset': 'bathrobe.n.01', 'synonyms': ['bathrobe'], 'def': 'a loose-fitting robe of towelling; worn after a bath or swim', 'name': 'bathrobe'}, {'frequency': 'f', 'id': 69, 'synset': 'bathtub.n.01', 'synonyms': ['bathtub', 'bathing_tub'], 'def': 'a large open container that you fill with water and use to wash the body', 'name': 'bathtub'}, {'frequency': 'r', 'id': 70, 'synset': 'batter.n.02', 'synonyms': ['batter_(food)'], 'def': 'a liquid or semiliquid mixture, as of flour, eggs, and milk, used in cooking', 'name': 'batter_(food)'}, {'frequency': 'c', 'id': 71, 'synset': 'battery.n.02', 'synonyms': ['battery'], 'def': 'a portable device that produces electricity', 'name': 'battery'}, {'frequency': 'r', 'id': 72, 'synset': 'beach_ball.n.01', 'synonyms': ['beachball'], 'def': 'large and light ball; for play at the seaside', 'name': 'beachball'}, {'frequency': 'c', 'id': 73, 'synset': 'bead.n.01', 'synonyms': ['bead'], 'def': 'a small ball with a hole through the middle used for ornamentation, jewellery, etc.', 'name': 'bead'}, {'frequency': 'r', 'id': 74, 'synset': 'beaker.n.01', 'synonyms': ['beaker'], 'def': 'a flatbottomed jar made of glass or plastic; used for chemistry', 'name': 'beaker'}, {'frequency': 'c', 'id': 75, 'synset': 'bean_curd.n.01', 'synonyms': ['bean_curd', 'tofu'], 'def': 'cheeselike food made of curdled soybean milk', 'name': 'bean_curd'}, {'frequency': 'c', 'id': 76, 'synset': 'beanbag.n.01', 'synonyms': ['beanbag'], 'def': 'a bag filled with dried beans or similar items; used in games or to sit on', 'name': 'beanbag'}, {'frequency': 'f', 'id': 77, 'synset': 'beanie.n.01', 'synonyms': ['beanie', 'beany'], 'def': 'a small skullcap; formerly worn by schoolboys and college freshmen', 'name': 'beanie'}, {'frequency': 'f', 'id': 78, 'synset': 'bear.n.01', 'synonyms': ['bear'], 'def': 'large carnivorous or omnivorous mammals with shaggy coats and claws', 'name': 'bear'}, {'frequency': 'f', 'id': 79, 'synset': 'bed.n.01', 'synonyms': ['bed'], 'def': 'a piece of furniture that provides a place to sleep', 'name': 'bed'}, {'frequency': 'c', 'id': 80, 'synset': 'bedspread.n.01', 'synonyms': ['bedspread', 'bedcover', 'bed_covering', 'counterpane', 'spread'], 'def': 'decorative cover for a bed', 'name': 'bedspread'}, {'frequency': 'f', 'id': 81, 'synset': 'beef.n.01', 'synonyms': ['cow'], 'def': 'cattle that are reared for their meat', 'name': 'cow'}, {'frequency': 'c', 'id': 82, 'synset': 'beef.n.02', 'synonyms': ['beef_(food)', 'boeuf_(food)'], 'def': 'meat from an adult domestic bovine', 'name': 'beef_(food)'}, {'frequency': 'r', 'id': 83, 'synset': 'beeper.n.01', 'synonyms': ['beeper', 'pager'], 'def': 'an device that beeps when the person carrying it is being paged', 'name': 'beeper'}, {'frequency': 'f', 'id': 84, 'synset': 'beer_bottle.n.01', 'synonyms': ['beer_bottle'], 'def': 'a bottle that holds beer', 'name': 'beer_bottle'}, {'frequency': 'c', 'id': 85, 'synset': 'beer_can.n.01', 'synonyms': ['beer_can'], 'def': 'a can that holds beer', 'name': 'beer_can'}, {'frequency': 'r', 'id': 86, 'synset': 'beetle.n.01', 'synonyms': ['beetle'], 'def': 'insect with hard wing covers', 'name': 'beetle'}, {'frequency': 'f', 'id': 87, 'synset': 'bell.n.01', 'synonyms': ['bell'], 'def': 'a hollow device made of metal that makes a ringing sound when struck', 'name': 'bell'}, {'frequency': 'f', 'id': 88, 'synset': 'bell_pepper.n.02', 'synonyms': ['bell_pepper', 'capsicum'], 'def': 'large bell-shaped sweet pepper in green or red or yellow or orange or black varieties', 'name': 'bell_pepper'}, {'frequency': 'f', 'id': 89, 'synset': 'belt.n.02', 'synonyms': ['belt'], 'def': 'a band to tie or buckle around the body (usually at the waist)', 'name': 'belt'}, {'frequency': 'f', 'id': 90, 'synset': 'belt_buckle.n.01', 'synonyms': ['belt_buckle'], 'def': 'the buckle used to fasten a belt', 'name': 'belt_buckle'}, {'frequency': 'f', 'id': 91, 'synset': 'bench.n.01', 'synonyms': ['bench'], 'def': 'a long seat for more than one person', 'name': 'bench'}, {'frequency': 'c', 'id': 92, 'synset': 'beret.n.01', 'synonyms': ['beret'], 'def': 'a cap with no brim or bill; made of soft cloth', 'name': 'beret'}, {'frequency': 'c', 'id': 93, 'synset': 'bib.n.02', 'synonyms': ['bib'], 'def': 'a napkin tied under the chin of a child while eating', 'name': 'bib'}, {'frequency': 'r', 'id': 94, 'synset': 'bible.n.01', 'synonyms': ['Bible'], 'def': 'the sacred writings of the Christian religions', 'name': 'Bible'}, {'frequency': 'f', 'id': 95, 'synset': 'bicycle.n.01', 'synonyms': ['bicycle', 'bike_(bicycle)'], 'def': 'a wheeled vehicle that has two wheels and is moved by foot pedals', 'name': 'bicycle'}, {'frequency': 'f', 'id': 96, 'synset': 'bill.n.09', 'synonyms': ['visor', 'vizor'], 'def': 'a brim that projects to the front to shade the eyes', 'name': 'visor'}, {'frequency': 'c', 'id': 97, 'synset': 'binder.n.03', 'synonyms': ['binder', 'ring-binder'], 'def': 'holds loose papers or magazines', 'name': 'binder'}, {'frequency': 'c', 'id': 98, 'synset': 'binoculars.n.01', 'synonyms': ['binoculars', 'field_glasses', 'opera_glasses'], 'def': 'an optical instrument designed for simultaneous use by both eyes', 'name': 'binoculars'}, {'frequency': 'f', 'id': 99, 'synset': 'bird.n.01', 'synonyms': ['bird'], 'def': 'animal characterized by feathers and wings', 'name': 'bird'}, {'frequency': 'r', 'id': 100, 'synset': 'bird_feeder.n.01', 'synonyms': ['birdfeeder'], 'def': 'an outdoor device that supplies food for wild birds', 'name': 'birdfeeder'}, {'frequency': 'r', 'id': 101, 'synset': 'birdbath.n.01', 'synonyms': ['birdbath'], 'def': 'an ornamental basin (usually in a garden) for birds to bathe in', 'name': 'birdbath'}, {'frequency': 'c', 'id': 102, 'synset': 'birdcage.n.01', 'synonyms': ['birdcage'], 'def': 'a cage in which a bird can be kept', 'name': 'birdcage'}, {'frequency': 'c', 'id': 103, 'synset': 'birdhouse.n.01', 'synonyms': ['birdhouse'], 'def': 'a shelter for birds', 'name': 'birdhouse'}, {'frequency': 'f', 'id': 104, 'synset': 'birthday_cake.n.01', 'synonyms': ['birthday_cake'], 'def': 'decorated cake served at a birthday party', 'name': 'birthday_cake'}, {'frequency': 'r', 'id': 105, 'synset': 'birthday_card.n.01', 'synonyms': ['birthday_card'], 'def': 'a card expressing a birthday greeting', 'name': 'birthday_card'}, {'frequency': 'r', 'id': 106, 'synset': 'biscuit.n.01', 'synonyms': ['biscuit_(bread)'], 'def': 'small round bread leavened with baking-powder or soda', 'name': 'biscuit_(bread)'}, {'frequency': 'r', 'id': 107, 'synset': 'black_flag.n.01', 'synonyms': ['pirate_flag'], 'def': 'a flag usually bearing a white skull and crossbones on a black background', 'name': 'pirate_flag'}, {'frequency': 'c', 'id': 108, 'synset': 'black_sheep.n.02', 'synonyms': ['black_sheep'], 'def': 'sheep with a black coat', 'name': 'black_sheep'}, {'frequency': 'c', 'id': 109, 'synset': 'blackboard.n.01', 'synonyms': ['blackboard', 'chalkboard'], 'def': 'sheet of slate; for writing with chalk', 'name': 'blackboard'}, {'frequency': 'f', 'id': 110, 'synset': 'blanket.n.01', 'synonyms': ['blanket'], 'def': 'bedding that keeps a person warm in bed', 'name': 'blanket'}, {'frequency': 'c', 'id': 111, 'synset': 'blazer.n.01', 'synonyms': ['blazer', 'sport_jacket', 'sport_coat', 'sports_jacket', 'sports_coat'], 'def': 'lightweight jacket; often striped in the colors of a club or school', 'name': 'blazer'}, {'frequency': 'f', 'id': 112, 'synset': 'blender.n.01', 'synonyms': ['blender', 'liquidizer', 'liquidiser'], 'def': 'an electrically powered mixer that mix or chop or liquefy foods', 'name': 'blender'}, {'frequency': 'r', 'id': 113, 'synset': 'blimp.n.02', 'synonyms': ['blimp'], 'def': 'a small nonrigid airship used for observation or as a barrage balloon', 'name': 'blimp'}, {'frequency': 'c', 'id': 114, 'synset': 'blinker.n.01', 'synonyms': ['blinker', 'flasher'], 'def': 'a light that flashes on and off; used as a signal or to send messages', 'name': 'blinker'}, {'frequency': 'c', 'id': 115, 'synset': 'blueberry.n.02', 'synonyms': ['blueberry'], 'def': 'sweet edible dark-blue berries of blueberry plants', 'name': 'blueberry'}, {'frequency': 'r', 'id': 116, 'synset': 'boar.n.02', 'synonyms': ['boar'], 'def': 'an uncastrated male hog', 'name': 'boar'}, {'frequency': 'r', 'id': 117, 'synset': 'board.n.09', 'synonyms': ['gameboard'], 'def': 'a flat portable surface (usually rectangular) designed for board games', 'name': 'gameboard'}, {'frequency': 'f', 'id': 118, 'synset': 'boat.n.01', 'synonyms': ['boat', 'ship_(boat)'], 'def': 'a vessel for travel on water', 'name': 'boat'}, {'frequency': 'c', 'id': 119, 'synset': 'bobbin.n.01', 'synonyms': ['bobbin', 'spool', 'reel'], 'def': 'a thing around which thread/tape/film or other flexible materials can be wound', 'name': 'bobbin'}, {'frequency': 'r', 'id': 120, 'synset': 'bobby_pin.n.01', 'synonyms': ['bobby_pin', 'hairgrip'], 'def': 'a flat wire hairpin used to hold bobbed hair in place', 'name': 'bobby_pin'}, {'frequency': 'c', 'id': 121, 'synset': 'boiled_egg.n.01', 'synonyms': ['boiled_egg', 'coddled_egg'], 'def': 'egg cooked briefly in the shell in gently boiling water', 'name': 'boiled_egg'}, {'frequency': 'r', 'id': 122, 'synset': 'bolo_tie.n.01', 'synonyms': ['bolo_tie', 'bolo', 'bola_tie', 'bola'], 'def': 'a cord fastened around the neck with an ornamental clasp and worn as a necktie', 'name': 'bolo_tie'}, {'frequency': 'c', 'id': 123, 'synset': 'bolt.n.03', 'synonyms': ['deadbolt'], 'def': 'the part of a lock that is engaged or withdrawn with a key', 'name': 'deadbolt'}, {'frequency': 'f', 'id': 124, 'synset': 'bolt.n.06', 'synonyms': ['bolt'], 'def': 'a screw that screws into a nut to form a fastener', 'name': 'bolt'}, {'frequency': 'r', 'id': 125, 'synset': 'bonnet.n.01', 'synonyms': ['bonnet'], 'def': 'a hat tied under the chin', 'name': 'bonnet'}, {'frequency': 'f', 'id': 126, 'synset': 'book.n.01', 'synonyms': ['book'], 'def': 'a written work or composition that has been published', 'name': 'book'}, {'frequency': 'r', 'id': 127, 'synset': 'book_bag.n.01', 'synonyms': ['book_bag'], 'def': 'a bag in which students carry their books', 'name': 'book_bag'}, {'frequency': 'c', 'id': 128, 'synset': 'bookcase.n.01', 'synonyms': ['bookcase'], 'def': 'a piece of furniture with shelves for storing books', 'name': 'bookcase'}, {'frequency': 'c', 'id': 129, 'synset': 'booklet.n.01', 'synonyms': ['booklet', 'brochure', 'leaflet', 'pamphlet'], 'def': 'a small book usually having a paper cover', 'name': 'booklet'}, {'frequency': 'r', 'id': 130, 'synset': 'bookmark.n.01', 'synonyms': ['bookmark', 'bookmarker'], 'def': 'a marker (a piece of paper or ribbon) placed between the pages of a book', 'name': 'bookmark'}, {'frequency': 'r', 'id': 131, 'synset': 'boom.n.04', 'synonyms': ['boom_microphone', 'microphone_boom'], 'def': 'a pole carrying an overhead microphone projected over a film or tv set', 'name': 'boom_microphone'}, {'frequency': 'f', 'id': 132, 'synset': 'boot.n.01', 'synonyms': ['boot'], 'def': 'footwear that covers the whole foot and lower leg', 'name': 'boot'}, {'frequency': 'f', 'id': 133, 'synset': 'bottle.n.01', 'synonyms': ['bottle'], 'def': 'a glass or plastic vessel used for storing drinks or other liquids', 'name': 'bottle'}, {'frequency': 'c', 'id': 134, 'synset': 'bottle_opener.n.01', 'synonyms': ['bottle_opener'], 'def': 'an opener for removing caps or corks from bottles', 'name': 'bottle_opener'}, {'frequency': 'c', 'id': 135, 'synset': 'bouquet.n.01', 'synonyms': ['bouquet'], 'def': 'an arrangement of flowers that is usually given as a present', 'name': 'bouquet'}, {'frequency': 'r', 'id': 136, 'synset': 'bow.n.04', 'synonyms': ['bow_(weapon)'], 'def': 'a weapon for shooting arrows', 'name': 'bow_(weapon)'}, {'frequency': 'f', 'id': 137, 'synset': 'bow.n.08', 'synonyms': ['bow_(decorative_ribbons)'], 'def': 'a decorative interlacing of ribbons', 'name': 'bow_(decorative_ribbons)'}, {'frequency': 'f', 'id': 138, 'synset': 'bow_tie.n.01', 'synonyms': ['bow-tie', 'bowtie'], 'def': "a man's tie that ties in a bow", 'name': 'bow-tie'}, {'frequency': 'f', 'id': 139, 'synset': 'bowl.n.03', 'synonyms': ['bowl'], 'def': 'a dish that is round and open at the top for serving foods', 'name': 'bowl'}, {'frequency': 'r', 'id': 140, 'synset': 'bowl.n.08', 'synonyms': ['pipe_bowl'], 'def': 'a small round container that is open at the top for holding tobacco', 'name': 'pipe_bowl'}, {'frequency': 'c', 'id': 141, 'synset': 'bowler_hat.n.01', 'synonyms': ['bowler_hat', 'bowler', 'derby_hat', 'derby', 'plug_hat'], 'def': 'a felt hat that is round and hard with a narrow brim', 'name': 'bowler_hat'}, {'frequency': 'r', 'id': 142, 'synset': 'bowling_ball.n.01', 'synonyms': ['bowling_ball'], 'def': 'a large ball with finger holes used in the sport of bowling', 'name': 'bowling_ball'}, {'frequency': 'r', 'id': 143, 'synset': 'bowling_pin.n.01', 'synonyms': ['bowling_pin'], 'def': 'a club-shaped wooden object used in bowling', 'name': 'bowling_pin'}, {'frequency': 'r', 'id': 144, 'synset': 'boxing_glove.n.01', 'synonyms': ['boxing_glove'], 'def': 'large glove coverings the fists of a fighter worn for the sport of boxing', 'name': 'boxing_glove'}, {'frequency': 'c', 'id': 145, 'synset': 'brace.n.06', 'synonyms': ['suspenders'], 'def': 'elastic straps that hold trousers up (usually used in the plural)', 'name': 'suspenders'}, {'frequency': 'f', 'id': 146, 'synset': 'bracelet.n.02', 'synonyms': ['bracelet', 'bangle'], 'def': 'jewelry worn around the wrist for decoration', 'name': 'bracelet'}, {'frequency': 'r', 'id': 147, 'synset': 'brass.n.07', 'synonyms': ['brass_plaque'], 'def': 'a memorial made of brass', 'name': 'brass_plaque'}, {'frequency': 'c', 'id': 148, 'synset': 'brassiere.n.01', 'synonyms': ['brassiere', 'bra', 'bandeau'], 'def': 'an undergarment worn by women to support their breasts', 'name': 'brassiere'}, {'frequency': 'c', 'id': 149, 'synset': 'bread-bin.n.01', 'synonyms': ['bread-bin', 'breadbox'], 'def': 'a container used to keep bread or cake in', 'name': 'bread-bin'}, {'frequency': 'r', 'id': 150, 'synset': 'breechcloth.n.01', 'synonyms': ['breechcloth', 'breechclout', 'loincloth'], 'def': 'a garment that provides covering for the loins', 'name': 'breechcloth'}, {'frequency': 'c', 'id': 151, 'synset': 'bridal_gown.n.01', 'synonyms': ['bridal_gown', 'wedding_gown', 'wedding_dress'], 'def': 'a gown worn by the bride at a wedding', 'name': 'bridal_gown'}, {'frequency': 'c', 'id': 152, 'synset': 'briefcase.n.01', 'synonyms': ['briefcase'], 'def': 'a case with a handle; for carrying papers or files or books', 'name': 'briefcase'}, {'frequency': 'c', 'id': 153, 'synset': 'bristle_brush.n.01', 'synonyms': ['bristle_brush'], 'def': 'a brush that is made with the short stiff hairs of an animal or plant', 'name': 'bristle_brush'}, {'frequency': 'f', 'id': 154, 'synset': 'broccoli.n.01', 'synonyms': ['broccoli'], 'def': 'plant with dense clusters of tight green flower buds', 'name': 'broccoli'}, {'frequency': 'r', 'id': 155, 'synset': 'brooch.n.01', 'synonyms': ['broach'], 'def': 'a decorative pin worn by women', 'name': 'broach'}, {'frequency': 'c', 'id': 156, 'synset': 'broom.n.01', 'synonyms': ['broom'], 'def': 'bundle of straws or twigs attached to a long handle; used for cleaning', 'name': 'broom'}, {'frequency': 'c', 'id': 157, 'synset': 'brownie.n.03', 'synonyms': ['brownie'], 'def': 'square or bar of very rich chocolate cake usually with nuts', 'name': 'brownie'}, {'frequency': 'c', 'id': 158, 'synset': 'brussels_sprouts.n.01', 'synonyms': ['brussels_sprouts'], 'def': 'the small edible cabbage-like buds growing along a stalk', 'name': 'brussels_sprouts'}, {'frequency': 'r', 'id': 159, 'synset': 'bubble_gum.n.01', 'synonyms': ['bubble_gum'], 'def': 'a kind of chewing gum that can be blown into bubbles', 'name': 'bubble_gum'}, {'frequency': 'f', 'id': 160, 'synset': 'bucket.n.01', 'synonyms': ['bucket', 'pail'], 'def': 'a roughly cylindrical vessel that is open at the top', 'name': 'bucket'}, {'frequency': 'r', 'id': 161, 'synset': 'buggy.n.01', 'synonyms': ['horse_buggy'], 'def': 'a small lightweight carriage; drawn by a single horse', 'name': 'horse_buggy'}, {'frequency': 'c', 'id': 162, 'synset': 'bull.n.11', 'synonyms': ['bull'], 'def': 'mature male cow', 'name': 'bull'}, {'frequency': 'r', 'id': 163, 'synset': 'bulldog.n.01', 'synonyms': ['bulldog'], 'def': 'a thickset short-haired dog with a large head and strong undershot lower jaw', 'name': 'bulldog'}, {'frequency': 'r', 'id': 164, 'synset': 'bulldozer.n.01', 'synonyms': ['bulldozer', 'dozer'], 'def': 'large powerful tractor; a large blade in front flattens areas of ground', 'name': 'bulldozer'}, {'frequency': 'c', 'id': 165, 'synset': 'bullet_train.n.01', 'synonyms': ['bullet_train'], 'def': 'a high-speed passenger train', 'name': 'bullet_train'}, {'frequency': 'c', 'id': 166, 'synset': 'bulletin_board.n.02', 'synonyms': ['bulletin_board', 'notice_board'], 'def': 'a board that hangs on a wall; displays announcements', 'name': 'bulletin_board'}, {'frequency': 'r', 'id': 167, 'synset': 'bulletproof_vest.n.01', 'synonyms': ['bulletproof_vest'], 'def': 'a vest capable of resisting the impact of a bullet', 'name': 'bulletproof_vest'}, {'frequency': 'c', 'id': 168, 'synset': 'bullhorn.n.01', 'synonyms': ['bullhorn', 'megaphone'], 'def': 'a portable loudspeaker with built-in microphone and amplifier', 'name': 'bullhorn'}, {'frequency': 'r', 'id': 169, 'synset': 'bully_beef.n.01', 'synonyms': ['corned_beef', 'corn_beef'], 'def': 'beef cured or pickled in brine', 'name': 'corned_beef'}, {'frequency': 'f', 'id': 170, 'synset': 'bun.n.01', 'synonyms': ['bun', 'roll'], 'def': 'small rounded bread either plain or sweet', 'name': 'bun'}, {'frequency': 'c', 'id': 171, 'synset': 'bunk_bed.n.01', 'synonyms': ['bunk_bed'], 'def': 'beds built one above the other', 'name': 'bunk_bed'}, {'frequency': 'f', 'id': 172, 'synset': 'buoy.n.01', 'synonyms': ['buoy'], 'def': 'a float attached by rope to the seabed to mark channels in a harbor or underwater hazards', 'name': 'buoy'}, {'frequency': 'r', 'id': 173, 'synset': 'burrito.n.01', 'synonyms': ['burrito'], 'def': 'a flour tortilla folded around a filling', 'name': 'burrito'}, {'frequency': 'f', 'id': 174, 'synset': 'bus.n.01', 'synonyms': ['bus_(vehicle)', 'autobus', 'charabanc', 'double-decker', 'motorbus', 'motorcoach'], 'def': 'a vehicle carrying many passengers; used for public transport', 'name': 'bus_(vehicle)'}, {'frequency': 'c', 'id': 175, 'synset': 'business_card.n.01', 'synonyms': ['business_card'], 'def': "a card on which are printed the person's name and business affiliation", 'name': 'business_card'}, {'frequency': 'c', 'id': 176, 'synset': 'butcher_knife.n.01', 'synonyms': ['butcher_knife'], 'def': 'a large sharp knife for cutting or trimming meat', 'name': 'butcher_knife'}, {'frequency': 'c', 'id': 177, 'synset': 'butter.n.01', 'synonyms': ['butter'], 'def': 'an edible emulsion of fat globules made by churning milk or cream; for cooking and table use', 'name': 'butter'}, {'frequency': 'c', 'id': 178, 'synset': 'butterfly.n.01', 'synonyms': ['butterfly'], 'def': 'insect typically having a slender body with knobbed antennae and broad colorful wings', 'name': 'butterfly'}, {'frequency': 'f', 'id': 179, 'synset': 'button.n.01', 'synonyms': ['button'], 'def': 'a round fastener sewn to shirts and coats etc to fit through buttonholes', 'name': 'button'}, {'frequency': 'f', 'id': 180, 'synset': 'cab.n.03', 'synonyms': ['cab_(taxi)', 'taxi', 'taxicab'], 'def': 'a car that takes passengers where they want to go in exchange for money', 'name': 'cab_(taxi)'}, {'frequency': 'r', 'id': 181, 'synset': 'cabana.n.01', 'synonyms': ['cabana'], 'def': 'a small tent used as a dressing room beside the sea or a swimming pool', 'name': 'cabana'}, {'frequency': 'r', 'id': 182, 'synset': 'cabin_car.n.01', 'synonyms': ['cabin_car', 'caboose'], 'def': 'a car on a freight train for use of the train crew; usually the last car on the train', 'name': 'cabin_car'}, {'frequency': 'f', 'id': 183, 'synset': 'cabinet.n.01', 'synonyms': ['cabinet'], 'def': 'a piece of furniture resembling a cupboard with doors and shelves and drawers', 'name': 'cabinet'}, {'frequency': 'r', 'id': 184, 'synset': 'cabinet.n.03', 'synonyms': ['locker', 'storage_locker'], 'def': 'a storage compartment for clothes and valuables; usually it has a lock', 'name': 'locker'}, {'frequency': 'f', 'id': 185, 'synset': 'cake.n.03', 'synonyms': ['cake'], 'def': 'baked goods made from or based on a mixture of flour, sugar, eggs, and fat', 'name': 'cake'}, {'frequency': 'c', 'id': 186, 'synset': 'calculator.n.02', 'synonyms': ['calculator'], 'def': 'a small machine that is used for mathematical calculations', 'name': 'calculator'}, {'frequency': 'f', 'id': 187, 'synset': 'calendar.n.02', 'synonyms': ['calendar'], 'def': 'a list or register of events (appointments/social events/court cases, etc)', 'name': 'calendar'}, {'frequency': 'c', 'id': 188, 'synset': 'calf.n.01', 'synonyms': ['calf'], 'def': 'young of domestic cattle', 'name': 'calf'}, {'frequency': 'c', 'id': 189, 'synset': 'camcorder.n.01', 'synonyms': ['camcorder'], 'def': 'a portable television camera and videocassette recorder', 'name': 'camcorder'}, {'frequency': 'c', 'id': 190, 'synset': 'camel.n.01', 'synonyms': ['camel'], 'def': 'cud-chewing mammal used as a draft or saddle animal in desert regions', 'name': 'camel'}, {'frequency': 'f', 'id': 191, 'synset': 'camera.n.01', 'synonyms': ['camera'], 'def': 'equipment for taking photographs', 'name': 'camera'}, {'frequency': 'c', 'id': 192, 'synset': 'camera_lens.n.01', 'synonyms': ['camera_lens'], 'def': 'a lens that focuses the image in a camera', 'name': 'camera_lens'}, {'frequency': 'c', 'id': 193, 'synset': 'camper.n.02', 'synonyms': ['camper_(vehicle)', 'camping_bus', 'motor_home'], 'def': 'a recreational vehicle equipped for camping out while traveling', 'name': 'camper_(vehicle)'}, {'frequency': 'f', 'id': 194, 'synset': 'can.n.01', 'synonyms': ['can', 'tin_can'], 'def': 'airtight sealed metal container for food or drink or paint etc.', 'name': 'can'}, {'frequency': 'c', 'id': 195, 'synset': 'can_opener.n.01', 'synonyms': ['can_opener', 'tin_opener'], 'def': 'a device for cutting cans open', 'name': 'can_opener'}, {'frequency': 'r', 'id': 196, 'synset': 'candelabrum.n.01', 'synonyms': ['candelabrum', 'candelabra'], 'def': 'branched candlestick; ornamental; has several lights', 'name': 'candelabrum'}, {'frequency': 'f', 'id': 197, 'synset': 'candle.n.01', 'synonyms': ['candle', 'candlestick'], 'def': 'stick of wax with a wick in the middle', 'name': 'candle'}, {'frequency': 'f', 'id': 198, 'synset': 'candlestick.n.01', 'synonyms': ['candle_holder'], 'def': 'a holder with sockets for candles', 'name': 'candle_holder'}, {'frequency': 'r', 'id': 199, 'synset': 'candy_bar.n.01', 'synonyms': ['candy_bar'], 'def': 'a candy shaped as a bar', 'name': 'candy_bar'}, {'frequency': 'c', 'id': 200, 'synset': 'candy_cane.n.01', 'synonyms': ['candy_cane'], 'def': 'a hard candy in the shape of a rod (usually with stripes)', 'name': 'candy_cane'}, {'frequency': 'c', 'id': 201, 'synset': 'cane.n.01', 'synonyms': ['walking_cane'], 'def': 'a stick that people can lean on to help them walk', 'name': 'walking_cane'}, {'frequency': 'c', 'id': 202, 'synset': 'canister.n.02', 'synonyms': ['canister', 'cannister'], 'def': 'metal container for storing dry foods such as tea or flour', 'name': 'canister'}, {'frequency': 'r', 'id': 203, 'synset': 'cannon.n.02', 'synonyms': ['cannon'], 'def': 'heavy gun fired from a tank', 'name': 'cannon'}, {'frequency': 'c', 'id': 204, 'synset': 'canoe.n.01', 'synonyms': ['canoe'], 'def': 'small and light boat; pointed at both ends; propelled with a paddle', 'name': 'canoe'}, {'frequency': 'r', 'id': 205, 'synset': 'cantaloup.n.02', 'synonyms': ['cantaloup', 'cantaloupe'], 'def': 'the fruit of a cantaloup vine; small to medium-sized melon with yellowish flesh', 'name': 'cantaloup'}, {'frequency': 'r', 'id': 206, 'synset': 'canteen.n.01', 'synonyms': ['canteen'], 'def': 'a flask for carrying water; used by soldiers or travelers', 'name': 'canteen'}, {'frequency': 'c', 'id': 207, 'synset': 'cap.n.01', 'synonyms': ['cap_(headwear)'], 'def': 'a tight-fitting headwear', 'name': 'cap_(headwear)'}, {'frequency': 'f', 'id': 208, 'synset': 'cap.n.02', 'synonyms': ['bottle_cap', 'cap_(container_lid)'], 'def': 'a top (as for a bottle)', 'name': 'bottle_cap'}, {'frequency': 'r', 'id': 209, 'synset': 'cape.n.02', 'synonyms': ['cape'], 'def': 'a sleeveless garment like a cloak but shorter', 'name': 'cape'}, {'frequency': 'c', 'id': 210, 'synset': 'cappuccino.n.01', 'synonyms': ['cappuccino', 'coffee_cappuccino'], 'def': 'equal parts of espresso and steamed milk', 'name': 'cappuccino'}, {'frequency': 'f', 'id': 211, 'synset': 'car.n.01', 'synonyms': ['car_(automobile)', 'auto_(automobile)', 'automobile'], 'def': 'a motor vehicle with four wheels', 'name': 'car_(automobile)'}, {'frequency': 'f', 'id': 212, 'synset': 'car.n.02', 'synonyms': ['railcar_(part_of_a_train)', 'railway_car_(part_of_a_train)', 'railroad_car_(part_of_a_train)'], 'def': 'a wheeled vehicle adapted to the rails of railroad', 'name': 'railcar_(part_of_a_train)'}, {'frequency': 'r', 'id': 213, 'synset': 'car.n.04', 'synonyms': ['elevator_car'], 'def': 'where passengers ride up and down', 'name': 'elevator_car'}, {'frequency': 'r', 'id': 214, 'synset': 'car_battery.n.01', 'synonyms': ['car_battery', 'automobile_battery'], 'def': 'a battery in a motor vehicle', 'name': 'car_battery'}, {'frequency': 'c', 'id': 215, 'synset': 'card.n.02', 'synonyms': ['identity_card'], 'def': 'a card certifying the identity of the bearer', 'name': 'identity_card'}, {'frequency': 'c', 'id': 216, 'synset': 'card.n.03', 'synonyms': ['card'], 'def': 'a rectangular piece of paper used to send messages (e.g. greetings or pictures)', 'name': 'card'}, {'frequency': 'r', 'id': 217, 'synset': 'cardigan.n.01', 'synonyms': ['cardigan'], 'def': 'knitted jacket that is fastened up the front with buttons or a zipper', 'name': 'cardigan'}, {'frequency': 'r', 'id': 218, 'synset': 'cargo_ship.n.01', 'synonyms': ['cargo_ship', 'cargo_vessel'], 'def': 'a ship designed to carry cargo', 'name': 'cargo_ship'}, {'frequency': 'r', 'id': 219, 'synset': 'carnation.n.01', 'synonyms': ['carnation'], 'def': 'plant with pink to purple-red spice-scented usually double flowers', 'name': 'carnation'}, {'frequency': 'c', 'id': 220, 'synset': 'carriage.n.02', 'synonyms': ['horse_carriage'], 'def': 'a vehicle with wheels drawn by one or more horses', 'name': 'horse_carriage'}, {'frequency': 'f', 'id': 221, 'synset': 'carrot.n.01', 'synonyms': ['carrot'], 'def': 'deep orange edible root of the cultivated carrot plant', 'name': 'carrot'}, {'frequency': 'c', 'id': 222, 'synset': 'carryall.n.01', 'synonyms': ['tote_bag'], 'def': 'a capacious bag or basket', 'name': 'tote_bag'}, {'frequency': 'c', 'id': 223, 'synset': 'cart.n.01', 'synonyms': ['cart'], 'def': 'a heavy open wagon usually having two wheels and drawn by an animal', 'name': 'cart'}, {'frequency': 'c', 'id': 224, 'synset': 'carton.n.02', 'synonyms': ['carton'], 'def': 'a box made of cardboard; opens by flaps on top', 'name': 'carton'}, {'frequency': 'c', 'id': 225, 'synset': 'cash_register.n.01', 'synonyms': ['cash_register', 'register_(for_cash_transactions)'], 'def': 'a cashbox with an adding machine to register transactions', 'name': 'cash_register'}, {'frequency': 'r', 'id': 226, 'synset': 'casserole.n.01', 'synonyms': ['casserole'], 'def': 'food cooked and served in a casserole', 'name': 'casserole'}, {'frequency': 'r', 'id': 227, 'synset': 'cassette.n.01', 'synonyms': ['cassette'], 'def': 'a container that holds a magnetic tape used for recording or playing sound or video', 'name': 'cassette'}, {'frequency': 'c', 'id': 228, 'synset': 'cast.n.05', 'synonyms': ['cast', 'plaster_cast', 'plaster_bandage'], 'def': 'bandage consisting of a firm covering that immobilizes broken bones while they heal', 'name': 'cast'}, {'frequency': 'f', 'id': 229, 'synset': 'cat.n.01', 'synonyms': ['cat'], 'def': 'a domestic house cat', 'name': 'cat'}, {'frequency': 'c', 'id': 230, 'synset': 'cauliflower.n.02', 'synonyms': ['cauliflower'], 'def': 'edible compact head of white undeveloped flowers', 'name': 'cauliflower'}, {'frequency': 'r', 'id': 231, 'synset': 'caviar.n.01', 'synonyms': ['caviar', 'caviare'], 'def': "salted roe of sturgeon or other large fish; usually served as an hors d'oeuvre", 'name': 'caviar'}, {'frequency': 'c', 'id': 232, 'synset': 'cayenne.n.02', 'synonyms': ['cayenne_(spice)', 'cayenne_pepper_(spice)', 'red_pepper_(spice)'], 'def': 'ground pods and seeds of pungent red peppers of the genus Capsicum', 'name': 'cayenne_(spice)'}, {'frequency': 'c', 'id': 233, 'synset': 'cd_player.n.01', 'synonyms': ['CD_player'], 'def': 'electronic equipment for playing compact discs (CDs)', 'name': 'CD_player'}, {'frequency': 'c', 'id': 234, 'synset': 'celery.n.01', 'synonyms': ['celery'], 'def': 'widely cultivated herb with aromatic leaf stalks that are eaten raw or cooked', 'name': 'celery'}, {'frequency': 'f', 'id': 235, 'synset': 'cellular_telephone.n.01', 'synonyms': ['cellular_telephone', 'cellular_phone', 'cellphone', 'mobile_phone', 'smart_phone'], 'def': 'a hand-held mobile telephone', 'name': 'cellular_telephone'}, {'frequency': 'r', 'id': 236, 'synset': 'chain_mail.n.01', 'synonyms': ['chain_mail', 'ring_mail', 'chain_armor', 'chain_armour', 'ring_armor', 'ring_armour'], 'def': '(Middle Ages) flexible armor made of interlinked metal rings', 'name': 'chain_mail'}, {'frequency': 'f', 'id': 237, 'synset': 'chair.n.01', 'synonyms': ['chair'], 'def': 'a seat for one person, with a support for the back', 'name': 'chair'}, {'frequency': 'r', 'id': 238, 'synset': 'chaise_longue.n.01', 'synonyms': ['chaise_longue', 'chaise', 'daybed'], 'def': 'a long chair; for reclining', 'name': 'chaise_longue'}, {'frequency': 'r', 'id': 239, 'synset': 'champagne.n.01', 'synonyms': ['champagne'], 'def': 'a white sparkling wine produced in Champagne or resembling that produced there', 'name': 'champagne'}, {'frequency': 'f', 'id': 240, 'synset': 'chandelier.n.01', 'synonyms': ['chandelier'], 'def': 'branched lighting fixture; often ornate; hangs from the ceiling', 'name': 'chandelier'}, {'frequency': 'r', 'id': 241, 'synset': 'chap.n.04', 'synonyms': ['chap'], 'def': 'leather leggings without a seat; worn over trousers by cowboys to protect their legs', 'name': 'chap'}, {'frequency': 'r', 'id': 242, 'synset': 'checkbook.n.01', 'synonyms': ['checkbook', 'chequebook'], 'def': 'a book issued to holders of checking accounts', 'name': 'checkbook'}, {'frequency': 'r', 'id': 243, 'synset': 'checkerboard.n.01', 'synonyms': ['checkerboard'], 'def': 'a board having 64 squares of two alternating colors', 'name': 'checkerboard'}, {'frequency': 'c', 'id': 244, 'synset': 'cherry.n.03', 'synonyms': ['cherry'], 'def': 'a red fruit with a single hard stone', 'name': 'cherry'}, {'frequency': 'r', 'id': 245, 'synset': 'chessboard.n.01', 'synonyms': ['chessboard'], 'def': 'a checkerboard used to play chess', 'name': 'chessboard'}, {'frequency': 'r', 'id': 246, 'synset': 'chest_of_drawers.n.01', 'synonyms': ['chest_of_drawers_(furniture)', 'bureau_(furniture)', 'chest_(furniture)'], 'def': 'furniture with drawers for keeping clothes', 'name': 'chest_of_drawers_(furniture)'}, {'frequency': 'c', 'id': 247, 'synset': 'chicken.n.02', 'synonyms': ['chicken_(animal)'], 'def': 'a domestic fowl bred for flesh or eggs', 'name': 'chicken_(animal)'}, {'frequency': 'c', 'id': 248, 'synset': 'chicken_wire.n.01', 'synonyms': ['chicken_wire'], 'def': 'a galvanized wire network with a hexagonal mesh; used to build fences', 'name': 'chicken_wire'}, {'frequency': 'r', 'id': 249, 'synset': 'chickpea.n.01', 'synonyms': ['chickpea', 'garbanzo'], 'def': 'the seed of the chickpea plant; usually dried', 'name': 'chickpea'}, {'frequency': 'r', 'id': 250, 'synset': 'chihuahua.n.03', 'synonyms': ['Chihuahua'], 'def': 'an old breed of tiny short-haired dog with protruding eyes from Mexico', 'name': 'Chihuahua'}, {'frequency': 'r', 'id': 251, 'synset': 'chili.n.02', 'synonyms': ['chili_(vegetable)', 'chili_pepper_(vegetable)', 'chilli_(vegetable)', 'chilly_(vegetable)', 'chile_(vegetable)'], 'def': 'very hot and finely tapering pepper of special pungency', 'name': 'chili_(vegetable)'}, {'frequency': 'r', 'id': 252, 'synset': 'chime.n.01', 'synonyms': ['chime', 'gong'], 'def': 'an instrument consisting of a set of bells that are struck with a hammer', 'name': 'chime'}, {'frequency': 'r', 'id': 253, 'synset': 'chinaware.n.01', 'synonyms': ['chinaware'], 'def': 'dishware made of high quality porcelain', 'name': 'chinaware'}, {'frequency': 'c', 'id': 254, 'synset': 'chip.n.04', 'synonyms': ['crisp_(potato_chip)', 'potato_chip'], 'def': 'a thin crisp slice of potato fried in deep fat', 'name': 'crisp_(potato_chip)'}, {'frequency': 'r', 'id': 255, 'synset': 'chip.n.06', 'synonyms': ['poker_chip'], 'def': 'a small disk-shaped counter used to represent money when gambling', 'name': 'poker_chip'}, {'frequency': 'c', 'id': 256, 'synset': 'chocolate_bar.n.01', 'synonyms': ['chocolate_bar'], 'def': 'a bar of chocolate candy', 'name': 'chocolate_bar'}, {'frequency': 'c', 'id': 257, 'synset': 'chocolate_cake.n.01', 'synonyms': ['chocolate_cake'], 'def': 'cake containing chocolate', 'name': 'chocolate_cake'}, {'frequency': 'r', 'id': 258, 'synset': 'chocolate_milk.n.01', 'synonyms': ['chocolate_milk'], 'def': 'milk flavored with chocolate syrup', 'name': 'chocolate_milk'}, {'frequency': 'r', 'id': 259, 'synset': 'chocolate_mousse.n.01', 'synonyms': ['chocolate_mousse'], 'def': 'dessert mousse made with chocolate', 'name': 'chocolate_mousse'}, {'frequency': 'f', 'id': 260, 'synset': 'choker.n.03', 'synonyms': ['choker', 'collar', 'neckband'], 'def': 'necklace that fits tightly around the neck', 'name': 'choker'}, {'frequency': 'f', 'id': 261, 'synset': 'chopping_board.n.01', 'synonyms': ['chopping_board', 'cutting_board', 'chopping_block'], 'def': 'a wooden board where meats or vegetables can be cut', 'name': 'chopping_board'}, {'frequency': 'c', 'id': 262, 'synset': 'chopstick.n.01', 'synonyms': ['chopstick'], 'def': 'one of a pair of slender sticks used as oriental tableware to eat food with', 'name': 'chopstick'}, {'frequency': 'f', 'id': 263, 'synset': 'christmas_tree.n.05', 'synonyms': ['Christmas_tree'], 'def': 'an ornamented evergreen used as a Christmas decoration', 'name': 'Christmas_tree'}, {'frequency': 'c', 'id': 264, 'synset': 'chute.n.02', 'synonyms': ['slide'], 'def': 'sloping channel through which things can descend', 'name': 'slide'}, {'frequency': 'r', 'id': 265, 'synset': 'cider.n.01', 'synonyms': ['cider', 'cyder'], 'def': 'a beverage made from juice pressed from apples', 'name': 'cider'}, {'frequency': 'r', 'id': 266, 'synset': 'cigar_box.n.01', 'synonyms': ['cigar_box'], 'def': 'a box for holding cigars', 'name': 'cigar_box'}, {'frequency': 'c', 'id': 267, 'synset': 'cigarette.n.01', 'synonyms': ['cigarette'], 'def': 'finely ground tobacco wrapped in paper; for smoking', 'name': 'cigarette'}, {'frequency': 'c', 'id': 268, 'synset': 'cigarette_case.n.01', 'synonyms': ['cigarette_case', 'cigarette_pack'], 'def': 'a small flat case for holding cigarettes', 'name': 'cigarette_case'}, {'frequency': 'f', 'id': 269, 'synset': 'cistern.n.02', 'synonyms': ['cistern', 'water_tank'], 'def': 'a tank that holds the water used to flush a toilet', 'name': 'cistern'}, {'frequency': 'r', 'id': 270, 'synset': 'clarinet.n.01', 'synonyms': ['clarinet'], 'def': 'a single-reed instrument with a straight tube', 'name': 'clarinet'}, {'frequency': 'r', 'id': 271, 'synset': 'clasp.n.01', 'synonyms': ['clasp'], 'def': 'a fastener (as a buckle or hook) that is used to hold two things together', 'name': 'clasp'}, {'frequency': 'c', 'id': 272, 'synset': 'cleansing_agent.n.01', 'synonyms': ['cleansing_agent', 'cleanser', 'cleaner'], 'def': 'a preparation used in cleaning something', 'name': 'cleansing_agent'}, {'frequency': 'r', 'id': 273, 'synset': 'clementine.n.01', 'synonyms': ['clementine'], 'def': 'a variety of mandarin orange', 'name': 'clementine'}, {'frequency': 'c', 'id': 274, 'synset': 'clip.n.03', 'synonyms': ['clip'], 'def': 'any of various small fasteners used to hold loose articles together', 'name': 'clip'}, {'frequency': 'c', 'id': 275, 'synset': 'clipboard.n.01', 'synonyms': ['clipboard'], 'def': 'a small writing board with a clip at the top for holding papers', 'name': 'clipboard'}, {'frequency': 'f', 'id': 276, 'synset': 'clock.n.01', 'synonyms': ['clock', 'timepiece', 'timekeeper'], 'def': 'a timepiece that shows the time of day', 'name': 'clock'}, {'frequency': 'f', 'id': 277, 'synset': 'clock_tower.n.01', 'synonyms': ['clock_tower'], 'def': 'a tower with a large clock visible high up on an outside face', 'name': 'clock_tower'}, {'frequency': 'c', 'id': 278, 'synset': 'clothes_hamper.n.01', 'synonyms': ['clothes_hamper', 'laundry_basket', 'clothes_basket'], 'def': 'a hamper that holds dirty clothes to be washed or wet clothes to be dried', 'name': 'clothes_hamper'}, {'frequency': 'c', 'id': 279, 'synset': 'clothespin.n.01', 'synonyms': ['clothespin', 'clothes_peg'], 'def': 'wood or plastic fastener; for holding clothes on a clothesline', 'name': 'clothespin'}, {'frequency': 'r', 'id': 280, 'synset': 'clutch_bag.n.01', 'synonyms': ['clutch_bag'], 'def': "a woman's strapless purse that is carried in the hand", 'name': 'clutch_bag'}, {'frequency': 'f', 'id': 281, 'synset': 'coaster.n.03', 'synonyms': ['coaster'], 'def': 'a covering (plate or mat) that protects the surface of a table', 'name': 'coaster'}, {'frequency': 'f', 'id': 282, 'synset': 'coat.n.01', 'synonyms': ['coat'], 'def': 'an outer garment that has sleeves and covers the body from shoulder down', 'name': 'coat'}, {'frequency': 'c', 'id': 283, 'synset': 'coat_hanger.n.01', 'synonyms': ['coat_hanger', 'clothes_hanger', 'dress_hanger'], 'def': "a hanger that is shaped like a person's shoulders", 'name': 'coat_hanger'}, {'frequency': 'r', 'id': 284, 'synset': 'coatrack.n.01', 'synonyms': ['coatrack', 'hatrack'], 'def': 'a rack with hooks for temporarily holding coats and hats', 'name': 'coatrack'}, {'frequency': 'c', 'id': 285, 'synset': 'cock.n.04', 'synonyms': ['cock', 'rooster'], 'def': 'adult male chicken', 'name': 'cock'}, {'frequency': 'c', 'id': 286, 'synset': 'coconut.n.02', 'synonyms': ['coconut', 'cocoanut'], 'def': 'large hard-shelled brown oval nut with a fibrous husk', 'name': 'coconut'}, {'frequency': 'r', 'id': 287, 'synset': 'coffee_filter.n.01', 'synonyms': ['coffee_filter'], 'def': 'filter (usually of paper) that passes the coffee and retains the coffee grounds', 'name': 'coffee_filter'}, {'frequency': 'f', 'id': 288, 'synset': 'coffee_maker.n.01', 'synonyms': ['coffee_maker', 'coffee_machine'], 'def': 'a kitchen appliance for brewing coffee automatically', 'name': 'coffee_maker'}, {'frequency': 'f', 'id': 289, 'synset': 'coffee_table.n.01', 'synonyms': ['coffee_table', 'cocktail_table'], 'def': 'low table where magazines can be placed and coffee or cocktails are served', 'name': 'coffee_table'}, {'frequency': 'c', 'id': 290, 'synset': 'coffeepot.n.01', 'synonyms': ['coffeepot'], 'def': 'tall pot in which coffee is brewed', 'name': 'coffeepot'}, {'frequency': 'r', 'id': 291, 'synset': 'coil.n.05', 'synonyms': ['coil'], 'def': 'tubing that is wound in a spiral', 'name': 'coil'}, {'frequency': 'c', 'id': 292, 'synset': 'coin.n.01', 'synonyms': ['coin'], 'def': 'a flat metal piece (usually a disc) used as money', 'name': 'coin'}, {'frequency': 'r', 'id': 293, 'synset': 'colander.n.01', 'synonyms': ['colander', 'cullender'], 'def': 'bowl-shaped strainer; used to wash or drain foods', 'name': 'colander'}, {'frequency': 'c', 'id': 294, 'synset': 'coleslaw.n.01', 'synonyms': ['coleslaw', 'slaw'], 'def': 'basically shredded cabbage', 'name': 'coleslaw'}, {'frequency': 'r', 'id': 295, 'synset': 'coloring_material.n.01', 'synonyms': ['coloring_material', 'colouring_material'], 'def': 'any material used for its color', 'name': 'coloring_material'}, {'frequency': 'r', 'id': 296, 'synset': 'combination_lock.n.01', 'synonyms': ['combination_lock'], 'def': 'lock that can be opened only by turning dials in a special sequence', 'name': 'combination_lock'}, {'frequency': 'c', 'id': 297, 'synset': 'comforter.n.04', 'synonyms': ['pacifier', 'teething_ring'], 'def': 'device used for an infant to suck or bite on', 'name': 'pacifier'}, {'frequency': 'r', 'id': 298, 'synset': 'comic_book.n.01', 'synonyms': ['comic_book'], 'def': 'a magazine devoted to comic strips', 'name': 'comic_book'}, {'frequency': 'f', 'id': 299, 'synset': 'computer_keyboard.n.01', 'synonyms': ['computer_keyboard', 'keyboard_(computer)'], 'def': 'a keyboard that is a data input device for computers', 'name': 'computer_keyboard'}, {'frequency': 'r', 'id': 300, 'synset': 'concrete_mixer.n.01', 'synonyms': ['concrete_mixer', 'cement_mixer'], 'def': 'a machine with a large revolving drum in which cement/concrete is mixed', 'name': 'concrete_mixer'}, {'frequency': 'f', 'id': 301, 'synset': 'cone.n.01', 'synonyms': ['cone', 'traffic_cone'], 'def': 'a cone-shaped object used to direct traffic', 'name': 'cone'}, {'frequency': 'f', 'id': 302, 'synset': 'control.n.09', 'synonyms': ['control', 'controller'], 'def': 'a mechanism that controls the operation of a machine', 'name': 'control'}, {'frequency': 'r', 'id': 303, 'synset': 'convertible.n.01', 'synonyms': ['convertible_(automobile)'], 'def': 'a car that has top that can be folded or removed', 'name': 'convertible_(automobile)'}, {'frequency': 'r', 'id': 304, 'synset': 'convertible.n.03', 'synonyms': ['sofa_bed'], 'def': 'a sofa that can be converted into a bed', 'name': 'sofa_bed'}, {'frequency': 'c', 'id': 305, 'synset': 'cookie.n.01', 'synonyms': ['cookie', 'cooky', 'biscuit_(cookie)'], 'def': "any of various small flat sweet cakes (`biscuit' is the British term)", 'name': 'cookie'}, {'frequency': 'r', 'id': 306, 'synset': 'cookie_jar.n.01', 'synonyms': ['cookie_jar', 'cooky_jar'], 'def': 'a jar in which cookies are kept (and sometimes money is hidden)', 'name': 'cookie_jar'}, {'frequency': 'r', 'id': 307, 'synset': 'cooking_utensil.n.01', 'synonyms': ['cooking_utensil'], 'def': 'a kitchen utensil made of material that does not melt easily; used for cooking', 'name': 'cooking_utensil'}, {'frequency': 'f', 'id': 308, 'synset': 'cooler.n.01', 'synonyms': ['cooler_(for_food)', 'ice_chest'], 'def': 'an insulated box for storing food often with ice', 'name': 'cooler_(for_food)'}, {'frequency': 'c', 'id': 309, 'synset': 'cork.n.04', 'synonyms': ['cork_(bottle_plug)', 'bottle_cork'], 'def': 'the plug in the mouth of a bottle (especially a wine bottle)', 'name': 'cork_(bottle_plug)'}, {'frequency': 'r', 'id': 310, 'synset': 'corkboard.n.01', 'synonyms': ['corkboard'], 'def': 'a sheet consisting of cork granules', 'name': 'corkboard'}, {'frequency': 'r', 'id': 311, 'synset': 'corkscrew.n.01', 'synonyms': ['corkscrew', 'bottle_screw'], 'def': 'a bottle opener that pulls corks', 'name': 'corkscrew'}, {'frequency': 'c', 'id': 312, 'synset': 'corn.n.03', 'synonyms': ['edible_corn', 'corn', 'maize'], 'def': 'ears of corn that can be prepared and served for human food', 'name': 'edible_corn'}, {'frequency': 'r', 'id': 313, 'synset': 'cornbread.n.01', 'synonyms': ['cornbread'], 'def': 'bread made primarily of cornmeal', 'name': 'cornbread'}, {'frequency': 'c', 'id': 314, 'synset': 'cornet.n.01', 'synonyms': ['cornet', 'horn', 'trumpet'], 'def': 'a brass musical instrument with a narrow tube and a flared bell and many valves', 'name': 'cornet'}, {'frequency': 'c', 'id': 315, 'synset': 'cornice.n.01', 'synonyms': ['cornice', 'valance', 'valance_board', 'pelmet'], 'def': 'a decorative framework to conceal curtain fixtures at the top of a window casing', 'name': 'cornice'}, {'frequency': 'r', 'id': 316, 'synset': 'cornmeal.n.01', 'synonyms': ['cornmeal'], 'def': 'coarsely ground corn', 'name': 'cornmeal'}, {'frequency': 'r', 'id': 317, 'synset': 'corset.n.01', 'synonyms': ['corset', 'girdle'], 'def': "a woman's close-fitting foundation garment", 'name': 'corset'}, {'frequency': 'r', 'id': 318, 'synset': 'cos.n.02', 'synonyms': ['romaine_lettuce'], 'def': 'lettuce with long dark-green leaves in a loosely packed elongated head', 'name': 'romaine_lettuce'}, {'frequency': 'c', 'id': 319, 'synset': 'costume.n.04', 'synonyms': ['costume'], 'def': 'the attire characteristic of a country or a time or a social class', 'name': 'costume'}, {'frequency': 'r', 'id': 320, 'synset': 'cougar.n.01', 'synonyms': ['cougar', 'puma', 'catamount', 'mountain_lion', 'panther'], 'def': 'large American feline resembling a lion', 'name': 'cougar'}, {'frequency': 'r', 'id': 321, 'synset': 'coverall.n.01', 'synonyms': ['coverall'], 'def': 'a loose-fitting protective garment that is worn over other clothing', 'name': 'coverall'}, {'frequency': 'r', 'id': 322, 'synset': 'cowbell.n.01', 'synonyms': ['cowbell'], 'def': 'a bell hung around the neck of cow so that the cow can be easily located', 'name': 'cowbell'}, {'frequency': 'f', 'id': 323, 'synset': 'cowboy_hat.n.01', 'synonyms': ['cowboy_hat', 'ten-gallon_hat'], 'def': 'a hat with a wide brim and a soft crown; worn by American ranch hands', 'name': 'cowboy_hat'}, {'frequency': 'r', 'id': 324, 'synset': 'crab.n.01', 'synonyms': ['crab_(animal)'], 'def': 'decapod having eyes on short stalks and a broad flattened shell and pincers', 'name': 'crab_(animal)'}, {'frequency': 'c', 'id': 325, 'synset': 'cracker.n.01', 'synonyms': ['cracker'], 'def': 'a thin crisp wafer', 'name': 'cracker'}, {'frequency': 'r', 'id': 326, 'synset': 'crape.n.01', 'synonyms': ['crape', 'crepe', 'French_pancake'], 'def': 'small very thin pancake', 'name': 'crape'}, {'frequency': 'f', 'id': 327, 'synset': 'crate.n.01', 'synonyms': ['crate'], 'def': 'a rugged box (usually made of wood); used for shipping', 'name': 'crate'}, {'frequency': 'r', 'id': 328, 'synset': 'crayon.n.01', 'synonyms': ['crayon', 'wax_crayon'], 'def': 'writing or drawing implement made of a colored stick of composition wax', 'name': 'crayon'}, {'frequency': 'r', 'id': 329, 'synset': 'cream_pitcher.n.01', 'synonyms': ['cream_pitcher'], 'def': 'a small pitcher for serving cream', 'name': 'cream_pitcher'}, {'frequency': 'r', 'id': 330, 'synset': 'credit_card.n.01', 'synonyms': ['credit_card', 'charge_card', 'debit_card'], 'def': 'a card, usually plastic, used to pay for goods and services', 'name': 'credit_card'}, {'frequency': 'c', 'id': 331, 'synset': 'crescent_roll.n.01', 'synonyms': ['crescent_roll', 'croissant'], 'def': 'very rich flaky crescent-shaped roll', 'name': 'crescent_roll'}, {'frequency': 'c', 'id': 332, 'synset': 'crib.n.01', 'synonyms': ['crib', 'cot'], 'def': 'baby bed with high sides made of slats', 'name': 'crib'}, {'frequency': 'c', 'id': 333, 'synset': 'crock.n.03', 'synonyms': ['crock_pot', 'earthenware_jar'], 'def': 'an earthen jar (made of baked clay)', 'name': 'crock_pot'}, {'frequency': 'f', 'id': 334, 'synset': 'crossbar.n.01', 'synonyms': ['crossbar'], 'def': 'a horizontal bar that goes across something', 'name': 'crossbar'}, {'frequency': 'r', 'id': 335, 'synset': 'crouton.n.01', 'synonyms': ['crouton'], 'def': 'a small piece of toasted or fried bread; served in soup or salads', 'name': 'crouton'}, {'frequency': 'r', 'id': 336, 'synset': 'crow.n.01', 'synonyms': ['crow'], 'def': 'black birds having a raucous call', 'name': 'crow'}, {'frequency': 'c', 'id': 337, 'synset': 'crown.n.04', 'synonyms': ['crown'], 'def': 'an ornamental jeweled headdress signifying sovereignty', 'name': 'crown'}, {'frequency': 'c', 'id': 338, 'synset': 'crucifix.n.01', 'synonyms': ['crucifix'], 'def': 'representation of the cross on which Jesus died', 'name': 'crucifix'}, {'frequency': 'c', 'id': 339, 'synset': 'cruise_ship.n.01', 'synonyms': ['cruise_ship', 'cruise_liner'], 'def': 'a passenger ship used commercially for pleasure cruises', 'name': 'cruise_ship'}, {'frequency': 'c', 'id': 340, 'synset': 'cruiser.n.01', 'synonyms': ['police_cruiser', 'patrol_car', 'police_car', 'squad_car'], 'def': 'a car in which policemen cruise the streets', 'name': 'police_cruiser'}, {'frequency': 'c', 'id': 341, 'synset': 'crumb.n.03', 'synonyms': ['crumb'], 'def': 'small piece of e.g. bread or cake', 'name': 'crumb'}, {'frequency': 'r', 'id': 342, 'synset': 'crutch.n.01', 'synonyms': ['crutch'], 'def': 'a wooden or metal staff that fits under the armpit and reaches to the ground', 'name': 'crutch'}, {'frequency': 'c', 'id': 343, 'synset': 'cub.n.03', 'synonyms': ['cub_(animal)'], 'def': 'the young of certain carnivorous mammals such as the bear or wolf or lion', 'name': 'cub_(animal)'}, {'frequency': 'r', 'id': 344, 'synset': 'cube.n.05', 'synonyms': ['cube', 'square_block'], 'def': 'a block in the (approximate) shape of a cube', 'name': 'cube'}, {'frequency': 'f', 'id': 345, 'synset': 'cucumber.n.02', 'synonyms': ['cucumber', 'cuke'], 'def': 'cylindrical green fruit with thin green rind and white flesh eaten as a vegetable', 'name': 'cucumber'}, {'frequency': 'c', 'id': 346, 'synset': 'cufflink.n.01', 'synonyms': ['cufflink'], 'def': 'jewelry consisting of linked buttons used to fasten the cuffs of a shirt', 'name': 'cufflink'}, {'frequency': 'f', 'id': 347, 'synset': 'cup.n.01', 'synonyms': ['cup'], 'def': 'a small open container usually used for drinking; usually has a handle', 'name': 'cup'}, {'frequency': 'c', 'id': 348, 'synset': 'cup.n.08', 'synonyms': ['trophy_cup'], 'def': 'a metal vessel with handles that is awarded as a trophy to a competition winner', 'name': 'trophy_cup'}, {'frequency': 'c', 'id': 349, 'synset': 'cupcake.n.01', 'synonyms': ['cupcake'], 'def': 'small cake baked in a muffin tin', 'name': 'cupcake'}, {'frequency': 'r', 'id': 350, 'synset': 'curler.n.01', 'synonyms': ['hair_curler', 'hair_roller', 'hair_crimper'], 'def': 'a cylindrical tube around which the hair is wound to curl it', 'name': 'hair_curler'}, {'frequency': 'r', 'id': 351, 'synset': 'curling_iron.n.01', 'synonyms': ['curling_iron'], 'def': 'a cylindrical home appliance that heats hair that has been curled around it', 'name': 'curling_iron'}, {'frequency': 'f', 'id': 352, 'synset': 'curtain.n.01', 'synonyms': ['curtain', 'drapery'], 'def': 'hanging cloth used as a blind (especially for a window)', 'name': 'curtain'}, {'frequency': 'f', 'id': 353, 'synset': 'cushion.n.03', 'synonyms': ['cushion'], 'def': 'a soft bag filled with air or padding such as feathers or foam rubber', 'name': 'cushion'}, {'frequency': 'r', 'id': 354, 'synset': 'custard.n.01', 'synonyms': ['custard'], 'def': 'sweetened mixture of milk and eggs baked or boiled or frozen', 'name': 'custard'}, {'frequency': 'c', 'id': 355, 'synset': 'cutter.n.06', 'synonyms': ['cutting_tool'], 'def': 'a cutting implement; a tool for cutting', 'name': 'cutting_tool'}, {'frequency': 'r', 'id': 356, 'synset': 'cylinder.n.04', 'synonyms': ['cylinder'], 'def': 'a cylindrical container', 'name': 'cylinder'}, {'frequency': 'r', 'id': 357, 'synset': 'cymbal.n.01', 'synonyms': ['cymbal'], 'def': 'a percussion instrument consisting of a concave brass disk', 'name': 'cymbal'}, {'frequency': 'r', 'id': 358, 'synset': 'dachshund.n.01', 'synonyms': ['dachshund', 'dachsie', 'badger_dog'], 'def': 'small long-bodied short-legged breed of dog having a short sleek coat and long drooping ears', 'name': 'dachshund'}, {'frequency': 'r', 'id': 359, 'synset': 'dagger.n.01', 'synonyms': ['dagger'], 'def': 'a short knife with a pointed blade used for piercing or stabbing', 'name': 'dagger'}, {'frequency': 'r', 'id': 360, 'synset': 'dartboard.n.01', 'synonyms': ['dartboard'], 'def': 'a circular board of wood or cork used as the target in the game of darts', 'name': 'dartboard'}, {'frequency': 'r', 'id': 361, 'synset': 'date.n.08', 'synonyms': ['date_(fruit)'], 'def': 'sweet edible fruit of the date palm with a single long woody seed', 'name': 'date_(fruit)'}, {'frequency': 'f', 'id': 362, 'synset': 'deck_chair.n.01', 'synonyms': ['deck_chair', 'beach_chair'], 'def': 'a folding chair for use outdoors; a wooden frame supports a length of canvas', 'name': 'deck_chair'}, {'frequency': 'c', 'id': 363, 'synset': 'deer.n.01', 'synonyms': ['deer', 'cervid'], 'def': "distinguished from Bovidae by the male's having solid deciduous antlers", 'name': 'deer'}, {'frequency': 'c', 'id': 364, 'synset': 'dental_floss.n.01', 'synonyms': ['dental_floss', 'floss'], 'def': 'a soft thread for cleaning the spaces between the teeth', 'name': 'dental_floss'}, {'frequency': 'f', 'id': 365, 'synset': 'desk.n.01', 'synonyms': ['desk'], 'def': 'a piece of furniture with a writing surface and usually drawers or other compartments', 'name': 'desk'}, {'frequency': 'r', 'id': 366, 'synset': 'detergent.n.01', 'synonyms': ['detergent'], 'def': 'a surface-active chemical widely used in industry and laundering', 'name': 'detergent'}, {'frequency': 'c', 'id': 367, 'synset': 'diaper.n.01', 'synonyms': ['diaper'], 'def': 'garment consisting of a folded cloth drawn up between the legs and fastened at the waist', 'name': 'diaper'}, {'frequency': 'r', 'id': 368, 'synset': 'diary.n.01', 'synonyms': ['diary', 'journal'], 'def': 'a daily written record of (usually personal) experiences and observations', 'name': 'diary'}, {'frequency': 'r', 'id': 369, 'synset': 'die.n.01', 'synonyms': ['die', 'dice'], 'def': 'a small cube with 1 to 6 spots on the six faces; used in gambling', 'name': 'die'}, {'frequency': 'r', 'id': 370, 'synset': 'dinghy.n.01', 'synonyms': ['dinghy', 'dory', 'rowboat'], 'def': 'a small boat of shallow draft with seats and oars with which it is propelled', 'name': 'dinghy'}, {'frequency': 'f', 'id': 371, 'synset': 'dining_table.n.01', 'synonyms': ['dining_table'], 'def': 'a table at which meals are served', 'name': 'dining_table'}, {'frequency': 'r', 'id': 372, 'synset': 'dinner_jacket.n.01', 'synonyms': ['tux', 'tuxedo'], 'def': 'semiformal evening dress for men', 'name': 'tux'}, {'frequency': 'c', 'id': 373, 'synset': 'dish.n.01', 'synonyms': ['dish'], 'def': 'a piece of dishware normally used as a container for holding or serving food', 'name': 'dish'}, {'frequency': 'c', 'id': 374, 'synset': 'dish.n.05', 'synonyms': ['dish_antenna'], 'def': 'directional antenna consisting of a parabolic reflector', 'name': 'dish_antenna'}, {'frequency': 'c', 'id': 375, 'synset': 'dishrag.n.01', 'synonyms': ['dishrag', 'dishcloth'], 'def': 'a cloth for washing dishes', 'name': 'dishrag'}, {'frequency': 'c', 'id': 376, 'synset': 'dishtowel.n.01', 'synonyms': ['dishtowel', 'tea_towel'], 'def': 'a towel for drying dishes', 'name': 'dishtowel'}, {'frequency': 'f', 'id': 377, 'synset': 'dishwasher.n.01', 'synonyms': ['dishwasher', 'dishwashing_machine'], 'def': 'a machine for washing dishes', 'name': 'dishwasher'}, {'frequency': 'r', 'id': 378, 'synset': 'dishwasher_detergent.n.01', 'synonyms': ['dishwasher_detergent', 'dishwashing_detergent', 'dishwashing_liquid'], 'def': 'a low-sudsing detergent designed for use in dishwashers', 'name': 'dishwasher_detergent'}, {'frequency': 'r', 'id': 379, 'synset': 'diskette.n.01', 'synonyms': ['diskette', 'floppy', 'floppy_disk'], 'def': 'a small plastic magnetic disk enclosed in a stiff envelope used to store data', 'name': 'diskette'}, {'frequency': 'c', 'id': 380, 'synset': 'dispenser.n.01', 'synonyms': ['dispenser'], 'def': 'a container so designed that the contents can be used in prescribed amounts', 'name': 'dispenser'}, {'frequency': 'c', 'id': 381, 'synset': 'dixie_cup.n.01', 'synonyms': ['Dixie_cup', 'paper_cup'], 'def': 'a disposable cup made of paper; for holding drinks', 'name': 'Dixie_cup'}, {'frequency': 'f', 'id': 382, 'synset': 'dog.n.01', 'synonyms': ['dog'], 'def': 'a common domesticated dog', 'name': 'dog'}, {'frequency': 'f', 'id': 383, 'synset': 'dog_collar.n.01', 'synonyms': ['dog_collar'], 'def': 'a collar for a dog', 'name': 'dog_collar'}, {'frequency': 'c', 'id': 384, 'synset': 'doll.n.01', 'synonyms': ['doll'], 'def': 'a toy replica of a HUMAN (NOT AN ANIMAL)', 'name': 'doll'}, {'frequency': 'r', 'id': 385, 'synset': 'dollar.n.02', 'synonyms': ['dollar', 'dollar_bill', 'one_dollar_bill'], 'def': 'a piece of paper money worth one dollar', 'name': 'dollar'}, {'frequency': 'r', 'id': 386, 'synset': 'dolphin.n.02', 'synonyms': ['dolphin'], 'def': 'any of various small toothed whales with a beaklike snout; larger than porpoises', 'name': 'dolphin'}, {'frequency': 'c', 'id': 387, 'synset': 'domestic_ass.n.01', 'synonyms': ['domestic_ass', 'donkey'], 'def': 'domestic beast of burden descended from the African wild ass; patient but stubborn', 'name': 'domestic_ass'}, {'frequency': 'r', 'id': 388, 'synset': 'domino.n.03', 'synonyms': ['eye_mask'], 'def': 'a mask covering the upper part of the face but with holes for the eyes', 'name': 'eye_mask'}, {'frequency': 'r', 'id': 389, 'synset': 'doorbell.n.01', 'synonyms': ['doorbell', 'buzzer'], 'def': 'a button at an outer door that gives a ringing or buzzing signal when pushed', 'name': 'doorbell'}, {'frequency': 'f', 'id': 390, 'synset': 'doorknob.n.01', 'synonyms': ['doorknob', 'doorhandle'], 'def': "a knob used to open a door (often called `doorhandle' in Great Britain)", 'name': 'doorknob'}, {'frequency': 'c', 'id': 391, 'synset': 'doormat.n.02', 'synonyms': ['doormat', 'welcome_mat'], 'def': 'a mat placed outside an exterior door for wiping the shoes before entering', 'name': 'doormat'}, {'frequency': 'f', 'id': 392, 'synset': 'doughnut.n.02', 'synonyms': ['doughnut', 'donut'], 'def': 'a small ring-shaped friedcake', 'name': 'doughnut'}, {'frequency': 'r', 'id': 393, 'synset': 'dove.n.01', 'synonyms': ['dove'], 'def': 'any of numerous small pigeons', 'name': 'dove'}, {'frequency': 'r', 'id': 394, 'synset': 'dragonfly.n.01', 'synonyms': ['dragonfly'], 'def': 'slender-bodied non-stinging insect having iridescent wings that are outspread at rest', 'name': 'dragonfly'}, {'frequency': 'f', 'id': 395, 'synset': 'drawer.n.01', 'synonyms': ['drawer'], 'def': 'a boxlike container in a piece of furniture; made so as to slide in and out', 'name': 'drawer'}, {'frequency': 'c', 'id': 396, 'synset': 'drawers.n.01', 'synonyms': ['underdrawers', 'boxers', 'boxershorts'], 'def': 'underpants worn by men', 'name': 'underdrawers'}, {'frequency': 'f', 'id': 397, 'synset': 'dress.n.01', 'synonyms': ['dress', 'frock'], 'def': 'a one-piece garment for a woman; has skirt and bodice', 'name': 'dress'}, {'frequency': 'c', 'id': 398, 'synset': 'dress_hat.n.01', 'synonyms': ['dress_hat', 'high_hat', 'opera_hat', 'silk_hat', 'top_hat'], 'def': "a man's hat with a tall crown; usually covered with silk or with beaver fur", 'name': 'dress_hat'}, {'frequency': 'c', 'id': 399, 'synset': 'dress_suit.n.01', 'synonyms': ['dress_suit'], 'def': 'formalwear consisting of full evening dress for men', 'name': 'dress_suit'}, {'frequency': 'c', 'id': 400, 'synset': 'dresser.n.05', 'synonyms': ['dresser'], 'def': 'a cabinet with shelves', 'name': 'dresser'}, {'frequency': 'c', 'id': 401, 'synset': 'drill.n.01', 'synonyms': ['drill'], 'def': 'a tool with a sharp rotating point for making holes in hard materials', 'name': 'drill'}, {'frequency': 'r', 'id': 402, 'synset': 'drinking_fountain.n.01', 'synonyms': ['drinking_fountain'], 'def': 'a public fountain to provide a jet of drinking water', 'name': 'drinking_fountain'}, {'frequency': 'r', 'id': 403, 'synset': 'drone.n.04', 'synonyms': ['drone'], 'def': 'an aircraft without a pilot that is operated by remote control', 'name': 'drone'}, {'frequency': 'r', 'id': 404, 'synset': 'dropper.n.01', 'synonyms': ['dropper', 'eye_dropper'], 'def': 'pipet consisting of a small tube with a vacuum bulb at one end for drawing liquid in and releasing it a drop at a time', 'name': 'dropper'}, {'frequency': 'c', 'id': 405, 'synset': 'drum.n.01', 'synonyms': ['drum_(musical_instrument)'], 'def': 'a musical percussion instrument; usually consists of a hollow cylinder with a membrane stretched across each end', 'name': 'drum_(musical_instrument)'}, {'frequency': 'r', 'id': 406, 'synset': 'drumstick.n.02', 'synonyms': ['drumstick'], 'def': 'a stick used for playing a drum', 'name': 'drumstick'}, {'frequency': 'f', 'id': 407, 'synset': 'duck.n.01', 'synonyms': ['duck'], 'def': 'small web-footed broad-billed swimming bird', 'name': 'duck'}, {'frequency': 'r', 'id': 408, 'synset': 'duckling.n.02', 'synonyms': ['duckling'], 'def': 'young duck', 'name': 'duckling'}, {'frequency': 'c', 'id': 409, 'synset': 'duct_tape.n.01', 'synonyms': ['duct_tape'], 'def': 'a wide silvery adhesive tape', 'name': 'duct_tape'}, {'frequency': 'f', 'id': 410, 'synset': 'duffel_bag.n.01', 'synonyms': ['duffel_bag', 'duffle_bag', 'duffel', 'duffle'], 'def': 'a large cylindrical bag of heavy cloth', 'name': 'duffel_bag'}, {'frequency': 'r', 'id': 411, 'synset': 'dumbbell.n.01', 'synonyms': ['dumbbell'], 'def': 'an exercising weight with two ball-like ends connected by a short handle', 'name': 'dumbbell'}, {'frequency': 'c', 'id': 412, 'synset': 'dumpster.n.01', 'synonyms': ['dumpster'], 'def': 'a container designed to receive and transport and dump waste', 'name': 'dumpster'}, {'frequency': 'r', 'id': 413, 'synset': 'dustpan.n.02', 'synonyms': ['dustpan'], 'def': 'a short-handled receptacle into which dust can be swept', 'name': 'dustpan'}, {'frequency': 'r', 'id': 414, 'synset': 'dutch_oven.n.02', 'synonyms': ['Dutch_oven'], 'def': 'iron or earthenware cooking pot; used for stews', 'name': 'Dutch_oven'}, {'frequency': 'c', 'id': 415, 'synset': 'eagle.n.01', 'synonyms': ['eagle'], 'def': 'large birds of prey noted for their broad wings and strong soaring flight', 'name': 'eagle'}, {'frequency': 'f', 'id': 416, 'synset': 'earphone.n.01', 'synonyms': ['earphone', 'earpiece', 'headphone'], 'def': 'device for listening to audio that is held over or inserted into the ear', 'name': 'earphone'}, {'frequency': 'r', 'id': 417, 'synset': 'earplug.n.01', 'synonyms': ['earplug'], 'def': 'a soft plug that is inserted into the ear canal to block sound', 'name': 'earplug'}, {'frequency': 'f', 'id': 418, 'synset': 'earring.n.01', 'synonyms': ['earring'], 'def': 'jewelry to ornament the ear', 'name': 'earring'}, {'frequency': 'c', 'id': 419, 'synset': 'easel.n.01', 'synonyms': ['easel'], 'def': "an upright tripod for displaying something (usually an artist's canvas)", 'name': 'easel'}, {'frequency': 'r', 'id': 420, 'synset': 'eclair.n.01', 'synonyms': ['eclair'], 'def': 'oblong cream puff', 'name': 'eclair'}, {'frequency': 'r', 'id': 421, 'synset': 'eel.n.01', 'synonyms': ['eel'], 'def': 'an elongate fish with fatty flesh', 'name': 'eel'}, {'frequency': 'f', 'id': 422, 'synset': 'egg.n.02', 'synonyms': ['egg', 'eggs'], 'def': 'oval reproductive body of a fowl (especially a hen) used as food', 'name': 'egg'}, {'frequency': 'r', 'id': 423, 'synset': 'egg_roll.n.01', 'synonyms': ['egg_roll', 'spring_roll'], 'def': 'minced vegetables and meat wrapped in a pancake and fried', 'name': 'egg_roll'}, {'frequency': 'c', 'id': 424, 'synset': 'egg_yolk.n.01', 'synonyms': ['egg_yolk', 'yolk_(egg)'], 'def': 'the yellow spherical part of an egg', 'name': 'egg_yolk'}, {'frequency': 'c', 'id': 425, 'synset': 'eggbeater.n.02', 'synonyms': ['eggbeater', 'eggwhisk'], 'def': 'a mixer for beating eggs or whipping cream', 'name': 'eggbeater'}, {'frequency': 'c', 'id': 426, 'synset': 'eggplant.n.01', 'synonyms': ['eggplant', 'aubergine'], 'def': 'egg-shaped vegetable having a shiny skin typically dark purple', 'name': 'eggplant'}, {'frequency': 'r', 'id': 427, 'synset': 'electric_chair.n.01', 'synonyms': ['electric_chair'], 'def': 'a chair-shaped instrument of execution by electrocution', 'name': 'electric_chair'}, {'frequency': 'f', 'id': 428, 'synset': 'electric_refrigerator.n.01', 'synonyms': ['refrigerator'], 'def': 'a refrigerator in which the coolant is pumped around by an electric motor', 'name': 'refrigerator'}, {'frequency': 'f', 'id': 429, 'synset': 'elephant.n.01', 'synonyms': ['elephant'], 'def': 'a common elephant', 'name': 'elephant'}, {'frequency': 'r', 'id': 430, 'synset': 'elk.n.01', 'synonyms': ['elk', 'moose'], 'def': 'large northern deer with enormous flattened antlers in the male', 'name': 'elk'}, {'frequency': 'c', 'id': 431, 'synset': 'envelope.n.01', 'synonyms': ['envelope'], 'def': 'a flat (usually rectangular) container for a letter, thin package, etc.', 'name': 'envelope'}, {'frequency': 'c', 'id': 432, 'synset': 'eraser.n.01', 'synonyms': ['eraser'], 'def': 'an implement used to erase something', 'name': 'eraser'}, {'frequency': 'r', 'id': 433, 'synset': 'escargot.n.01', 'synonyms': ['escargot'], 'def': 'edible snail usually served in the shell with a sauce of melted butter and garlic', 'name': 'escargot'}, {'frequency': 'r', 'id': 434, 'synset': 'eyepatch.n.01', 'synonyms': ['eyepatch'], 'def': 'a protective cloth covering for an injured eye', 'name': 'eyepatch'}, {'frequency': 'r', 'id': 435, 'synset': 'falcon.n.01', 'synonyms': ['falcon'], 'def': 'birds of prey having long pointed powerful wings adapted for swift flight', 'name': 'falcon'}, {'frequency': 'f', 'id': 436, 'synset': 'fan.n.01', 'synonyms': ['fan'], 'def': 'a device for creating a current of air by movement of a surface or surfaces', 'name': 'fan'}, {'frequency': 'f', 'id': 437, 'synset': 'faucet.n.01', 'synonyms': ['faucet', 'spigot', 'tap'], 'def': 'a regulator for controlling the flow of a liquid from a reservoir', 'name': 'faucet'}, {'frequency': 'r', 'id': 438, 'synset': 'fedora.n.01', 'synonyms': ['fedora'], 'def': 'a hat made of felt with a creased crown', 'name': 'fedora'}, {'frequency': 'r', 'id': 439, 'synset': 'ferret.n.02', 'synonyms': ['ferret'], 'def': 'domesticated albino variety of the European polecat bred for hunting rats and rabbits', 'name': 'ferret'}, {'frequency': 'c', 'id': 440, 'synset': 'ferris_wheel.n.01', 'synonyms': ['Ferris_wheel'], 'def': 'a large wheel with suspended seats that remain upright as the wheel rotates', 'name': 'Ferris_wheel'}, {'frequency': 'r', 'id': 441, 'synset': 'ferry.n.01', 'synonyms': ['ferry', 'ferryboat'], 'def': 'a boat that transports people or vehicles across a body of water and operates on a regular schedule', 'name': 'ferry'}, {'frequency': 'r', 'id': 442, 'synset': 'fig.n.04', 'synonyms': ['fig_(fruit)'], 'def': 'fleshy sweet pear-shaped yellowish or purple fruit eaten fresh or preserved or dried', 'name': 'fig_(fruit)'}, {'frequency': 'c', 'id': 443, 'synset': 'fighter.n.02', 'synonyms': ['fighter_jet', 'fighter_aircraft', 'attack_aircraft'], 'def': 'a high-speed military or naval airplane designed to destroy enemy targets', 'name': 'fighter_jet'}, {'frequency': 'f', 'id': 444, 'synset': 'figurine.n.01', 'synonyms': ['figurine'], 'def': 'a small carved or molded figure', 'name': 'figurine'}, {'frequency': 'c', 'id': 445, 'synset': 'file.n.03', 'synonyms': ['file_cabinet', 'filing_cabinet'], 'def': 'office furniture consisting of a container for keeping papers in order', 'name': 'file_cabinet'}, {'frequency': 'r', 'id': 446, 'synset': 'file.n.04', 'synonyms': ['file_(tool)'], 'def': 'a steel hand tool with small sharp teeth on some or all of its surfaces; used for smoothing wood or metal', 'name': 'file_(tool)'}, {'frequency': 'f', 'id': 447, 'synset': 'fire_alarm.n.02', 'synonyms': ['fire_alarm', 'smoke_alarm'], 'def': 'an alarm that is tripped off by fire or smoke', 'name': 'fire_alarm'}, {'frequency': 'c', 'id': 448, 'synset': 'fire_engine.n.01', 'synonyms': ['fire_engine', 'fire_truck'], 'def': 'large trucks that carry firefighters and equipment to the site of a fire', 'name': 'fire_engine'}, {'frequency': 'c', 'id': 449, 'synset': 'fire_extinguisher.n.01', 'synonyms': ['fire_extinguisher', 'extinguisher'], 'def': 'a manually operated device for extinguishing small fires', 'name': 'fire_extinguisher'}, {'frequency': 'c', 'id': 450, 'synset': 'fire_hose.n.01', 'synonyms': ['fire_hose'], 'def': 'a large hose that carries water from a fire hydrant to the site of the fire', 'name': 'fire_hose'}, {'frequency': 'f', 'id': 451, 'synset': 'fireplace.n.01', 'synonyms': ['fireplace'], 'def': 'an open recess in a wall at the base of a chimney where a fire can be built', 'name': 'fireplace'}, {'frequency': 'f', 'id': 452, 'synset': 'fireplug.n.01', 'synonyms': ['fireplug', 'fire_hydrant', 'hydrant'], 'def': 'an upright hydrant for drawing water to use in fighting a fire', 'name': 'fireplug'}, {'frequency': 'c', 'id': 453, 'synset': 'fish.n.01', 'synonyms': ['fish'], 'def': 'any of various mostly cold-blooded aquatic vertebrates usually having scales and breathing through gills', 'name': 'fish'}, {'frequency': 'r', 'id': 454, 'synset': 'fish.n.02', 'synonyms': ['fish_(food)'], 'def': 'the flesh of fish used as food', 'name': 'fish_(food)'}, {'frequency': 'r', 'id': 455, 'synset': 'fishbowl.n.02', 'synonyms': ['fishbowl', 'goldfish_bowl'], 'def': 'a transparent bowl in which small fish are kept', 'name': 'fishbowl'}, {'frequency': 'r', 'id': 456, 'synset': 'fishing_boat.n.01', 'synonyms': ['fishing_boat', 'fishing_vessel'], 'def': 'a vessel for fishing', 'name': 'fishing_boat'}, {'frequency': 'c', 'id': 457, 'synset': 'fishing_rod.n.01', 'synonyms': ['fishing_rod', 'fishing_pole'], 'def': 'a rod that is used in fishing to extend the fishing line', 'name': 'fishing_rod'}, {'frequency': 'f', 'id': 458, 'synset': 'flag.n.01', 'synonyms': ['flag'], 'def': 'emblem usually consisting of a rectangular piece of cloth of distinctive design (do not include pole)', 'name': 'flag'}, {'frequency': 'f', 'id': 459, 'synset': 'flagpole.n.02', 'synonyms': ['flagpole', 'flagstaff'], 'def': 'a tall staff or pole on which a flag is raised', 'name': 'flagpole'}, {'frequency': 'c', 'id': 460, 'synset': 'flamingo.n.01', 'synonyms': ['flamingo'], 'def': 'large pink web-footed bird with down-bent bill', 'name': 'flamingo'}, {'frequency': 'c', 'id': 461, 'synset': 'flannel.n.01', 'synonyms': ['flannel'], 'def': 'a soft light woolen fabric; used for clothing', 'name': 'flannel'}, {'frequency': 'r', 'id': 462, 'synset': 'flash.n.10', 'synonyms': ['flash', 'flashbulb'], 'def': 'a lamp for providing momentary light to take a photograph', 'name': 'flash'}, {'frequency': 'c', 'id': 463, 'synset': 'flashlight.n.01', 'synonyms': ['flashlight', 'torch'], 'def': 'a small portable battery-powered electric lamp', 'name': 'flashlight'}, {'frequency': 'r', 'id': 464, 'synset': 'fleece.n.03', 'synonyms': ['fleece'], 'def': 'a soft bulky fabric with deep pile; used chiefly for clothing', 'name': 'fleece'}, {'frequency': 'f', 'id': 465, 'synset': 'flip-flop.n.02', 'synonyms': ['flip-flop_(sandal)'], 'def': 'a backless sandal held to the foot by a thong between two toes', 'name': 'flip-flop_(sandal)'}, {'frequency': 'c', 'id': 466, 'synset': 'flipper.n.01', 'synonyms': ['flipper_(footwear)', 'fin_(footwear)'], 'def': 'a shoe to aid a person in swimming', 'name': 'flipper_(footwear)'}, {'frequency': 'f', 'id': 467, 'synset': 'flower_arrangement.n.01', 'synonyms': ['flower_arrangement', 'floral_arrangement'], 'def': 'a decorative arrangement of flowers', 'name': 'flower_arrangement'}, {'frequency': 'c', 'id': 468, 'synset': 'flute.n.02', 'synonyms': ['flute_glass', 'champagne_flute'], 'def': 'a tall narrow wineglass', 'name': 'flute_glass'}, {'frequency': 'r', 'id': 469, 'synset': 'foal.n.01', 'synonyms': ['foal'], 'def': 'a young horse', 'name': 'foal'}, {'frequency': 'c', 'id': 470, 'synset': 'folding_chair.n.01', 'synonyms': ['folding_chair'], 'def': 'a chair that can be folded flat for storage', 'name': 'folding_chair'}, {'frequency': 'c', 'id': 471, 'synset': 'food_processor.n.01', 'synonyms': ['food_processor'], 'def': 'a kitchen appliance for shredding, blending, chopping, or slicing food', 'name': 'food_processor'}, {'frequency': 'c', 'id': 472, 'synset': 'football.n.02', 'synonyms': ['football_(American)'], 'def': 'the inflated oblong ball used in playing American football', 'name': 'football_(American)'}, {'frequency': 'r', 'id': 473, 'synset': 'football_helmet.n.01', 'synonyms': ['football_helmet'], 'def': 'a padded helmet with a face mask to protect the head of football players', 'name': 'football_helmet'}, {'frequency': 'c', 'id': 474, 'synset': 'footstool.n.01', 'synonyms': ['footstool', 'footrest'], 'def': 'a low seat or a stool to rest the feet of a seated person', 'name': 'footstool'}, {'frequency': 'f', 'id': 475, 'synset': 'fork.n.01', 'synonyms': ['fork'], 'def': 'cutlery used for serving and eating food', 'name': 'fork'}, {'frequency': 'r', 'id': 476, 'synset': 'forklift.n.01', 'synonyms': ['forklift'], 'def': 'an industrial vehicle with a power operated fork in front that can be inserted under loads to lift and move them', 'name': 'forklift'}, {'frequency': 'r', 'id': 477, 'synset': 'freight_car.n.01', 'synonyms': ['freight_car'], 'def': 'a railway car that carries freight', 'name': 'freight_car'}, {'frequency': 'r', 'id': 478, 'synset': 'french_toast.n.01', 'synonyms': ['French_toast'], 'def': 'bread slice dipped in egg and milk and fried', 'name': 'French_toast'}, {'frequency': 'c', 'id': 479, 'synset': 'freshener.n.01', 'synonyms': ['freshener', 'air_freshener'], 'def': 'anything that freshens', 'name': 'freshener'}, {'frequency': 'f', 'id': 480, 'synset': 'frisbee.n.01', 'synonyms': ['frisbee'], 'def': 'a light, plastic disk propelled with a flip of the wrist for recreation or competition', 'name': 'frisbee'}, {'frequency': 'c', 'id': 481, 'synset': 'frog.n.01', 'synonyms': ['frog', 'toad', 'toad_frog'], 'def': 'a tailless stout-bodied amphibians with long hind limbs for leaping', 'name': 'frog'}, {'frequency': 'c', 'id': 482, 'synset': 'fruit_juice.n.01', 'synonyms': ['fruit_juice'], 'def': 'drink produced by squeezing or crushing fruit', 'name': 'fruit_juice'}, {'frequency': 'r', 'id': 483, 'synset': 'fruit_salad.n.01', 'synonyms': ['fruit_salad'], 'def': 'salad composed of fruits', 'name': 'fruit_salad'}, {'frequency': 'c', 'id': 484, 'synset': 'frying_pan.n.01', 'synonyms': ['frying_pan', 'frypan', 'skillet'], 'def': 'a pan used for frying foods', 'name': 'frying_pan'}, {'frequency': 'r', 'id': 485, 'synset': 'fudge.n.01', 'synonyms': ['fudge'], 'def': 'soft creamy candy', 'name': 'fudge'}, {'frequency': 'r', 'id': 486, 'synset': 'funnel.n.02', 'synonyms': ['funnel'], 'def': 'a cone-shaped utensil used to channel a substance into a container with a small mouth', 'name': 'funnel'}, {'frequency': 'c', 'id': 487, 'synset': 'futon.n.01', 'synonyms': ['futon'], 'def': 'a pad that is used for sleeping on the floor or on a raised frame', 'name': 'futon'}, {'frequency': 'r', 'id': 488, 'synset': 'gag.n.02', 'synonyms': ['gag', 'muzzle'], 'def': "restraint put into a person's mouth to prevent speaking or shouting", 'name': 'gag'}, {'frequency': 'r', 'id': 489, 'synset': 'garbage.n.03', 'synonyms': ['garbage'], 'def': 'a receptacle where waste can be discarded', 'name': 'garbage'}, {'frequency': 'c', 'id': 490, 'synset': 'garbage_truck.n.01', 'synonyms': ['garbage_truck'], 'def': 'a truck for collecting domestic refuse', 'name': 'garbage_truck'}, {'frequency': 'c', 'id': 491, 'synset': 'garden_hose.n.01', 'synonyms': ['garden_hose'], 'def': 'a hose used for watering a lawn or garden', 'name': 'garden_hose'}, {'frequency': 'c', 'id': 492, 'synset': 'gargle.n.01', 'synonyms': ['gargle', 'mouthwash'], 'def': 'a medicated solution used for gargling and rinsing the mouth', 'name': 'gargle'}, {'frequency': 'r', 'id': 493, 'synset': 'gargoyle.n.02', 'synonyms': ['gargoyle'], 'def': 'an ornament consisting of a grotesquely carved figure of a person or animal', 'name': 'gargoyle'}, {'frequency': 'c', 'id': 494, 'synset': 'garlic.n.02', 'synonyms': ['garlic', 'ail'], 'def': 'aromatic bulb used as seasoning', 'name': 'garlic'}, {'frequency': 'r', 'id': 495, 'synset': 'gasmask.n.01', 'synonyms': ['gasmask', 'respirator', 'gas_helmet'], 'def': 'a protective face mask with a filter', 'name': 'gasmask'}, {'frequency': 'r', 'id': 496, 'synset': 'gazelle.n.01', 'synonyms': ['gazelle'], 'def': 'small swift graceful antelope of Africa and Asia having lustrous eyes', 'name': 'gazelle'}, {'frequency': 'c', 'id': 497, 'synset': 'gelatin.n.02', 'synonyms': ['gelatin', 'jelly'], 'def': 'an edible jelly made with gelatin and used as a dessert or salad base or a coating for foods', 'name': 'gelatin'}, {'frequency': 'r', 'id': 498, 'synset': 'gem.n.02', 'synonyms': ['gemstone'], 'def': 'a crystalline rock that can be cut and polished for jewelry', 'name': 'gemstone'}, {'frequency': 'c', 'id': 499, 'synset': 'giant_panda.n.01', 'synonyms': ['giant_panda', 'panda', 'panda_bear'], 'def': 'large black-and-white herbivorous mammal of bamboo forests of China and Tibet', 'name': 'giant_panda'}, {'frequency': 'c', 'id': 500, 'synset': 'gift_wrap.n.01', 'synonyms': ['gift_wrap'], 'def': 'attractive wrapping paper suitable for wrapping gifts', 'name': 'gift_wrap'}, {'frequency': 'c', 'id': 501, 'synset': 'ginger.n.03', 'synonyms': ['ginger', 'gingerroot'], 'def': 'the root of the common ginger plant; used fresh as a seasoning', 'name': 'ginger'}, {'frequency': 'f', 'id': 502, 'synset': 'giraffe.n.01', 'synonyms': ['giraffe'], 'def': 'tall animal having a spotted coat and small horns and very long neck and legs', 'name': 'giraffe'}, {'frequency': 'c', 'id': 503, 'synset': 'girdle.n.02', 'synonyms': ['cincture', 'sash', 'waistband', 'waistcloth'], 'def': 'a band of material around the waist that strengthens a skirt or trousers', 'name': 'cincture'}, {'frequency': 'f', 'id': 504, 'synset': 'glass.n.02', 'synonyms': ['glass_(drink_container)', 'drinking_glass'], 'def': 'a container for holding liquids while drinking', 'name': 'glass_(drink_container)'}, {'frequency': 'c', 'id': 505, 'synset': 'globe.n.03', 'synonyms': ['globe'], 'def': 'a sphere on which a map (especially of the earth) is represented', 'name': 'globe'}, {'frequency': 'f', 'id': 506, 'synset': 'glove.n.02', 'synonyms': ['glove'], 'def': 'handwear covering the hand', 'name': 'glove'}, {'frequency': 'c', 'id': 507, 'synset': 'goat.n.01', 'synonyms': ['goat'], 'def': 'a common goat', 'name': 'goat'}, {'frequency': 'f', 'id': 508, 'synset': 'goggles.n.01', 'synonyms': ['goggles'], 'def': 'tight-fitting spectacles worn to protect the eyes', 'name': 'goggles'}, {'frequency': 'r', 'id': 509, 'synset': 'goldfish.n.01', 'synonyms': ['goldfish'], 'def': 'small golden or orange-red freshwater fishes used as pond or aquarium pets', 'name': 'goldfish'}, {'frequency': 'r', 'id': 510, 'synset': 'golf_club.n.02', 'synonyms': ['golf_club', 'golf-club'], 'def': 'golf equipment used by a golfer to hit a golf ball', 'name': 'golf_club'}, {'frequency': 'c', 'id': 511, 'synset': 'golfcart.n.01', 'synonyms': ['golfcart'], 'def': 'a small motor vehicle in which golfers can ride between shots', 'name': 'golfcart'}, {'frequency': 'r', 'id': 512, 'synset': 'gondola.n.02', 'synonyms': ['gondola_(boat)'], 'def': 'long narrow flat-bottomed boat propelled by sculling; traditionally used on canals of Venice', 'name': 'gondola_(boat)'}, {'frequency': 'c', 'id': 513, 'synset': 'goose.n.01', 'synonyms': ['goose'], 'def': 'loud, web-footed long-necked aquatic birds usually larger than ducks', 'name': 'goose'}, {'frequency': 'r', 'id': 514, 'synset': 'gorilla.n.01', 'synonyms': ['gorilla'], 'def': 'largest ape', 'name': 'gorilla'}, {'frequency': 'r', 'id': 515, 'synset': 'gourd.n.02', 'synonyms': ['gourd'], 'def': 'any of numerous inedible fruits with hard rinds', 'name': 'gourd'}, {'frequency': 'r', 'id': 516, 'synset': 'gown.n.04', 'synonyms': ['surgical_gown', 'scrubs_(surgical_clothing)'], 'def': 'protective garment worn by surgeons during operations', 'name': 'surgical_gown'}, {'frequency': 'f', 'id': 517, 'synset': 'grape.n.01', 'synonyms': ['grape'], 'def': 'any of various juicy fruit with green or purple skins; grow in clusters', 'name': 'grape'}, {'frequency': 'r', 'id': 518, 'synset': 'grasshopper.n.01', 'synonyms': ['grasshopper'], 'def': 'plant-eating insect with hind legs adapted for leaping', 'name': 'grasshopper'}, {'frequency': 'c', 'id': 519, 'synset': 'grater.n.01', 'synonyms': ['grater'], 'def': 'utensil with sharp perforations for shredding foods (as vegetables or cheese)', 'name': 'grater'}, {'frequency': 'c', 'id': 520, 'synset': 'gravestone.n.01', 'synonyms': ['gravestone', 'headstone', 'tombstone'], 'def': 'a stone that is used to mark a grave', 'name': 'gravestone'}, {'frequency': 'r', 'id': 521, 'synset': 'gravy_boat.n.01', 'synonyms': ['gravy_boat', 'gravy_holder'], 'def': 'a dish (often boat-shaped) for serving gravy or sauce', 'name': 'gravy_boat'}, {'frequency': 'c', 'id': 522, 'synset': 'green_bean.n.02', 'synonyms': ['green_bean'], 'def': 'a common bean plant cultivated for its slender green edible pods', 'name': 'green_bean'}, {'frequency': 'c', 'id': 523, 'synset': 'green_onion.n.01', 'synonyms': ['green_onion', 'spring_onion', 'scallion'], 'def': 'a young onion before the bulb has enlarged', 'name': 'green_onion'}, {'frequency': 'r', 'id': 524, 'synset': 'griddle.n.01', 'synonyms': ['griddle'], 'def': 'cooking utensil consisting of a flat heated surface on which food is cooked', 'name': 'griddle'}, {'frequency': 'r', 'id': 525, 'synset': 'grillroom.n.01', 'synonyms': ['grillroom', 'grill_(restaurant)'], 'def': 'a restaurant where food is cooked on a grill', 'name': 'grillroom'}, {'frequency': 'r', 'id': 526, 'synset': 'grinder.n.04', 'synonyms': ['grinder_(tool)'], 'def': 'a machine tool that polishes metal', 'name': 'grinder_(tool)'}, {'frequency': 'r', 'id': 527, 'synset': 'grits.n.01', 'synonyms': ['grits', 'hominy_grits'], 'def': 'coarsely ground corn boiled as a breakfast dish', 'name': 'grits'}, {'frequency': 'c', 'id': 528, 'synset': 'grizzly.n.01', 'synonyms': ['grizzly', 'grizzly_bear'], 'def': 'powerful brownish-yellow bear of the uplands of western North America', 'name': 'grizzly'}, {'frequency': 'c', 'id': 529, 'synset': 'grocery_bag.n.01', 'synonyms': ['grocery_bag'], 'def': "a sack for holding customer's groceries", 'name': 'grocery_bag'}, {'frequency': 'r', 'id': 530, 'synset': 'guacamole.n.01', 'synonyms': ['guacamole'], 'def': 'a dip made of mashed avocado mixed with chopped onions and other seasonings', 'name': 'guacamole'}, {'frequency': 'f', 'id': 531, 'synset': 'guitar.n.01', 'synonyms': ['guitar'], 'def': 'a stringed instrument usually having six strings; played by strumming or plucking', 'name': 'guitar'}, {'frequency': 'c', 'id': 532, 'synset': 'gull.n.02', 'synonyms': ['gull', 'seagull'], 'def': 'mostly white aquatic bird having long pointed wings and short legs', 'name': 'gull'}, {'frequency': 'c', 'id': 533, 'synset': 'gun.n.01', 'synonyms': ['gun'], 'def': 'a weapon that discharges a bullet at high velocity from a metal tube', 'name': 'gun'}, {'frequency': 'r', 'id': 534, 'synset': 'hair_spray.n.01', 'synonyms': ['hair_spray'], 'def': 'substance sprayed on the hair to hold it in place', 'name': 'hair_spray'}, {'frequency': 'c', 'id': 535, 'synset': 'hairbrush.n.01', 'synonyms': ['hairbrush'], 'def': "a brush used to groom a person's hair", 'name': 'hairbrush'}, {'frequency': 'c', 'id': 536, 'synset': 'hairnet.n.01', 'synonyms': ['hairnet'], 'def': 'a small net that someone wears over their hair to keep it in place', 'name': 'hairnet'}, {'frequency': 'c', 'id': 537, 'synset': 'hairpin.n.01', 'synonyms': ['hairpin'], 'def': "a double pronged pin used to hold women's hair in place", 'name': 'hairpin'}, {'frequency': 'f', 'id': 538, 'synset': 'ham.n.01', 'synonyms': ['ham', 'jambon', 'gammon'], 'def': 'meat cut from the thigh of a hog (usually smoked)', 'name': 'ham'}, {'frequency': 'c', 'id': 539, 'synset': 'hamburger.n.01', 'synonyms': ['hamburger', 'beefburger', 'burger'], 'def': 'a sandwich consisting of a patty of minced beef served on a bun', 'name': 'hamburger'}, {'frequency': 'c', 'id': 540, 'synset': 'hammer.n.02', 'synonyms': ['hammer'], 'def': 'a hand tool with a heavy head and a handle; used to deliver an impulsive force by striking', 'name': 'hammer'}, {'frequency': 'r', 'id': 541, 'synset': 'hammock.n.02', 'synonyms': ['hammock'], 'def': 'a hanging bed of canvas or rope netting (usually suspended between two trees)', 'name': 'hammock'}, {'frequency': 'r', 'id': 542, 'synset': 'hamper.n.02', 'synonyms': ['hamper'], 'def': 'a basket usually with a cover', 'name': 'hamper'}, {'frequency': 'r', 'id': 543, 'synset': 'hamster.n.01', 'synonyms': ['hamster'], 'def': 'short-tailed burrowing rodent with large cheek pouches', 'name': 'hamster'}, {'frequency': 'c', 'id': 544, 'synset': 'hand_blower.n.01', 'synonyms': ['hair_dryer'], 'def': 'a hand-held electric blower that can blow warm air onto the hair', 'name': 'hair_dryer'}, {'frequency': 'r', 'id': 545, 'synset': 'hand_glass.n.01', 'synonyms': ['hand_glass', 'hand_mirror'], 'def': 'a mirror intended to be held in the hand', 'name': 'hand_glass'}, {'frequency': 'f', 'id': 546, 'synset': 'hand_towel.n.01', 'synonyms': ['hand_towel', 'face_towel'], 'def': 'a small towel used to dry the hands or face', 'name': 'hand_towel'}, {'frequency': 'c', 'id': 547, 'synset': 'handcart.n.01', 'synonyms': ['handcart', 'pushcart', 'hand_truck'], 'def': 'wheeled vehicle that can be pushed by a person', 'name': 'handcart'}, {'frequency': 'r', 'id': 548, 'synset': 'handcuff.n.01', 'synonyms': ['handcuff'], 'def': 'shackle that consists of a metal loop that can be locked around the wrist', 'name': 'handcuff'}, {'frequency': 'c', 'id': 549, 'synset': 'handkerchief.n.01', 'synonyms': ['handkerchief'], 'def': 'a square piece of cloth used for wiping the eyes or nose or as a costume accessory', 'name': 'handkerchief'}, {'frequency': 'f', 'id': 550, 'synset': 'handle.n.01', 'synonyms': ['handle', 'grip', 'handgrip'], 'def': 'the appendage to an object that is designed to be held in order to use or move it', 'name': 'handle'}, {'frequency': 'r', 'id': 551, 'synset': 'handsaw.n.01', 'synonyms': ['handsaw', "carpenter's_saw"], 'def': 'a saw used with one hand for cutting wood', 'name': 'handsaw'}, {'frequency': 'r', 'id': 552, 'synset': 'hardback.n.01', 'synonyms': ['hardback_book', 'hardcover_book'], 'def': 'a book with cardboard or cloth or leather covers', 'name': 'hardback_book'}, {'frequency': 'r', 'id': 553, 'synset': 'harmonium.n.01', 'synonyms': ['harmonium', 'organ_(musical_instrument)', 'reed_organ_(musical_instrument)'], 'def': 'a free-reed instrument in which air is forced through the reeds by bellows', 'name': 'harmonium'}, {'frequency': 'f', 'id': 554, 'synset': 'hat.n.01', 'synonyms': ['hat'], 'def': 'headwear that protects the head from bad weather, sun, or worn for fashion', 'name': 'hat'}, {'frequency': 'r', 'id': 555, 'synset': 'hatbox.n.01', 'synonyms': ['hatbox'], 'def': 'a round piece of luggage for carrying hats', 'name': 'hatbox'}, {'frequency': 'r', 'id': 556, 'synset': 'hatch.n.03', 'synonyms': ['hatch'], 'def': 'a movable barrier covering a hatchway', 'name': 'hatch'}, {'frequency': 'c', 'id': 557, 'synset': 'head_covering.n.01', 'synonyms': ['veil'], 'def': 'a garment that covers the head and face', 'name': 'veil'}, {'frequency': 'f', 'id': 558, 'synset': 'headband.n.01', 'synonyms': ['headband'], 'def': 'a band worn around or over the head', 'name': 'headband'}, {'frequency': 'f', 'id': 559, 'synset': 'headboard.n.01', 'synonyms': ['headboard'], 'def': 'a vertical board or panel forming the head of a bedstead', 'name': 'headboard'}, {'frequency': 'f', 'id': 560, 'synset': 'headlight.n.01', 'synonyms': ['headlight', 'headlamp'], 'def': 'a powerful light with reflector; attached to the front of an automobile or locomotive', 'name': 'headlight'}, {'frequency': 'c', 'id': 561, 'synset': 'headscarf.n.01', 'synonyms': ['headscarf'], 'def': 'a kerchief worn over the head and tied under the chin', 'name': 'headscarf'}, {'frequency': 'r', 'id': 562, 'synset': 'headset.n.01', 'synonyms': ['headset'], 'def': 'receiver consisting of a pair of headphones', 'name': 'headset'}, {'frequency': 'c', 'id': 563, 'synset': 'headstall.n.01', 'synonyms': ['headstall_(for_horses)', 'headpiece_(for_horses)'], 'def': "the band that is the part of a bridle that fits around a horse's head", 'name': 'headstall_(for_horses)'}, {'frequency': 'r', 'id': 564, 'synset': 'hearing_aid.n.02', 'synonyms': ['hearing_aid'], 'def': 'an acoustic device used to direct sound to the ear of a hearing-impaired person', 'name': 'hearing_aid'}, {'frequency': 'c', 'id': 565, 'synset': 'heart.n.02', 'synonyms': ['heart'], 'def': 'a muscular organ; its contractions move the blood through the body', 'name': 'heart'}, {'frequency': 'c', 'id': 566, 'synset': 'heater.n.01', 'synonyms': ['heater', 'warmer'], 'def': 'device that heats water or supplies warmth to a room', 'name': 'heater'}, {'frequency': 'c', 'id': 567, 'synset': 'helicopter.n.01', 'synonyms': ['helicopter'], 'def': 'an aircraft without wings that obtains its lift from the rotation of overhead blades', 'name': 'helicopter'}, {'frequency': 'f', 'id': 568, 'synset': 'helmet.n.02', 'synonyms': ['helmet'], 'def': 'a protective headgear made of hard material to resist blows', 'name': 'helmet'}, {'frequency': 'r', 'id': 569, 'synset': 'heron.n.02', 'synonyms': ['heron'], 'def': 'grey or white wading bird with long neck and long legs and (usually) long bill', 'name': 'heron'}, {'frequency': 'c', 'id': 570, 'synset': 'highchair.n.01', 'synonyms': ['highchair', 'feeding_chair'], 'def': 'a chair for feeding a very young child', 'name': 'highchair'}, {'frequency': 'f', 'id': 571, 'synset': 'hinge.n.01', 'synonyms': ['hinge'], 'def': 'a joint that holds two parts together so that one can swing relative to the other', 'name': 'hinge'}, {'frequency': 'r', 'id': 572, 'synset': 'hippopotamus.n.01', 'synonyms': ['hippopotamus'], 'def': 'massive thick-skinned animal living in or around rivers of tropical Africa', 'name': 'hippopotamus'}, {'frequency': 'r', 'id': 573, 'synset': 'hockey_stick.n.01', 'synonyms': ['hockey_stick'], 'def': 'sports implement consisting of a stick used by hockey players to move the puck', 'name': 'hockey_stick'}, {'frequency': 'c', 'id': 574, 'synset': 'hog.n.03', 'synonyms': ['hog', 'pig'], 'def': 'domestic swine', 'name': 'hog'}, {'frequency': 'f', 'id': 575, 'synset': 'home_plate.n.01', 'synonyms': ['home_plate_(baseball)', 'home_base_(baseball)'], 'def': '(baseball) a rubber slab where the batter stands; it must be touched by a base runner in order to score', 'name': 'home_plate_(baseball)'}, {'frequency': 'c', 'id': 576, 'synset': 'honey.n.01', 'synonyms': ['honey'], 'def': 'a sweet yellow liquid produced by bees', 'name': 'honey'}, {'frequency': 'f', 'id': 577, 'synset': 'hood.n.06', 'synonyms': ['fume_hood', 'exhaust_hood'], 'def': 'metal covering leading to a vent that exhausts smoke or fumes', 'name': 'fume_hood'}, {'frequency': 'f', 'id': 578, 'synset': 'hook.n.05', 'synonyms': ['hook'], 'def': 'a curved or bent implement for suspending or pulling something', 'name': 'hook'}, {'frequency': 'f', 'id': 579, 'synset': 'horse.n.01', 'synonyms': ['horse'], 'def': 'a common horse', 'name': 'horse'}, {'frequency': 'f', 'id': 580, 'synset': 'hose.n.03', 'synonyms': ['hose', 'hosepipe'], 'def': 'a flexible pipe for conveying a liquid or gas', 'name': 'hose'}, {'frequency': 'r', 'id': 581, 'synset': 'hot-air_balloon.n.01', 'synonyms': ['hot-air_balloon'], 'def': 'balloon for travel through the air in a basket suspended below a large bag of heated air', 'name': 'hot-air_balloon'}, {'frequency': 'r', 'id': 582, 'synset': 'hot_plate.n.01', 'synonyms': ['hotplate'], 'def': 'a portable electric appliance for heating or cooking or keeping food warm', 'name': 'hotplate'}, {'frequency': 'c', 'id': 583, 'synset': 'hot_sauce.n.01', 'synonyms': ['hot_sauce'], 'def': 'a pungent peppery sauce', 'name': 'hot_sauce'}, {'frequency': 'r', 'id': 584, 'synset': 'hourglass.n.01', 'synonyms': ['hourglass'], 'def': 'a sandglass timer that runs for sixty minutes', 'name': 'hourglass'}, {'frequency': 'r', 'id': 585, 'synset': 'houseboat.n.01', 'synonyms': ['houseboat'], 'def': 'a barge that is designed and equipped for use as a dwelling', 'name': 'houseboat'}, {'frequency': 'r', 'id': 586, 'synset': 'hummingbird.n.01', 'synonyms': ['hummingbird'], 'def': 'tiny American bird having brilliant iridescent plumage and long slender bills', 'name': 'hummingbird'}, {'frequency': 'r', 'id': 587, 'synset': 'hummus.n.01', 'synonyms': ['hummus', 'humus', 'hommos', 'hoummos', 'humous'], 'def': 'a thick spread made from mashed chickpeas', 'name': 'hummus'}, {'frequency': 'c', 'id': 588, 'synset': 'ice_bear.n.01', 'synonyms': ['polar_bear'], 'def': 'white bear of Arctic regions', 'name': 'polar_bear'}, {'frequency': 'c', 'id': 589, 'synset': 'ice_cream.n.01', 'synonyms': ['icecream'], 'def': 'frozen dessert containing cream and sugar and flavoring', 'name': 'icecream'}, {'frequency': 'r', 'id': 590, 'synset': 'ice_lolly.n.01', 'synonyms': ['popsicle'], 'def': 'ice cream or water ice on a small wooden stick', 'name': 'popsicle'}, {'frequency': 'c', 'id': 591, 'synset': 'ice_maker.n.01', 'synonyms': ['ice_maker'], 'def': 'an appliance included in some electric refrigerators for making ice cubes', 'name': 'ice_maker'}, {'frequency': 'r', 'id': 592, 'synset': 'ice_pack.n.01', 'synonyms': ['ice_pack', 'ice_bag'], 'def': 'a waterproof bag filled with ice: applied to the body (especially the head) to cool or reduce swelling', 'name': 'ice_pack'}, {'frequency': 'r', 'id': 593, 'synset': 'ice_skate.n.01', 'synonyms': ['ice_skate'], 'def': 'skate consisting of a boot with a steel blade fitted to the sole', 'name': 'ice_skate'}, {'frequency': 'r', 'id': 594, 'synset': 'ice_tea.n.01', 'synonyms': ['ice_tea', 'iced_tea'], 'def': 'strong tea served over ice', 'name': 'ice_tea'}, {'frequency': 'c', 'id': 595, 'synset': 'igniter.n.01', 'synonyms': ['igniter', 'ignitor', 'lighter'], 'def': 'a substance or device used to start a fire', 'name': 'igniter'}, {'frequency': 'r', 'id': 596, 'synset': 'incense.n.01', 'synonyms': ['incense'], 'def': 'a substance that produces a fragrant odor when burned', 'name': 'incense'}, {'frequency': 'r', 'id': 597, 'synset': 'inhaler.n.01', 'synonyms': ['inhaler', 'inhalator'], 'def': 'a dispenser that produces a chemical vapor to be inhaled through mouth or nose', 'name': 'inhaler'}, {'frequency': 'c', 'id': 598, 'synset': 'ipod.n.01', 'synonyms': ['iPod'], 'def': 'a pocket-sized device used to play music files', 'name': 'iPod'}, {'frequency': 'c', 'id': 599, 'synset': 'iron.n.04', 'synonyms': ['iron_(for_clothing)', 'smoothing_iron_(for_clothing)'], 'def': 'home appliance consisting of a flat metal base that is heated and used to smooth cloth', 'name': 'iron_(for_clothing)'}, {'frequency': 'r', 'id': 600, 'synset': 'ironing_board.n.01', 'synonyms': ['ironing_board'], 'def': 'narrow padded board on collapsible supports; used for ironing clothes', 'name': 'ironing_board'}, {'frequency': 'f', 'id': 601, 'synset': 'jacket.n.01', 'synonyms': ['jacket'], 'def': 'a waist-length coat', 'name': 'jacket'}, {'frequency': 'r', 'id': 602, 'synset': 'jam.n.01', 'synonyms': ['jam'], 'def': 'preserve of crushed fruit', 'name': 'jam'}, {'frequency': 'f', 'id': 603, 'synset': 'jean.n.01', 'synonyms': ['jean', 'blue_jean', 'denim'], 'def': '(usually plural) close-fitting trousers of heavy denim for manual work or casual wear', 'name': 'jean'}, {'frequency': 'c', 'id': 604, 'synset': 'jeep.n.01', 'synonyms': ['jeep', 'landrover'], 'def': 'a car suitable for traveling over rough terrain', 'name': 'jeep'}, {'frequency': 'r', 'id': 605, 'synset': 'jelly_bean.n.01', 'synonyms': ['jelly_bean', 'jelly_egg'], 'def': 'sugar-glazed jellied candy', 'name': 'jelly_bean'}, {'frequency': 'f', 'id': 606, 'synset': 'jersey.n.03', 'synonyms': ['jersey', 'T-shirt', 'tee_shirt'], 'def': 'a close-fitting pullover shirt', 'name': 'jersey'}, {'frequency': 'c', 'id': 607, 'synset': 'jet.n.01', 'synonyms': ['jet_plane', 'jet-propelled_plane'], 'def': 'an airplane powered by one or more jet engines', 'name': 'jet_plane'}, {'frequency': 'c', 'id': 608, 'synset': 'jewelry.n.01', 'synonyms': ['jewelry', 'jewellery'], 'def': 'an adornment (as a bracelet or ring or necklace) made of precious metals and set with gems (or imitation gems)', 'name': 'jewelry'}, {'frequency': 'r', 'id': 609, 'synset': 'joystick.n.02', 'synonyms': ['joystick'], 'def': 'a control device for computers consisting of a vertical handle that can move freely in two directions', 'name': 'joystick'}, {'frequency': 'r', 'id': 610, 'synset': 'jump_suit.n.01', 'synonyms': ['jumpsuit'], 'def': "one-piece garment fashioned after a parachutist's uniform", 'name': 'jumpsuit'}, {'frequency': 'c', 'id': 611, 'synset': 'kayak.n.01', 'synonyms': ['kayak'], 'def': 'a small canoe consisting of a light frame made watertight with animal skins', 'name': 'kayak'}, {'frequency': 'r', 'id': 612, 'synset': 'keg.n.02', 'synonyms': ['keg'], 'def': 'small cask or barrel', 'name': 'keg'}, {'frequency': 'r', 'id': 613, 'synset': 'kennel.n.01', 'synonyms': ['kennel', 'doghouse'], 'def': 'outbuilding that serves as a shelter for a dog', 'name': 'kennel'}, {'frequency': 'c', 'id': 614, 'synset': 'kettle.n.01', 'synonyms': ['kettle', 'boiler'], 'def': 'a metal pot for stewing or boiling; usually has a lid', 'name': 'kettle'}, {'frequency': 'f', 'id': 615, 'synset': 'key.n.01', 'synonyms': ['key'], 'def': 'metal instrument used to unlock a lock', 'name': 'key'}, {'frequency': 'r', 'id': 616, 'synset': 'keycard.n.01', 'synonyms': ['keycard'], 'def': 'a plastic card used to gain access typically to a door', 'name': 'keycard'}, {'frequency': 'r', 'id': 617, 'synset': 'kilt.n.01', 'synonyms': ['kilt'], 'def': 'a knee-length pleated tartan skirt worn by men as part of the traditional dress in the Highlands of northern Scotland', 'name': 'kilt'}, {'frequency': 'c', 'id': 618, 'synset': 'kimono.n.01', 'synonyms': ['kimono'], 'def': 'a loose robe; imitated from robes originally worn by Japanese', 'name': 'kimono'}, {'frequency': 'f', 'id': 619, 'synset': 'kitchen_sink.n.01', 'synonyms': ['kitchen_sink'], 'def': 'a sink in a kitchen', 'name': 'kitchen_sink'}, {'frequency': 'c', 'id': 620, 'synset': 'kitchen_table.n.01', 'synonyms': ['kitchen_table'], 'def': 'a table in the kitchen', 'name': 'kitchen_table'}, {'frequency': 'f', 'id': 621, 'synset': 'kite.n.03', 'synonyms': ['kite'], 'def': 'plaything consisting of a light frame covered with tissue paper; flown in wind at end of a string', 'name': 'kite'}, {'frequency': 'c', 'id': 622, 'synset': 'kitten.n.01', 'synonyms': ['kitten', 'kitty'], 'def': 'young domestic cat', 'name': 'kitten'}, {'frequency': 'c', 'id': 623, 'synset': 'kiwi.n.03', 'synonyms': ['kiwi_fruit'], 'def': 'fuzzy brown egg-shaped fruit with slightly tart green flesh', 'name': 'kiwi_fruit'}, {'frequency': 'f', 'id': 624, 'synset': 'knee_pad.n.01', 'synonyms': ['knee_pad'], 'def': 'protective garment consisting of a pad worn by football or baseball or hockey players', 'name': 'knee_pad'}, {'frequency': 'f', 'id': 625, 'synset': 'knife.n.01', 'synonyms': ['knife'], 'def': 'tool with a blade and point used as a cutting instrument', 'name': 'knife'}, {'frequency': 'r', 'id': 626, 'synset': 'knight.n.02', 'synonyms': ['knight_(chess_piece)', 'horse_(chess_piece)'], 'def': 'a chess game piece shaped to resemble the head of a horse', 'name': 'knight_(chess_piece)'}, {'frequency': 'r', 'id': 627, 'synset': 'knitting_needle.n.01', 'synonyms': ['knitting_needle'], 'def': 'needle consisting of a slender rod with pointed ends; usually used in pairs', 'name': 'knitting_needle'}, {'frequency': 'f', 'id': 628, 'synset': 'knob.n.02', 'synonyms': ['knob'], 'def': 'a round handle often found on a door', 'name': 'knob'}, {'frequency': 'r', 'id': 629, 'synset': 'knocker.n.05', 'synonyms': ['knocker_(on_a_door)', 'doorknocker'], 'def': 'a device (usually metal and ornamental) attached by a hinge to a door', 'name': 'knocker_(on_a_door)'}, {'frequency': 'r', 'id': 630, 'synset': 'koala.n.01', 'synonyms': ['koala', 'koala_bear'], 'def': 'sluggish tailless Australian marsupial with grey furry ears and coat', 'name': 'koala'}, {'frequency': 'r', 'id': 631, 'synset': 'lab_coat.n.01', 'synonyms': ['lab_coat', 'laboratory_coat'], 'def': 'a light coat worn to protect clothing from substances used while working in a laboratory', 'name': 'lab_coat'}, {'frequency': 'f', 'id': 632, 'synset': 'ladder.n.01', 'synonyms': ['ladder'], 'def': 'steps consisting of two parallel members connected by rungs', 'name': 'ladder'}, {'frequency': 'c', 'id': 633, 'synset': 'ladle.n.01', 'synonyms': ['ladle'], 'def': 'a spoon-shaped vessel with a long handle frequently used to transfer liquids', 'name': 'ladle'}, {'frequency': 'r', 'id': 634, 'synset': 'ladybug.n.01', 'synonyms': ['ladybug', 'ladybeetle', 'ladybird_beetle'], 'def': 'small round bright-colored and spotted beetle, typically red and black', 'name': 'ladybug'}, {'frequency': 'c', 'id': 635, 'synset': 'lamb.n.01', 'synonyms': ['lamb_(animal)'], 'def': 'young sheep', 'name': 'lamb_(animal)'}, {'frequency': 'r', 'id': 636, 'synset': 'lamb_chop.n.01', 'synonyms': ['lamb-chop', 'lambchop'], 'def': 'chop cut from a lamb', 'name': 'lamb-chop'}, {'frequency': 'f', 'id': 637, 'synset': 'lamp.n.02', 'synonyms': ['lamp'], 'def': 'a piece of furniture holding one or more electric light bulbs', 'name': 'lamp'}, {'frequency': 'f', 'id': 638, 'synset': 'lamppost.n.01', 'synonyms': ['lamppost'], 'def': 'a metal post supporting an outdoor lamp (such as a streetlight)', 'name': 'lamppost'}, {'frequency': 'f', 'id': 639, 'synset': 'lampshade.n.01', 'synonyms': ['lampshade'], 'def': 'a protective ornamental shade used to screen a light bulb from direct view', 'name': 'lampshade'}, {'frequency': 'c', 'id': 640, 'synset': 'lantern.n.01', 'synonyms': ['lantern'], 'def': 'light in a transparent protective case', 'name': 'lantern'}, {'frequency': 'f', 'id': 641, 'synset': 'lanyard.n.02', 'synonyms': ['lanyard', 'laniard'], 'def': 'a cord worn around the neck to hold a knife or whistle, etc.', 'name': 'lanyard'}, {'frequency': 'f', 'id': 642, 'synset': 'laptop.n.01', 'synonyms': ['laptop_computer', 'notebook_computer'], 'def': 'a portable computer small enough to use in your lap', 'name': 'laptop_computer'}, {'frequency': 'r', 'id': 643, 'synset': 'lasagna.n.01', 'synonyms': ['lasagna', 'lasagne'], 'def': 'baked dish of layers of lasagna pasta with sauce and cheese and meat or vegetables', 'name': 'lasagna'}, {'frequency': 'c', 'id': 644, 'synset': 'latch.n.02', 'synonyms': ['latch'], 'def': 'a bar that can be lowered or slid into a groove to fasten a door or gate', 'name': 'latch'}, {'frequency': 'r', 'id': 645, 'synset': 'lawn_mower.n.01', 'synonyms': ['lawn_mower'], 'def': 'garden tool for mowing grass on lawns', 'name': 'lawn_mower'}, {'frequency': 'r', 'id': 646, 'synset': 'leather.n.01', 'synonyms': ['leather'], 'def': 'an animal skin made smooth and flexible by removing the hair and then tanning', 'name': 'leather'}, {'frequency': 'c', 'id': 647, 'synset': 'legging.n.01', 'synonyms': ['legging_(clothing)', 'leging_(clothing)', 'leg_covering'], 'def': 'a garment covering the leg (usually extending from the knee to the ankle)', 'name': 'legging_(clothing)'}, {'frequency': 'c', 'id': 648, 'synset': 'lego.n.01', 'synonyms': ['Lego', 'Lego_set'], 'def': "a child's plastic construction set for making models from blocks", 'name': 'Lego'}, {'frequency': 'f', 'id': 649, 'synset': 'lemon.n.01', 'synonyms': ['lemon'], 'def': 'yellow oval fruit with juicy acidic flesh', 'name': 'lemon'}, {'frequency': 'r', 'id': 650, 'synset': 'lemonade.n.01', 'synonyms': ['lemonade'], 'def': 'sweetened beverage of diluted lemon juice', 'name': 'lemonade'}, {'frequency': 'f', 'id': 651, 'synset': 'lettuce.n.02', 'synonyms': ['lettuce'], 'def': 'leafy plant commonly eaten in salad or on sandwiches', 'name': 'lettuce'}, {'frequency': 'f', 'id': 652, 'synset': 'license_plate.n.01', 'synonyms': ['license_plate', 'numberplate'], 'def': "a plate mounted on the front and back of car and bearing the car's registration number", 'name': 'license_plate'}, {'frequency': 'f', 'id': 653, 'synset': 'life_buoy.n.01', 'synonyms': ['life_buoy', 'lifesaver', 'life_belt', 'life_ring'], 'def': 'a ring-shaped life preserver used to prevent drowning (NOT a life-jacket or vest)', 'name': 'life_buoy'}, {'frequency': 'f', 'id': 654, 'synset': 'life_jacket.n.01', 'synonyms': ['life_jacket', 'life_vest'], 'def': 'life preserver consisting of a sleeveless jacket of buoyant or inflatable design', 'name': 'life_jacket'}, {'frequency': 'f', 'id': 655, 'synset': 'light_bulb.n.01', 'synonyms': ['lightbulb'], 'def': 'glass bulb or tube shaped electric device that emits light (DO NOT MARK LAMPS AS A WHOLE)', 'name': 'lightbulb'}, {'frequency': 'r', 'id': 656, 'synset': 'lightning_rod.n.02', 'synonyms': ['lightning_rod', 'lightning_conductor'], 'def': 'a metallic conductor that is attached to a high point and leads to the ground', 'name': 'lightning_rod'}, {'frequency': 'c', 'id': 657, 'synset': 'lime.n.06', 'synonyms': ['lime'], 'def': 'the green acidic fruit of any of various lime trees', 'name': 'lime'}, {'frequency': 'r', 'id': 658, 'synset': 'limousine.n.01', 'synonyms': ['limousine'], 'def': 'long luxurious car; usually driven by a chauffeur', 'name': 'limousine'}, {'frequency': 'r', 'id': 659, 'synset': 'linen.n.02', 'synonyms': ['linen_paper'], 'def': 'a high-quality paper made of linen fibers or with a linen finish', 'name': 'linen_paper'}, {'frequency': 'c', 'id': 660, 'synset': 'lion.n.01', 'synonyms': ['lion'], 'def': 'large gregarious predatory cat of Africa and India', 'name': 'lion'}, {'frequency': 'c', 'id': 661, 'synset': 'lip_balm.n.01', 'synonyms': ['lip_balm'], 'def': 'a balm applied to the lips', 'name': 'lip_balm'}, {'frequency': 'c', 'id': 662, 'synset': 'lipstick.n.01', 'synonyms': ['lipstick', 'lip_rouge'], 'def': 'makeup that is used to color the lips', 'name': 'lipstick'}, {'frequency': 'r', 'id': 663, 'synset': 'liquor.n.01', 'synonyms': ['liquor', 'spirits', 'hard_liquor', 'liqueur', 'cordial'], 'def': 'an alcoholic beverage that is distilled rather than fermented', 'name': 'liquor'}, {'frequency': 'r', 'id': 664, 'synset': 'lizard.n.01', 'synonyms': ['lizard'], 'def': 'a reptile with usually two pairs of legs and a tapering tail', 'name': 'lizard'}, {'frequency': 'r', 'id': 665, 'synset': 'loafer.n.02', 'synonyms': ['Loafer_(type_of_shoe)'], 'def': 'a low leather step-in shoe', 'name': 'Loafer_(type_of_shoe)'}, {'frequency': 'f', 'id': 666, 'synset': 'log.n.01', 'synonyms': ['log'], 'def': 'a segment of the trunk of a tree when stripped of branches', 'name': 'log'}, {'frequency': 'c', 'id': 667, 'synset': 'lollipop.n.02', 'synonyms': ['lollipop'], 'def': 'hard candy on a stick', 'name': 'lollipop'}, {'frequency': 'c', 'id': 668, 'synset': 'lotion.n.01', 'synonyms': ['lotion'], 'def': 'any of various cosmetic preparations that are applied to the skin', 'name': 'lotion'}, {'frequency': 'f', 'id': 669, 'synset': 'loudspeaker.n.01', 'synonyms': ['speaker_(stero_equipment)'], 'def': 'electronic device that produces sound often as part of a stereo system', 'name': 'speaker_(stero_equipment)'}, {'frequency': 'c', 'id': 670, 'synset': 'love_seat.n.01', 'synonyms': ['loveseat'], 'def': 'small sofa that seats two people', 'name': 'loveseat'}, {'frequency': 'r', 'id': 671, 'synset': 'machine_gun.n.01', 'synonyms': ['machine_gun'], 'def': 'a rapidly firing automatic gun', 'name': 'machine_gun'}, {'frequency': 'f', 'id': 672, 'synset': 'magazine.n.02', 'synonyms': ['magazine'], 'def': 'a paperback periodic publication', 'name': 'magazine'}, {'frequency': 'f', 'id': 673, 'synset': 'magnet.n.01', 'synonyms': ['magnet'], 'def': 'a device that attracts iron and produces a magnetic field', 'name': 'magnet'}, {'frequency': 'r', 'id': 674, 'synset': 'mail_slot.n.01', 'synonyms': ['mail_slot'], 'def': 'a slot (usually in a door) through which mail can be delivered', 'name': 'mail_slot'}, {'frequency': 'c', 'id': 675, 'synset': 'mailbox.n.01', 'synonyms': ['mailbox_(at_home)', 'letter_box_(at_home)'], 'def': 'a private box for delivery of mail', 'name': 'mailbox_(at_home)'}, {'frequency': 'r', 'id': 676, 'synset': 'mallet.n.01', 'synonyms': ['mallet'], 'def': 'a sports implement with a long handle and a hammer-like head used to hit a ball', 'name': 'mallet'}, {'frequency': 'r', 'id': 677, 'synset': 'mammoth.n.01', 'synonyms': ['mammoth'], 'def': 'any of numerous extinct elephants widely distributed in the Pleistocene', 'name': 'mammoth'}, {'frequency': 'c', 'id': 678, 'synset': 'mandarin.n.05', 'synonyms': ['mandarin_orange'], 'def': 'a somewhat flat reddish-orange loose skinned citrus of China', 'name': 'mandarin_orange'}, {'frequency': 'c', 'id': 679, 'synset': 'manger.n.01', 'synonyms': ['manger', 'trough'], 'def': 'a container (usually in a barn or stable) from which cattle or horses feed', 'name': 'manger'}, {'frequency': 'f', 'id': 680, 'synset': 'manhole.n.01', 'synonyms': ['manhole'], 'def': 'a hole (usually with a flush cover) through which a person can gain access to an underground structure', 'name': 'manhole'}, {'frequency': 'c', 'id': 681, 'synset': 'map.n.01', 'synonyms': ['map'], 'def': "a diagrammatic representation of the earth's surface (or part of it)", 'name': 'map'}, {'frequency': 'c', 'id': 682, 'synset': 'marker.n.03', 'synonyms': ['marker'], 'def': 'a writing implement for making a mark', 'name': 'marker'}, {'frequency': 'r', 'id': 683, 'synset': 'martini.n.01', 'synonyms': ['martini'], 'def': 'a cocktail made of gin (or vodka) with dry vermouth', 'name': 'martini'}, {'frequency': 'r', 'id': 684, 'synset': 'mascot.n.01', 'synonyms': ['mascot'], 'def': 'a person or animal that is adopted by a team or other group as a symbolic figure', 'name': 'mascot'}, {'frequency': 'c', 'id': 685, 'synset': 'mashed_potato.n.01', 'synonyms': ['mashed_potato'], 'def': 'potato that has been peeled and boiled and then mashed', 'name': 'mashed_potato'}, {'frequency': 'r', 'id': 686, 'synset': 'masher.n.02', 'synonyms': ['masher'], 'def': 'a kitchen utensil used for mashing (e.g. potatoes)', 'name': 'masher'}, {'frequency': 'f', 'id': 687, 'synset': 'mask.n.04', 'synonyms': ['mask', 'facemask'], 'def': 'a protective covering worn over the face', 'name': 'mask'}, {'frequency': 'f', 'id': 688, 'synset': 'mast.n.01', 'synonyms': ['mast'], 'def': 'a vertical spar for supporting sails', 'name': 'mast'}, {'frequency': 'c', 'id': 689, 'synset': 'mat.n.03', 'synonyms': ['mat_(gym_equipment)', 'gym_mat'], 'def': 'sports equipment consisting of a piece of thick padding on the floor for gymnastics', 'name': 'mat_(gym_equipment)'}, {'frequency': 'r', 'id': 690, 'synset': 'matchbox.n.01', 'synonyms': ['matchbox'], 'def': 'a box for holding matches', 'name': 'matchbox'}, {'frequency': 'f', 'id': 691, 'synset': 'mattress.n.01', 'synonyms': ['mattress'], 'def': 'a thick pad filled with resilient material used as a bed or part of a bed', 'name': 'mattress'}, {'frequency': 'c', 'id': 692, 'synset': 'measuring_cup.n.01', 'synonyms': ['measuring_cup'], 'def': 'graduated cup used to measure liquid or granular ingredients', 'name': 'measuring_cup'}, {'frequency': 'c', 'id': 693, 'synset': 'measuring_stick.n.01', 'synonyms': ['measuring_stick', 'ruler_(measuring_stick)', 'measuring_rod'], 'def': 'measuring instrument having a sequence of marks at regular intervals', 'name': 'measuring_stick'}, {'frequency': 'c', 'id': 694, 'synset': 'meatball.n.01', 'synonyms': ['meatball'], 'def': 'ground meat formed into a ball and fried or simmered in broth', 'name': 'meatball'}, {'frequency': 'c', 'id': 695, 'synset': 'medicine.n.02', 'synonyms': ['medicine'], 'def': 'something that treats or prevents or alleviates the symptoms of disease', 'name': 'medicine'}, {'frequency': 'r', 'id': 696, 'synset': 'melon.n.01', 'synonyms': ['melon'], 'def': 'fruit of the gourd family having a hard rind and sweet juicy flesh', 'name': 'melon'}, {'frequency': 'f', 'id': 697, 'synset': 'microphone.n.01', 'synonyms': ['microphone'], 'def': 'device for converting sound waves into electrical energy', 'name': 'microphone'}, {'frequency': 'r', 'id': 698, 'synset': 'microscope.n.01', 'synonyms': ['microscope'], 'def': 'magnifier of the image of small objects', 'name': 'microscope'}, {'frequency': 'f', 'id': 699, 'synset': 'microwave.n.02', 'synonyms': ['microwave_oven'], 'def': 'kitchen appliance that cooks food by passing an electromagnetic wave through it', 'name': 'microwave_oven'}, {'frequency': 'r', 'id': 700, 'synset': 'milestone.n.01', 'synonyms': ['milestone', 'milepost'], 'def': 'stone post at side of a road to show distances', 'name': 'milestone'}, {'frequency': 'c', 'id': 701, 'synset': 'milk.n.01', 'synonyms': ['milk'], 'def': 'a white nutritious liquid secreted by mammals and used as food by human beings', 'name': 'milk'}, {'frequency': 'f', 'id': 702, 'synset': 'minivan.n.01', 'synonyms': ['minivan'], 'def': 'a small box-shaped passenger van', 'name': 'minivan'}, {'frequency': 'r', 'id': 703, 'synset': 'mint.n.05', 'synonyms': ['mint_candy'], 'def': 'a candy that is flavored with a mint oil', 'name': 'mint_candy'}, {'frequency': 'f', 'id': 704, 'synset': 'mirror.n.01', 'synonyms': ['mirror'], 'def': 'polished surface that forms images by reflecting light', 'name': 'mirror'}, {'frequency': 'c', 'id': 705, 'synset': 'mitten.n.01', 'synonyms': ['mitten'], 'def': 'glove that encases the thumb separately and the other four fingers together', 'name': 'mitten'}, {'frequency': 'c', 'id': 706, 'synset': 'mixer.n.04', 'synonyms': ['mixer_(kitchen_tool)', 'stand_mixer'], 'def': 'a kitchen utensil that is used for mixing foods', 'name': 'mixer_(kitchen_tool)'}, {'frequency': 'c', 'id': 707, 'synset': 'money.n.03', 'synonyms': ['money'], 'def': 'the official currency issued by a government or national bank', 'name': 'money'}, {'frequency': 'f', 'id': 708, 'synset': 'monitor.n.04', 'synonyms': ['monitor_(computer_equipment) computer_monitor'], 'def': 'a computer monitor', 'name': 'monitor_(computer_equipment) computer_monitor'}, {'frequency': 'c', 'id': 709, 'synset': 'monkey.n.01', 'synonyms': ['monkey'], 'def': 'any of various long-tailed primates', 'name': 'monkey'}, {'frequency': 'f', 'id': 710, 'synset': 'motor.n.01', 'synonyms': ['motor'], 'def': 'machine that converts other forms of energy into mechanical energy and so imparts motion', 'name': 'motor'}, {'frequency': 'f', 'id': 711, 'synset': 'motor_scooter.n.01', 'synonyms': ['motor_scooter', 'scooter'], 'def': 'a wheeled vehicle with small wheels and a low-powered engine', 'name': 'motor_scooter'}, {'frequency': 'r', 'id': 712, 'synset': 'motor_vehicle.n.01', 'synonyms': ['motor_vehicle', 'automotive_vehicle'], 'def': 'a self-propelled wheeled vehicle that does not run on rails', 'name': 'motor_vehicle'}, {'frequency': 'r', 'id': 713, 'synset': 'motorboat.n.01', 'synonyms': ['motorboat', 'powerboat'], 'def': 'a boat propelled by an internal-combustion engine', 'name': 'motorboat'}, {'frequency': 'f', 'id': 714, 'synset': 'motorcycle.n.01', 'synonyms': ['motorcycle'], 'def': 'a motor vehicle with two wheels and a strong frame', 'name': 'motorcycle'}, {'frequency': 'f', 'id': 715, 'synset': 'mound.n.01', 'synonyms': ['mound_(baseball)', "pitcher's_mound"], 'def': '(baseball) the slight elevation on which the pitcher stands', 'name': 'mound_(baseball)'}, {'frequency': 'r', 'id': 716, 'synset': 'mouse.n.01', 'synonyms': ['mouse_(animal_rodent)'], 'def': 'a small rodent with pointed snouts and small ears on elongated bodies with slender usually hairless tails', 'name': 'mouse_(animal_rodent)'}, {'frequency': 'f', 'id': 717, 'synset': 'mouse.n.04', 'synonyms': ['mouse_(computer_equipment)', 'computer_mouse'], 'def': 'a computer input device that controls an on-screen pointer', 'name': 'mouse_(computer_equipment)'}, {'frequency': 'f', 'id': 718, 'synset': 'mousepad.n.01', 'synonyms': ['mousepad'], 'def': 'a small portable pad that provides an operating surface for a computer mouse', 'name': 'mousepad'}, {'frequency': 'c', 'id': 719, 'synset': 'muffin.n.01', 'synonyms': ['muffin'], 'def': 'a sweet quick bread baked in a cup-shaped pan', 'name': 'muffin'}, {'frequency': 'f', 'id': 720, 'synset': 'mug.n.04', 'synonyms': ['mug'], 'def': 'with handle and usually cylindrical', 'name': 'mug'}, {'frequency': 'f', 'id': 721, 'synset': 'mushroom.n.02', 'synonyms': ['mushroom'], 'def': 'a common mushroom', 'name': 'mushroom'}, {'frequency': 'r', 'id': 722, 'synset': 'music_stool.n.01', 'synonyms': ['music_stool', 'piano_stool'], 'def': 'a stool for piano players; usually adjustable in height', 'name': 'music_stool'}, {'frequency': 'r', 'id': 723, 'synset': 'musical_instrument.n.01', 'synonyms': ['musical_instrument', 'instrument_(musical)'], 'def': 'any of various devices or contrivances that can be used to produce musical tones or sounds', 'name': 'musical_instrument'}, {'frequency': 'r', 'id': 724, 'synset': 'nailfile.n.01', 'synonyms': ['nailfile'], 'def': 'a small flat file for shaping the nails', 'name': 'nailfile'}, {'frequency': 'r', 'id': 725, 'synset': 'nameplate.n.01', 'synonyms': ['nameplate'], 'def': 'a plate bearing a name', 'name': 'nameplate'}, {'frequency': 'f', 'id': 726, 'synset': 'napkin.n.01', 'synonyms': ['napkin', 'table_napkin', 'serviette'], 'def': 'a small piece of table linen or paper that is used to wipe the mouth and to cover the lap in order to protect clothing', 'name': 'napkin'}, {'frequency': 'r', 'id': 727, 'synset': 'neckerchief.n.01', 'synonyms': ['neckerchief'], 'def': 'a kerchief worn around the neck', 'name': 'neckerchief'}, {'frequency': 'f', 'id': 728, 'synset': 'necklace.n.01', 'synonyms': ['necklace'], 'def': 'jewelry consisting of a cord or chain (often bearing gems) worn about the neck as an ornament', 'name': 'necklace'}, {'frequency': 'f', 'id': 729, 'synset': 'necktie.n.01', 'synonyms': ['necktie', 'tie_(necktie)'], 'def': 'neckwear consisting of a long narrow piece of material worn under a collar and tied in knot at the front', 'name': 'necktie'}, {'frequency': 'r', 'id': 730, 'synset': 'needle.n.03', 'synonyms': ['needle'], 'def': 'a sharp pointed implement (usually metal)', 'name': 'needle'}, {'frequency': 'c', 'id': 731, 'synset': 'nest.n.01', 'synonyms': ['nest'], 'def': 'a structure in which animals lay eggs or give birth to their young', 'name': 'nest'}, {'frequency': 'r', 'id': 732, 'synset': 'newsstand.n.01', 'synonyms': ['newsstand'], 'def': 'a stall where newspapers and other periodicals are sold', 'name': 'newsstand'}, {'frequency': 'c', 'id': 733, 'synset': 'nightwear.n.01', 'synonyms': ['nightshirt', 'nightwear', 'sleepwear', 'nightclothes'], 'def': 'garments designed to be worn in bed', 'name': 'nightshirt'}, {'frequency': 'r', 'id': 734, 'synset': 'nosebag.n.01', 'synonyms': ['nosebag_(for_animals)', 'feedbag'], 'def': 'a canvas bag that is used to feed an animal (such as a horse); covers the muzzle and fastens at the top of the head', 'name': 'nosebag_(for_animals)'}, {'frequency': 'r', 'id': 735, 'synset': 'noseband.n.01', 'synonyms': ['noseband_(for_animals)', 'nosepiece_(for_animals)'], 'def': "a strap that is the part of a bridle that goes over the animal's nose", 'name': 'noseband_(for_animals)'}, {'frequency': 'f', 'id': 736, 'synset': 'notebook.n.01', 'synonyms': ['notebook'], 'def': 'a book with blank pages for recording notes or memoranda', 'name': 'notebook'}, {'frequency': 'c', 'id': 737, 'synset': 'notepad.n.01', 'synonyms': ['notepad'], 'def': 'a pad of paper for keeping notes', 'name': 'notepad'}, {'frequency': 'c', 'id': 738, 'synset': 'nut.n.03', 'synonyms': ['nut'], 'def': 'a small metal block (usually square or hexagonal) with internal screw thread to be fitted onto a bolt', 'name': 'nut'}, {'frequency': 'r', 'id': 739, 'synset': 'nutcracker.n.01', 'synonyms': ['nutcracker'], 'def': 'a hand tool used to crack nuts open', 'name': 'nutcracker'}, {'frequency': 'c', 'id': 740, 'synset': 'oar.n.01', 'synonyms': ['oar'], 'def': 'an implement used to propel or steer a boat', 'name': 'oar'}, {'frequency': 'r', 'id': 741, 'synset': 'octopus.n.01', 'synonyms': ['octopus_(food)'], 'def': 'tentacles of octopus prepared as food', 'name': 'octopus_(food)'}, {'frequency': 'r', 'id': 742, 'synset': 'octopus.n.02', 'synonyms': ['octopus_(animal)'], 'def': 'bottom-living cephalopod having a soft oval body with eight long tentacles', 'name': 'octopus_(animal)'}, {'frequency': 'c', 'id': 743, 'synset': 'oil_lamp.n.01', 'synonyms': ['oil_lamp', 'kerosene_lamp', 'kerosine_lamp'], 'def': 'a lamp that burns oil (as kerosine) for light', 'name': 'oil_lamp'}, {'frequency': 'c', 'id': 744, 'synset': 'olive_oil.n.01', 'synonyms': ['olive_oil'], 'def': 'oil from olives', 'name': 'olive_oil'}, {'frequency': 'r', 'id': 745, 'synset': 'omelet.n.01', 'synonyms': ['omelet', 'omelette'], 'def': 'beaten eggs cooked until just set; may be folded around e.g. ham or cheese or jelly', 'name': 'omelet'}, {'frequency': 'f', 'id': 746, 'synset': 'onion.n.01', 'synonyms': ['onion'], 'def': 'the bulb of an onion plant', 'name': 'onion'}, {'frequency': 'f', 'id': 747, 'synset': 'orange.n.01', 'synonyms': ['orange_(fruit)'], 'def': 'orange (FRUIT of an orange tree)', 'name': 'orange_(fruit)'}, {'frequency': 'c', 'id': 748, 'synset': 'orange_juice.n.01', 'synonyms': ['orange_juice'], 'def': 'bottled or freshly squeezed juice of oranges', 'name': 'orange_juice'}, {'frequency': 'r', 'id': 749, 'synset': 'oregano.n.01', 'synonyms': ['oregano', 'marjoram'], 'def': 'aromatic Eurasian perennial herb used in cooking and baking', 'name': 'oregano'}, {'frequency': 'c', 'id': 750, 'synset': 'ostrich.n.02', 'synonyms': ['ostrich'], 'def': 'fast-running African flightless bird with two-toed feet; largest living bird', 'name': 'ostrich'}, {'frequency': 'c', 'id': 751, 'synset': 'ottoman.n.03', 'synonyms': ['ottoman', 'pouf', 'pouffe', 'hassock'], 'def': 'thick cushion used as a seat', 'name': 'ottoman'}, {'frequency': 'c', 'id': 752, 'synset': 'overall.n.01', 'synonyms': ['overalls_(clothing)'], 'def': 'work clothing consisting of denim trousers usually with a bib and shoulder straps', 'name': 'overalls_(clothing)'}, {'frequency': 'c', 'id': 753, 'synset': 'owl.n.01', 'synonyms': ['owl'], 'def': 'nocturnal bird of prey with hawk-like beak and claws and large head with front-facing eyes', 'name': 'owl'}, {'frequency': 'c', 'id': 754, 'synset': 'packet.n.03', 'synonyms': ['packet'], 'def': 'a small package or bundle', 'name': 'packet'}, {'frequency': 'r', 'id': 755, 'synset': 'pad.n.03', 'synonyms': ['inkpad', 'inking_pad', 'stamp_pad'], 'def': 'absorbent material saturated with ink used to transfer ink evenly to a rubber stamp', 'name': 'inkpad'}, {'frequency': 'c', 'id': 756, 'synset': 'pad.n.04', 'synonyms': ['pad'], 'def': 'a flat mass of soft material used for protection, stuffing, or comfort', 'name': 'pad'}, {'frequency': 'c', 'id': 757, 'synset': 'paddle.n.04', 'synonyms': ['paddle', 'boat_paddle'], 'def': 'a short light oar used without an oarlock to propel a canoe or small boat', 'name': 'paddle'}, {'frequency': 'c', 'id': 758, 'synset': 'padlock.n.01', 'synonyms': ['padlock'], 'def': 'a detachable, portable lock', 'name': 'padlock'}, {'frequency': 'r', 'id': 759, 'synset': 'paintbox.n.01', 'synonyms': ['paintbox'], 'def': "a box containing a collection of cubes or tubes of artists' paint", 'name': 'paintbox'}, {'frequency': 'c', 'id': 760, 'synset': 'paintbrush.n.01', 'synonyms': ['paintbrush'], 'def': 'a brush used as an applicator to apply paint', 'name': 'paintbrush'}, {'frequency': 'f', 'id': 761, 'synset': 'painting.n.01', 'synonyms': ['painting'], 'def': 'graphic art consisting of an artistic composition made by applying paints to a surface', 'name': 'painting'}, {'frequency': 'c', 'id': 762, 'synset': 'pajama.n.02', 'synonyms': ['pajamas', 'pyjamas'], 'def': 'loose-fitting nightclothes worn for sleeping or lounging', 'name': 'pajamas'}, {'frequency': 'c', 'id': 763, 'synset': 'palette.n.02', 'synonyms': ['palette', 'pallet'], 'def': 'board that provides a flat surface on which artists mix paints and the range of colors used', 'name': 'palette'}, {'frequency': 'f', 'id': 764, 'synset': 'pan.n.01', 'synonyms': ['pan_(for_cooking)', 'cooking_pan'], 'def': 'cooking utensil consisting of a wide metal vessel', 'name': 'pan_(for_cooking)'}, {'frequency': 'r', 'id': 765, 'synset': 'pan.n.03', 'synonyms': ['pan_(metal_container)'], 'def': 'shallow container made of metal', 'name': 'pan_(metal_container)'}, {'frequency': 'c', 'id': 766, 'synset': 'pancake.n.01', 'synonyms': ['pancake'], 'def': 'a flat cake of thin batter fried on both sides on a griddle', 'name': 'pancake'}, {'frequency': 'r', 'id': 767, 'synset': 'pantyhose.n.01', 'synonyms': ['pantyhose'], 'def': "a woman's tights consisting of underpants and stockings", 'name': 'pantyhose'}, {'frequency': 'r', 'id': 768, 'synset': 'papaya.n.02', 'synonyms': ['papaya'], 'def': 'large oval melon-like tropical fruit with yellowish flesh', 'name': 'papaya'}, {'frequency': 'r', 'id': 769, 'synset': 'paper_clip.n.01', 'synonyms': ['paperclip'], 'def': 'a wire or plastic clip for holding sheets of paper together', 'name': 'paperclip'}, {'frequency': 'f', 'id': 770, 'synset': 'paper_plate.n.01', 'synonyms': ['paper_plate'], 'def': 'a disposable plate made of cardboard', 'name': 'paper_plate'}, {'frequency': 'f', 'id': 771, 'synset': 'paper_towel.n.01', 'synonyms': ['paper_towel'], 'def': 'a disposable towel made of absorbent paper', 'name': 'paper_towel'}, {'frequency': 'r', 'id': 772, 'synset': 'paperback_book.n.01', 'synonyms': ['paperback_book', 'paper-back_book', 'softback_book', 'soft-cover_book'], 'def': 'a book with paper covers', 'name': 'paperback_book'}, {'frequency': 'r', 'id': 773, 'synset': 'paperweight.n.01', 'synonyms': ['paperweight'], 'def': 'a weight used to hold down a stack of papers', 'name': 'paperweight'}, {'frequency': 'c', 'id': 774, 'synset': 'parachute.n.01', 'synonyms': ['parachute'], 'def': 'rescue equipment consisting of a device that fills with air and retards your fall', 'name': 'parachute'}, {'frequency': 'r', 'id': 775, 'synset': 'parakeet.n.01', 'synonyms': ['parakeet', 'parrakeet', 'parroket', 'paraquet', 'paroquet', 'parroquet'], 'def': 'any of numerous small slender long-tailed parrots', 'name': 'parakeet'}, {'frequency': 'c', 'id': 776, 'synset': 'parasail.n.01', 'synonyms': ['parasail_(sports)'], 'def': 'parachute that will lift a person up into the air when it is towed by a motorboat or a car', 'name': 'parasail_(sports)'}, {'frequency': 'r', 'id': 777, 'synset': 'parchment.n.01', 'synonyms': ['parchment'], 'def': 'a superior paper resembling sheepskin', 'name': 'parchment'}, {'frequency': 'r', 'id': 778, 'synset': 'parka.n.01', 'synonyms': ['parka', 'anorak'], 'def': "a kind of heavy jacket (`windcheater' is a British term)", 'name': 'parka'}, {'frequency': 'f', 'id': 779, 'synset': 'parking_meter.n.01', 'synonyms': ['parking_meter'], 'def': 'a coin-operated timer located next to a parking space', 'name': 'parking_meter'}, {'frequency': 'c', 'id': 780, 'synset': 'parrot.n.01', 'synonyms': ['parrot'], 'def': 'usually brightly colored tropical birds with short hooked beaks and the ability to mimic sounds', 'name': 'parrot'}, {'frequency': 'c', 'id': 781, 'synset': 'passenger_car.n.01', 'synonyms': ['passenger_car_(part_of_a_train)', 'coach_(part_of_a_train)'], 'def': 'a railcar where passengers ride', 'name': 'passenger_car_(part_of_a_train)'}, {'frequency': 'r', 'id': 782, 'synset': 'passenger_ship.n.01', 'synonyms': ['passenger_ship'], 'def': 'a ship built to carry passengers', 'name': 'passenger_ship'}, {'frequency': 'r', 'id': 783, 'synset': 'passport.n.02', 'synonyms': ['passport'], 'def': 'a document issued by a country to a citizen allowing that person to travel abroad and re-enter the home country', 'name': 'passport'}, {'frequency': 'f', 'id': 784, 'synset': 'pastry.n.02', 'synonyms': ['pastry'], 'def': 'any of various baked foods made of dough or batter', 'name': 'pastry'}, {'frequency': 'r', 'id': 785, 'synset': 'patty.n.01', 'synonyms': ['patty_(food)'], 'def': 'small flat mass of chopped food', 'name': 'patty_(food)'}, {'frequency': 'c', 'id': 786, 'synset': 'pea.n.01', 'synonyms': ['pea_(food)'], 'def': 'seed of a pea plant used for food', 'name': 'pea_(food)'}, {'frequency': 'c', 'id': 787, 'synset': 'peach.n.03', 'synonyms': ['peach'], 'def': 'downy juicy fruit with sweet yellowish or whitish flesh', 'name': 'peach'}, {'frequency': 'c', 'id': 788, 'synset': 'peanut_butter.n.01', 'synonyms': ['peanut_butter'], 'def': 'a spread made from ground peanuts', 'name': 'peanut_butter'}, {'frequency': 'c', 'id': 789, 'synset': 'pear.n.01', 'synonyms': ['pear'], 'def': 'sweet juicy gritty-textured fruit available in many varieties', 'name': 'pear'}, {'frequency': 'r', 'id': 790, 'synset': 'peeler.n.03', 'synonyms': ['peeler_(tool_for_fruit_and_vegetables)'], 'def': 'a device for peeling vegetables or fruits', 'name': 'peeler_(tool_for_fruit_and_vegetables)'}, {'frequency': 'r', 'id': 791, 'synset': 'pegboard.n.01', 'synonyms': ['pegboard'], 'def': 'a board perforated with regularly spaced holes into which pegs can be fitted', 'name': 'pegboard'}, {'frequency': 'c', 'id': 792, 'synset': 'pelican.n.01', 'synonyms': ['pelican'], 'def': 'large long-winged warm-water seabird having a large bill with a distensible pouch for fish', 'name': 'pelican'}, {'frequency': 'f', 'id': 793, 'synset': 'pen.n.01', 'synonyms': ['pen'], 'def': 'a writing implement with a point from which ink flows', 'name': 'pen'}, {'frequency': 'c', 'id': 794, 'synset': 'pencil.n.01', 'synonyms': ['pencil'], 'def': 'a thin cylindrical pointed writing implement made of wood and graphite', 'name': 'pencil'}, {'frequency': 'r', 'id': 795, 'synset': 'pencil_box.n.01', 'synonyms': ['pencil_box', 'pencil_case'], 'def': 'a box for holding pencils', 'name': 'pencil_box'}, {'frequency': 'r', 'id': 796, 'synset': 'pencil_sharpener.n.01', 'synonyms': ['pencil_sharpener'], 'def': 'a rotary implement for sharpening the point on pencils', 'name': 'pencil_sharpener'}, {'frequency': 'r', 'id': 797, 'synset': 'pendulum.n.01', 'synonyms': ['pendulum'], 'def': 'an apparatus consisting of an object mounted so that it swings freely under the influence of gravity', 'name': 'pendulum'}, {'frequency': 'c', 'id': 798, 'synset': 'penguin.n.01', 'synonyms': ['penguin'], 'def': 'short-legged flightless birds of cold southern regions having webbed feet and wings modified as flippers', 'name': 'penguin'}, {'frequency': 'r', 'id': 799, 'synset': 'pennant.n.02', 'synonyms': ['pennant'], 'def': 'a flag longer than it is wide (and often tapering)', 'name': 'pennant'}, {'frequency': 'r', 'id': 800, 'synset': 'penny.n.02', 'synonyms': ['penny_(coin)'], 'def': 'a coin worth one-hundredth of the value of the basic unit', 'name': 'penny_(coin)'}, {'frequency': 'c', 'id': 801, 'synset': 'pepper.n.03', 'synonyms': ['pepper', 'peppercorn'], 'def': 'pungent seasoning from the berry of the common pepper plant; whole or ground', 'name': 'pepper'}, {'frequency': 'c', 'id': 802, 'synset': 'pepper_mill.n.01', 'synonyms': ['pepper_mill', 'pepper_grinder'], 'def': 'a mill for grinding pepper', 'name': 'pepper_mill'}, {'frequency': 'c', 'id': 803, 'synset': 'perfume.n.02', 'synonyms': ['perfume'], 'def': 'a toiletry that emits and diffuses a fragrant odor', 'name': 'perfume'}, {'frequency': 'r', 'id': 804, 'synset': 'persimmon.n.02', 'synonyms': ['persimmon'], 'def': 'orange fruit resembling a plum; edible when fully ripe', 'name': 'persimmon'}, {'frequency': 'f', 'id': 805, 'synset': 'person.n.01', 'synonyms': ['baby', 'child', 'boy', 'girl', 'man', 'woman', 'person', 'human'], 'def': 'a human being', 'name': 'baby'}, {'frequency': 'r', 'id': 806, 'synset': 'pet.n.01', 'synonyms': ['pet'], 'def': 'a domesticated animal kept for companionship or amusement', 'name': 'pet'}, {'frequency': 'r', 'id': 807, 'synset': 'petfood.n.01', 'synonyms': ['petfood', 'pet-food'], 'def': 'food prepared for animal pets', 'name': 'petfood'}, {'frequency': 'r', 'id': 808, 'synset': 'pew.n.01', 'synonyms': ['pew_(church_bench)', 'church_bench'], 'def': 'long bench with backs; used in church by the congregation', 'name': 'pew_(church_bench)'}, {'frequency': 'r', 'id': 809, 'synset': 'phonebook.n.01', 'synonyms': ['phonebook', 'telephone_book', 'telephone_directory'], 'def': 'a directory containing an alphabetical list of telephone subscribers and their telephone numbers', 'name': 'phonebook'}, {'frequency': 'c', 'id': 810, 'synset': 'phonograph_record.n.01', 'synonyms': ['phonograph_record', 'phonograph_recording', 'record_(phonograph_recording)'], 'def': 'sound recording consisting of a typically black disk with a continuous groove', 'name': 'phonograph_record'}, {'frequency': 'c', 'id': 811, 'synset': 'piano.n.01', 'synonyms': ['piano'], 'def': 'a keyboard instrument that is played by depressing keys that cause hammers to strike tuned strings and produce sounds', 'name': 'piano'}, {'frequency': 'f', 'id': 812, 'synset': 'pickle.n.01', 'synonyms': ['pickle'], 'def': 'vegetables (especially cucumbers) preserved in brine or vinegar', 'name': 'pickle'}, {'frequency': 'f', 'id': 813, 'synset': 'pickup.n.01', 'synonyms': ['pickup_truck'], 'def': 'a light truck with an open body and low sides and a tailboard', 'name': 'pickup_truck'}, {'frequency': 'c', 'id': 814, 'synset': 'pie.n.01', 'synonyms': ['pie'], 'def': 'dish baked in pastry-lined pan often with a pastry top', 'name': 'pie'}, {'frequency': 'c', 'id': 815, 'synset': 'pigeon.n.01', 'synonyms': ['pigeon'], 'def': 'wild and domesticated birds having a heavy body and short legs', 'name': 'pigeon'}, {'frequency': 'r', 'id': 816, 'synset': 'piggy_bank.n.01', 'synonyms': ['piggy_bank', 'penny_bank'], 'def': "a child's coin bank (often shaped like a pig)", 'name': 'piggy_bank'}, {'frequency': 'f', 'id': 817, 'synset': 'pillow.n.01', 'synonyms': ['pillow'], 'def': 'a cushion to support the head of a sleeping person', 'name': 'pillow'}, {'frequency': 'r', 'id': 818, 'synset': 'pin.n.09', 'synonyms': ['pin_(non_jewelry)'], 'def': 'a small slender (often pointed) piece of wood or metal used to support or fasten or attach things', 'name': 'pin_(non_jewelry)'}, {'frequency': 'f', 'id': 819, 'synset': 'pineapple.n.02', 'synonyms': ['pineapple'], 'def': 'large sweet fleshy tropical fruit with a tuft of stiff leaves', 'name': 'pineapple'}, {'frequency': 'c', 'id': 820, 'synset': 'pinecone.n.01', 'synonyms': ['pinecone'], 'def': 'the seed-producing cone of a pine tree', 'name': 'pinecone'}, {'frequency': 'r', 'id': 821, 'synset': 'ping-pong_ball.n.01', 'synonyms': ['ping-pong_ball'], 'def': 'light hollow ball used in playing table tennis', 'name': 'ping-pong_ball'}, {'frequency': 'r', 'id': 822, 'synset': 'pinwheel.n.03', 'synonyms': ['pinwheel'], 'def': 'a toy consisting of vanes of colored paper or plastic that is pinned to a stick and spins when it is pointed into the wind', 'name': 'pinwheel'}, {'frequency': 'r', 'id': 823, 'synset': 'pipe.n.01', 'synonyms': ['tobacco_pipe'], 'def': 'a tube with a small bowl at one end; used for smoking tobacco', 'name': 'tobacco_pipe'}, {'frequency': 'f', 'id': 824, 'synset': 'pipe.n.02', 'synonyms': ['pipe', 'piping'], 'def': 'a long tube made of metal or plastic that is used to carry water or oil or gas etc.', 'name': 'pipe'}, {'frequency': 'r', 'id': 825, 'synset': 'pistol.n.01', 'synonyms': ['pistol', 'handgun'], 'def': 'a firearm that is held and fired with one hand', 'name': 'pistol'}, {'frequency': 'r', 'id': 826, 'synset': 'pita.n.01', 'synonyms': ['pita_(bread)', 'pocket_bread'], 'def': 'usually small round bread that can open into a pocket for filling', 'name': 'pita_(bread)'}, {'frequency': 'f', 'id': 827, 'synset': 'pitcher.n.02', 'synonyms': ['pitcher_(vessel_for_liquid)', 'ewer'], 'def': 'an open vessel with a handle and a spout for pouring', 'name': 'pitcher_(vessel_for_liquid)'}, {'frequency': 'r', 'id': 828, 'synset': 'pitchfork.n.01', 'synonyms': ['pitchfork'], 'def': 'a long-handled hand tool with sharp widely spaced prongs for lifting and pitching hay', 'name': 'pitchfork'}, {'frequency': 'f', 'id': 829, 'synset': 'pizza.n.01', 'synonyms': ['pizza'], 'def': 'Italian open pie made of thin bread dough spread with a spiced mixture of e.g. tomato sauce and cheese', 'name': 'pizza'}, {'frequency': 'f', 'id': 830, 'synset': 'place_mat.n.01', 'synonyms': ['place_mat'], 'def': 'a mat placed on a table for an individual place setting', 'name': 'place_mat'}, {'frequency': 'f', 'id': 831, 'synset': 'plate.n.04', 'synonyms': ['plate'], 'def': 'dish on which food is served or from which food is eaten', 'name': 'plate'}, {'frequency': 'c', 'id': 832, 'synset': 'platter.n.01', 'synonyms': ['platter'], 'def': 'a large shallow dish used for serving food', 'name': 'platter'}, {'frequency': 'r', 'id': 833, 'synset': 'playing_card.n.01', 'synonyms': ['playing_card'], 'def': 'one of a pack of cards that are used to play card games', 'name': 'playing_card'}, {'frequency': 'r', 'id': 834, 'synset': 'playpen.n.01', 'synonyms': ['playpen'], 'def': 'a portable enclosure in which babies may be left to play', 'name': 'playpen'}, {'frequency': 'c', 'id': 835, 'synset': 'pliers.n.01', 'synonyms': ['pliers', 'plyers'], 'def': 'a gripping hand tool with two hinged arms and (usually) serrated jaws', 'name': 'pliers'}, {'frequency': 'r', 'id': 836, 'synset': 'plow.n.01', 'synonyms': ['plow_(farm_equipment)', 'plough_(farm_equipment)'], 'def': 'a farm tool having one or more heavy blades to break the soil and cut a furrow prior to sowing', 'name': 'plow_(farm_equipment)'}, {'frequency': 'r', 'id': 837, 'synset': 'pocket_watch.n.01', 'synonyms': ['pocket_watch'], 'def': 'a watch that is carried in a small watch pocket', 'name': 'pocket_watch'}, {'frequency': 'c', 'id': 838, 'synset': 'pocketknife.n.01', 'synonyms': ['pocketknife'], 'def': 'a knife with a blade that folds into the handle; suitable for carrying in the pocket', 'name': 'pocketknife'}, {'frequency': 'c', 'id': 839, 'synset': 'poker.n.01', 'synonyms': ['poker_(fire_stirring_tool)', 'stove_poker', 'fire_hook'], 'def': 'fire iron consisting of a metal rod with a handle; used to stir a fire', 'name': 'poker_(fire_stirring_tool)'}, {'frequency': 'f', 'id': 840, 'synset': 'pole.n.01', 'synonyms': ['pole', 'post'], 'def': 'a long (usually round) rod of wood or metal or plastic', 'name': 'pole'}, {'frequency': 'r', 'id': 841, 'synset': 'police_van.n.01', 'synonyms': ['police_van', 'police_wagon', 'paddy_wagon', 'patrol_wagon'], 'def': 'van used by police to transport prisoners', 'name': 'police_van'}, {'frequency': 'f', 'id': 842, 'synset': 'polo_shirt.n.01', 'synonyms': ['polo_shirt', 'sport_shirt'], 'def': 'a shirt with short sleeves designed for comfort and casual wear', 'name': 'polo_shirt'}, {'frequency': 'r', 'id': 843, 'synset': 'poncho.n.01', 'synonyms': ['poncho'], 'def': 'a blanket-like cloak with a hole in the center for the head', 'name': 'poncho'}, {'frequency': 'c', 'id': 844, 'synset': 'pony.n.05', 'synonyms': ['pony'], 'def': 'any of various breeds of small gentle horses usually less than five feet high at the shoulder', 'name': 'pony'}, {'frequency': 'r', 'id': 845, 'synset': 'pool_table.n.01', 'synonyms': ['pool_table', 'billiard_table', 'snooker_table'], 'def': 'game equipment consisting of a heavy table on which pool is played', 'name': 'pool_table'}, {'frequency': 'f', 'id': 846, 'synset': 'pop.n.02', 'synonyms': ['pop_(soda)', 'soda_(pop)', 'tonic', 'soft_drink'], 'def': 'a sweet drink containing carbonated water and flavoring', 'name': 'pop_(soda)'}, {'frequency': 'r', 'id': 847, 'synset': 'portrait.n.02', 'synonyms': ['portrait', 'portrayal'], 'def': 'any likeness of a person, in any medium', 'name': 'portrait'}, {'frequency': 'c', 'id': 848, 'synset': 'postbox.n.01', 'synonyms': ['postbox_(public)', 'mailbox_(public)'], 'def': 'public box for deposit of mail', 'name': 'postbox_(public)'}, {'frequency': 'c', 'id': 849, 'synset': 'postcard.n.01', 'synonyms': ['postcard', 'postal_card', 'mailing-card'], 'def': 'a card for sending messages by post without an envelope', 'name': 'postcard'}, {'frequency': 'f', 'id': 850, 'synset': 'poster.n.01', 'synonyms': ['poster', 'placard'], 'def': 'a sign posted in a public place as an advertisement', 'name': 'poster'}, {'frequency': 'f', 'id': 851, 'synset': 'pot.n.01', 'synonyms': ['pot'], 'def': 'metal or earthenware cooking vessel that is usually round and deep; often has a handle and lid', 'name': 'pot'}, {'frequency': 'f', 'id': 852, 'synset': 'pot.n.04', 'synonyms': ['flowerpot'], 'def': 'a container in which plants are cultivated', 'name': 'flowerpot'}, {'frequency': 'f', 'id': 853, 'synset': 'potato.n.01', 'synonyms': ['potato'], 'def': 'an edible tuber native to South America', 'name': 'potato'}, {'frequency': 'c', 'id': 854, 'synset': 'potholder.n.01', 'synonyms': ['potholder'], 'def': 'an insulated pad for holding hot pots', 'name': 'potholder'}, {'frequency': 'c', 'id': 855, 'synset': 'pottery.n.01', 'synonyms': ['pottery', 'clayware'], 'def': 'ceramic ware made from clay and baked in a kiln', 'name': 'pottery'}, {'frequency': 'c', 'id': 856, 'synset': 'pouch.n.01', 'synonyms': ['pouch'], 'def': 'a small or medium size container for holding or carrying things', 'name': 'pouch'}, {'frequency': 'r', 'id': 857, 'synset': 'power_shovel.n.01', 'synonyms': ['power_shovel', 'excavator', 'digger'], 'def': 'a machine for excavating', 'name': 'power_shovel'}, {'frequency': 'c', 'id': 858, 'synset': 'prawn.n.01', 'synonyms': ['prawn', 'shrimp'], 'def': 'any of various edible decapod crustaceans', 'name': 'prawn'}, {'frequency': 'f', 'id': 859, 'synset': 'printer.n.03', 'synonyms': ['printer', 'printing_machine'], 'def': 'a machine that prints', 'name': 'printer'}, {'frequency': 'c', 'id': 860, 'synset': 'projectile.n.01', 'synonyms': ['projectile_(weapon)', 'missile'], 'def': 'a weapon that is forcibly thrown or projected at a targets', 'name': 'projectile_(weapon)'}, {'frequency': 'c', 'id': 861, 'synset': 'projector.n.02', 'synonyms': ['projector'], 'def': 'an optical instrument that projects an enlarged image onto a screen', 'name': 'projector'}, {'frequency': 'f', 'id': 862, 'synset': 'propeller.n.01', 'synonyms': ['propeller', 'propellor'], 'def': 'a mechanical device that rotates to push against air or water', 'name': 'propeller'}, {'frequency': 'r', 'id': 863, 'synset': 'prune.n.01', 'synonyms': ['prune'], 'def': 'dried plum', 'name': 'prune'}, {'frequency': 'r', 'id': 864, 'synset': 'pudding.n.01', 'synonyms': ['pudding'], 'def': 'any of various soft thick unsweetened baked dishes', 'name': 'pudding'}, {'frequency': 'r', 'id': 865, 'synset': 'puffer.n.02', 'synonyms': ['puffer_(fish)', 'pufferfish', 'blowfish', 'globefish'], 'def': 'fishes whose elongated spiny body can inflate itself with water or air to form a globe', 'name': 'puffer_(fish)'}, {'frequency': 'r', 'id': 866, 'synset': 'puffin.n.01', 'synonyms': ['puffin'], 'def': 'seabirds having short necks and brightly colored compressed bills', 'name': 'puffin'}, {'frequency': 'r', 'id': 867, 'synset': 'pug.n.01', 'synonyms': ['pug-dog'], 'def': 'small compact smooth-coated breed of Asiatic origin having a tightly curled tail and broad flat wrinkled muzzle', 'name': 'pug-dog'}, {'frequency': 'c', 'id': 868, 'synset': 'pumpkin.n.02', 'synonyms': ['pumpkin'], 'def': 'usually large pulpy deep-yellow round fruit of the squash family maturing in late summer or early autumn', 'name': 'pumpkin'}, {'frequency': 'r', 'id': 869, 'synset': 'punch.n.03', 'synonyms': ['puncher'], 'def': 'a tool for making holes or indentations', 'name': 'puncher'}, {'frequency': 'r', 'id': 870, 'synset': 'puppet.n.01', 'synonyms': ['puppet', 'marionette'], 'def': 'a small figure of a person operated from above with strings by a puppeteer', 'name': 'puppet'}, {'frequency': 'r', 'id': 871, 'synset': 'puppy.n.01', 'synonyms': ['puppy'], 'def': 'a young dog', 'name': 'puppy'}, {'frequency': 'r', 'id': 872, 'synset': 'quesadilla.n.01', 'synonyms': ['quesadilla'], 'def': 'a tortilla that is filled with cheese and heated', 'name': 'quesadilla'}, {'frequency': 'r', 'id': 873, 'synset': 'quiche.n.02', 'synonyms': ['quiche'], 'def': 'a tart filled with rich unsweetened custard; often contains other ingredients (as cheese or ham or seafood or vegetables)', 'name': 'quiche'}, {'frequency': 'f', 'id': 874, 'synset': 'quilt.n.01', 'synonyms': ['quilt', 'comforter'], 'def': 'bedding made of two layers of cloth filled with stuffing and stitched together', 'name': 'quilt'}, {'frequency': 'c', 'id': 875, 'synset': 'rabbit.n.01', 'synonyms': ['rabbit'], 'def': 'any of various burrowing animals of the family Leporidae having long ears and short tails', 'name': 'rabbit'}, {'frequency': 'r', 'id': 876, 'synset': 'racer.n.02', 'synonyms': ['race_car', 'racing_car'], 'def': 'a fast car that competes in races', 'name': 'race_car'}, {'frequency': 'c', 'id': 877, 'synset': 'racket.n.04', 'synonyms': ['racket', 'racquet'], 'def': 'a sports implement used to strike a ball in various games', 'name': 'racket'}, {'frequency': 'r', 'id': 878, 'synset': 'radar.n.01', 'synonyms': ['radar'], 'def': 'measuring instrument in which the echo of a pulse of microwave radiation is used to detect and locate distant objects', 'name': 'radar'}, {'frequency': 'c', 'id': 879, 'synset': 'radiator.n.03', 'synonyms': ['radiator'], 'def': 'a mechanism consisting of a metal honeycomb through which hot fluids circulate', 'name': 'radiator'}, {'frequency': 'c', 'id': 880, 'synset': 'radio_receiver.n.01', 'synonyms': ['radio_receiver', 'radio_set', 'radio', 'tuner_(radio)'], 'def': 'an electronic receiver that detects and demodulates and amplifies transmitted radio signals', 'name': 'radio_receiver'}, {'frequency': 'c', 'id': 881, 'synset': 'radish.n.03', 'synonyms': ['radish', 'daikon'], 'def': 'pungent edible root of any of various cultivated radish plants', 'name': 'radish'}, {'frequency': 'c', 'id': 882, 'synset': 'raft.n.01', 'synonyms': ['raft'], 'def': 'a flat float (usually made of logs or planks) that can be used for transport or as a platform for swimmers', 'name': 'raft'}, {'frequency': 'r', 'id': 883, 'synset': 'rag_doll.n.01', 'synonyms': ['rag_doll'], 'def': 'a cloth doll that is stuffed and (usually) painted', 'name': 'rag_doll'}, {'frequency': 'c', 'id': 884, 'synset': 'raincoat.n.01', 'synonyms': ['raincoat', 'waterproof_jacket'], 'def': 'a water-resistant coat', 'name': 'raincoat'}, {'frequency': 'c', 'id': 885, 'synset': 'ram.n.05', 'synonyms': ['ram_(animal)'], 'def': 'uncastrated adult male sheep', 'name': 'ram_(animal)'}, {'frequency': 'c', 'id': 886, 'synset': 'raspberry.n.02', 'synonyms': ['raspberry'], 'def': 'red or black edible aggregate berries usually smaller than the related blackberries', 'name': 'raspberry'}, {'frequency': 'r', 'id': 887, 'synset': 'rat.n.01', 'synonyms': ['rat'], 'def': 'any of various long-tailed rodents similar to but larger than a mouse', 'name': 'rat'}, {'frequency': 'c', 'id': 888, 'synset': 'razorblade.n.01', 'synonyms': ['razorblade'], 'def': 'a blade that has very sharp edge', 'name': 'razorblade'}, {'frequency': 'c', 'id': 889, 'synset': 'reamer.n.01', 'synonyms': ['reamer_(juicer)', 'juicer', 'juice_reamer'], 'def': 'a squeezer with a conical ridged center that is used for squeezing juice from citrus fruit', 'name': 'reamer_(juicer)'}, {'frequency': 'f', 'id': 890, 'synset': 'rearview_mirror.n.01', 'synonyms': ['rearview_mirror'], 'def': 'car mirror that reflects the view out of the rear window', 'name': 'rearview_mirror'}, {'frequency': 'c', 'id': 891, 'synset': 'receipt.n.02', 'synonyms': ['receipt'], 'def': 'an acknowledgment (usually tangible) that payment has been made', 'name': 'receipt'}, {'frequency': 'c', 'id': 892, 'synset': 'recliner.n.01', 'synonyms': ['recliner', 'reclining_chair', 'lounger_(chair)'], 'def': 'an armchair whose back can be lowered and foot can be raised to allow the sitter to recline in it', 'name': 'recliner'}, {'frequency': 'r', 'id': 893, 'synset': 'record_player.n.01', 'synonyms': ['record_player', 'phonograph_(record_player)', 'turntable'], 'def': 'machine in which rotating records cause a stylus to vibrate and the vibrations are amplified acoustically or electronically', 'name': 'record_player'}, {'frequency': 'r', 'id': 894, 'synset': 'red_cabbage.n.02', 'synonyms': ['red_cabbage'], 'def': 'compact head of purplish-red leaves', 'name': 'red_cabbage'}, {'frequency': 'f', 'id': 895, 'synset': 'reflector.n.01', 'synonyms': ['reflector'], 'def': 'device that reflects light, radiation, etc.', 'name': 'reflector'}, {'frequency': 'f', 'id': 896, 'synset': 'remote_control.n.01', 'synonyms': ['remote_control'], 'def': 'a device that can be used to control a machine or apparatus from a distance', 'name': 'remote_control'}, {'frequency': 'c', 'id': 897, 'synset': 'rhinoceros.n.01', 'synonyms': ['rhinoceros'], 'def': 'massive powerful herbivorous odd-toed ungulate of southeast Asia and Africa having very thick skin and one or two horns on the snout', 'name': 'rhinoceros'}, {'frequency': 'r', 'id': 898, 'synset': 'rib.n.03', 'synonyms': ['rib_(food)'], 'def': 'cut of meat including one or more ribs', 'name': 'rib_(food)'}, {'frequency': 'r', 'id': 899, 'synset': 'rifle.n.01', 'synonyms': ['rifle'], 'def': 'a shoulder firearm with a long barrel', 'name': 'rifle'}, {'frequency': 'f', 'id': 900, 'synset': 'ring.n.08', 'synonyms': ['ring'], 'def': 'jewelry consisting of a circlet of precious metal (often set with jewels) worn on the finger', 'name': 'ring'}, {'frequency': 'r', 'id': 901, 'synset': 'river_boat.n.01', 'synonyms': ['river_boat'], 'def': 'a boat used on rivers or to ply a river', 'name': 'river_boat'}, {'frequency': 'r', 'id': 902, 'synset': 'road_map.n.02', 'synonyms': ['road_map'], 'def': '(NOT A ROAD) a MAP showing roads (for automobile travel)', 'name': 'road_map'}, {'frequency': 'c', 'id': 903, 'synset': 'robe.n.01', 'synonyms': ['robe'], 'def': 'any loose flowing garment', 'name': 'robe'}, {'frequency': 'c', 'id': 904, 'synset': 'rocking_chair.n.01', 'synonyms': ['rocking_chair'], 'def': 'a chair mounted on rockers', 'name': 'rocking_chair'}, {'frequency': 'r', 'id': 905, 'synset': 'roller_skate.n.01', 'synonyms': ['roller_skate'], 'def': 'a shoe with pairs of rollers (small hard wheels) fixed to the sole', 'name': 'roller_skate'}, {'frequency': 'r', 'id': 906, 'synset': 'rollerblade.n.01', 'synonyms': ['Rollerblade'], 'def': 'an in-line variant of a roller skate', 'name': 'Rollerblade'}, {'frequency': 'c', 'id': 907, 'synset': 'rolling_pin.n.01', 'synonyms': ['rolling_pin'], 'def': 'utensil consisting of a cylinder (usually of wood) with a handle at each end; used to roll out dough', 'name': 'rolling_pin'}, {'frequency': 'r', 'id': 908, 'synset': 'root_beer.n.01', 'synonyms': ['root_beer'], 'def': 'carbonated drink containing extracts of roots and herbs', 'name': 'root_beer'}, {'frequency': 'c', 'id': 909, 'synset': 'router.n.02', 'synonyms': ['router_(computer_equipment)'], 'def': 'a device that forwards data packets between computer networks', 'name': 'router_(computer_equipment)'}, {'frequency': 'f', 'id': 910, 'synset': 'rubber_band.n.01', 'synonyms': ['rubber_band', 'elastic_band'], 'def': 'a narrow band of elastic rubber used to hold things (such as papers) together', 'name': 'rubber_band'}, {'frequency': 'c', 'id': 911, 'synset': 'runner.n.08', 'synonyms': ['runner_(carpet)'], 'def': 'a long narrow carpet', 'name': 'runner_(carpet)'}, {'frequency': 'f', 'id': 912, 'synset': 'sack.n.01', 'synonyms': ['plastic_bag', 'paper_bag'], 'def': "a bag made of paper or plastic for holding customer's purchases", 'name': 'plastic_bag'}, {'frequency': 'f', 'id': 913, 'synset': 'saddle.n.01', 'synonyms': ['saddle_(on_an_animal)'], 'def': 'a seat for the rider of a horse or camel', 'name': 'saddle_(on_an_animal)'}, {'frequency': 'f', 'id': 914, 'synset': 'saddle_blanket.n.01', 'synonyms': ['saddle_blanket', 'saddlecloth', 'horse_blanket'], 'def': 'stable gear consisting of a blanket placed under the saddle', 'name': 'saddle_blanket'}, {'frequency': 'c', 'id': 915, 'synset': 'saddlebag.n.01', 'synonyms': ['saddlebag'], 'def': 'a large bag (or pair of bags) hung over a saddle', 'name': 'saddlebag'}, {'frequency': 'r', 'id': 916, 'synset': 'safety_pin.n.01', 'synonyms': ['safety_pin'], 'def': 'a pin in the form of a clasp; has a guard so the point of the pin will not stick the user', 'name': 'safety_pin'}, {'frequency': 'c', 'id': 917, 'synset': 'sail.n.01', 'synonyms': ['sail'], 'def': 'a large piece of fabric by means of which wind is used to propel a sailing vessel', 'name': 'sail'}, {'frequency': 'c', 'id': 918, 'synset': 'salad.n.01', 'synonyms': ['salad'], 'def': 'food mixtures either arranged on a plate or tossed and served with a moist dressing; usually consisting of or including greens', 'name': 'salad'}, {'frequency': 'r', 'id': 919, 'synset': 'salad_plate.n.01', 'synonyms': ['salad_plate', 'salad_bowl'], 'def': 'a plate or bowl for individual servings of salad', 'name': 'salad_plate'}, {'frequency': 'r', 'id': 920, 'synset': 'salami.n.01', 'synonyms': ['salami'], 'def': 'highly seasoned fatty sausage of pork and beef usually dried', 'name': 'salami'}, {'frequency': 'r', 'id': 921, 'synset': 'salmon.n.01', 'synonyms': ['salmon_(fish)'], 'def': 'any of various large food and game fishes of northern waters', 'name': 'salmon_(fish)'}, {'frequency': 'r', 'id': 922, 'synset': 'salmon.n.03', 'synonyms': ['salmon_(food)'], 'def': 'flesh of any of various marine or freshwater fish of the family Salmonidae', 'name': 'salmon_(food)'}, {'frequency': 'r', 'id': 923, 'synset': 'salsa.n.01', 'synonyms': ['salsa'], 'def': 'spicy sauce of tomatoes and onions and chili peppers to accompany Mexican foods', 'name': 'salsa'}, {'frequency': 'f', 'id': 924, 'synset': 'saltshaker.n.01', 'synonyms': ['saltshaker'], 'def': 'a shaker with a perforated top for sprinkling salt', 'name': 'saltshaker'}, {'frequency': 'f', 'id': 925, 'synset': 'sandal.n.01', 'synonyms': ['sandal_(type_of_shoe)'], 'def': 'a shoe consisting of a sole fastened by straps to the foot', 'name': 'sandal_(type_of_shoe)'}, {'frequency': 'f', 'id': 926, 'synset': 'sandwich.n.01', 'synonyms': ['sandwich'], 'def': 'two (or more) slices of bread with a filling between them', 'name': 'sandwich'}, {'frequency': 'r', 'id': 927, 'synset': 'satchel.n.01', 'synonyms': ['satchel'], 'def': 'luggage consisting of a small case with a flat bottom and (usually) a shoulder strap', 'name': 'satchel'}, {'frequency': 'r', 'id': 928, 'synset': 'saucepan.n.01', 'synonyms': ['saucepan'], 'def': 'a deep pan with a handle; used for stewing or boiling', 'name': 'saucepan'}, {'frequency': 'f', 'id': 929, 'synset': 'saucer.n.02', 'synonyms': ['saucer'], 'def': 'a small shallow dish for holding a cup at the table', 'name': 'saucer'}, {'frequency': 'f', 'id': 930, 'synset': 'sausage.n.01', 'synonyms': ['sausage'], 'def': 'highly seasoned minced meat stuffed in casings', 'name': 'sausage'}, {'frequency': 'r', 'id': 931, 'synset': 'sawhorse.n.01', 'synonyms': ['sawhorse', 'sawbuck'], 'def': 'a framework for holding wood that is being sawed', 'name': 'sawhorse'}, {'frequency': 'r', 'id': 932, 'synset': 'sax.n.02', 'synonyms': ['saxophone'], 'def': "a wind instrument with a `J'-shaped form typically made of brass", 'name': 'saxophone'}, {'frequency': 'f', 'id': 933, 'synset': 'scale.n.07', 'synonyms': ['scale_(measuring_instrument)'], 'def': 'a measuring instrument for weighing; shows amount of mass', 'name': 'scale_(measuring_instrument)'}, {'frequency': 'r', 'id': 934, 'synset': 'scarecrow.n.01', 'synonyms': ['scarecrow', 'strawman'], 'def': 'an effigy in the shape of a man to frighten birds away from seeds', 'name': 'scarecrow'}, {'frequency': 'f', 'id': 935, 'synset': 'scarf.n.01', 'synonyms': ['scarf'], 'def': 'a garment worn around the head or neck or shoulders for warmth or decoration', 'name': 'scarf'}, {'frequency': 'c', 'id': 936, 'synset': 'school_bus.n.01', 'synonyms': ['school_bus'], 'def': 'a bus used to transport children to or from school', 'name': 'school_bus'}, {'frequency': 'f', 'id': 937, 'synset': 'scissors.n.01', 'synonyms': ['scissors'], 'def': 'a tool having two crossed pivoting blades with looped handles', 'name': 'scissors'}, {'frequency': 'c', 'id': 938, 'synset': 'scoreboard.n.01', 'synonyms': ['scoreboard'], 'def': 'a large board for displaying the score of a contest (and some other information)', 'name': 'scoreboard'}, {'frequency': 'c', 'id': 939, 'synset': 'scrambled_eggs.n.01', 'synonyms': ['scrambled_eggs'], 'def': 'eggs beaten and cooked to a soft firm consistency while stirring', 'name': 'scrambled_eggs'}, {'frequency': 'r', 'id': 940, 'synset': 'scraper.n.01', 'synonyms': ['scraper'], 'def': 'any of various hand tools for scraping', 'name': 'scraper'}, {'frequency': 'r', 'id': 941, 'synset': 'scratcher.n.03', 'synonyms': ['scratcher'], 'def': 'a device used for scratching', 'name': 'scratcher'}, {'frequency': 'c', 'id': 942, 'synset': 'screwdriver.n.01', 'synonyms': ['screwdriver'], 'def': 'a hand tool for driving screws; has a tip that fits into the head of a screw', 'name': 'screwdriver'}, {'frequency': 'c', 'id': 943, 'synset': 'scrub_brush.n.01', 'synonyms': ['scrubbing_brush'], 'def': 'a brush with short stiff bristles for heavy cleaning', 'name': 'scrubbing_brush'}, {'frequency': 'c', 'id': 944, 'synset': 'sculpture.n.01', 'synonyms': ['sculpture'], 'def': 'a three-dimensional work of art', 'name': 'sculpture'}, {'frequency': 'r', 'id': 945, 'synset': 'seabird.n.01', 'synonyms': ['seabird', 'seafowl'], 'def': 'a bird that frequents coastal waters and the open ocean: gulls; pelicans; gannets; cormorants; albatrosses; petrels; etc.', 'name': 'seabird'}, {'frequency': 'r', 'id': 946, 'synset': 'seahorse.n.02', 'synonyms': ['seahorse'], 'def': 'small fish with horse-like heads bent sharply downward and curled tails', 'name': 'seahorse'}, {'frequency': 'r', 'id': 947, 'synset': 'seaplane.n.01', 'synonyms': ['seaplane', 'hydroplane'], 'def': 'an airplane that can land on or take off from water', 'name': 'seaplane'}, {'frequency': 'c', 'id': 948, 'synset': 'seashell.n.01', 'synonyms': ['seashell'], 'def': 'the shell of a marine organism', 'name': 'seashell'}, {'frequency': 'r', 'id': 949, 'synset': 'seedling.n.01', 'synonyms': ['seedling'], 'def': 'young plant or tree grown from a seed', 'name': 'seedling'}, {'frequency': 'c', 'id': 950, 'synset': 'serving_dish.n.01', 'synonyms': ['serving_dish'], 'def': 'a dish used for serving food', 'name': 'serving_dish'}, {'frequency': 'r', 'id': 951, 'synset': 'sewing_machine.n.01', 'synonyms': ['sewing_machine'], 'def': 'a textile machine used as a home appliance for sewing', 'name': 'sewing_machine'}, {'frequency': 'r', 'id': 952, 'synset': 'shaker.n.03', 'synonyms': ['shaker'], 'def': 'a container in which something can be shaken', 'name': 'shaker'}, {'frequency': 'c', 'id': 953, 'synset': 'shampoo.n.01', 'synonyms': ['shampoo'], 'def': 'cleansing agent consisting of soaps or detergents used for washing the hair', 'name': 'shampoo'}, {'frequency': 'r', 'id': 954, 'synset': 'shark.n.01', 'synonyms': ['shark'], 'def': 'typically large carnivorous fishes with sharpe teeth', 'name': 'shark'}, {'frequency': 'r', 'id': 955, 'synset': 'sharpener.n.01', 'synonyms': ['sharpener'], 'def': 'any implement that is used to make something (an edge or a point) sharper', 'name': 'sharpener'}, {'frequency': 'r', 'id': 956, 'synset': 'sharpie.n.03', 'synonyms': ['Sharpie'], 'def': 'a pen with indelible ink that will write on any surface', 'name': 'Sharpie'}, {'frequency': 'r', 'id': 957, 'synset': 'shaver.n.03', 'synonyms': ['shaver_(electric)', 'electric_shaver', 'electric_razor'], 'def': 'a razor powered by an electric motor', 'name': 'shaver_(electric)'}, {'frequency': 'c', 'id': 958, 'synset': 'shaving_cream.n.01', 'synonyms': ['shaving_cream', 'shaving_soap'], 'def': 'toiletry consisting that forms a rich lather for softening the beard before shaving', 'name': 'shaving_cream'}, {'frequency': 'r', 'id': 959, 'synset': 'shawl.n.01', 'synonyms': ['shawl'], 'def': 'cloak consisting of an oblong piece of cloth used to cover the head and shoulders', 'name': 'shawl'}, {'frequency': 'r', 'id': 960, 'synset': 'shears.n.01', 'synonyms': ['shears'], 'def': 'large scissors with strong blades', 'name': 'shears'}, {'frequency': 'f', 'id': 961, 'synset': 'sheep.n.01', 'synonyms': ['sheep'], 'def': 'woolly usually horned ruminant mammal related to the goat', 'name': 'sheep'}, {'frequency': 'r', 'id': 962, 'synset': 'shepherd_dog.n.01', 'synonyms': ['shepherd_dog', 'sheepdog'], 'def': 'any of various usually long-haired breeds of dog reared to herd and guard sheep', 'name': 'shepherd_dog'}, {'frequency': 'r', 'id': 963, 'synset': 'sherbert.n.01', 'synonyms': ['sherbert', 'sherbet'], 'def': 'a frozen dessert made primarily of fruit juice and sugar', 'name': 'sherbert'}, {'frequency': 'r', 'id': 964, 'synset': 'shield.n.02', 'synonyms': ['shield'], 'def': 'armor carried on the arm to intercept blows', 'name': 'shield'}, {'frequency': 'f', 'id': 965, 'synset': 'shirt.n.01', 'synonyms': ['shirt'], 'def': 'a garment worn on the upper half of the body', 'name': 'shirt'}, {'frequency': 'f', 'id': 966, 'synset': 'shoe.n.01', 'synonyms': ['shoe', 'sneaker_(type_of_shoe)', 'tennis_shoe'], 'def': 'common footwear covering the foot', 'name': 'shoe'}, {'frequency': 'c', 'id': 967, 'synset': 'shopping_bag.n.01', 'synonyms': ['shopping_bag'], 'def': 'a bag made of plastic or strong paper (often with handles); used to transport goods after shopping', 'name': 'shopping_bag'}, {'frequency': 'c', 'id': 968, 'synset': 'shopping_cart.n.01', 'synonyms': ['shopping_cart'], 'def': 'a handcart that holds groceries or other goods while shopping', 'name': 'shopping_cart'}, {'frequency': 'f', 'id': 969, 'synset': 'short_pants.n.01', 'synonyms': ['short_pants', 'shorts_(clothing)', 'trunks_(clothing)'], 'def': 'trousers that end at or above the knee', 'name': 'short_pants'}, {'frequency': 'r', 'id': 970, 'synset': 'shot_glass.n.01', 'synonyms': ['shot_glass'], 'def': 'a small glass adequate to hold a single swallow of whiskey', 'name': 'shot_glass'}, {'frequency': 'c', 'id': 971, 'synset': 'shoulder_bag.n.01', 'synonyms': ['shoulder_bag'], 'def': 'a large handbag that can be carried by a strap looped over the shoulder', 'name': 'shoulder_bag'}, {'frequency': 'c', 'id': 972, 'synset': 'shovel.n.01', 'synonyms': ['shovel'], 'def': 'a hand tool for lifting loose material such as snow, dirt, etc.', 'name': 'shovel'}, {'frequency': 'f', 'id': 973, 'synset': 'shower.n.01', 'synonyms': ['shower_head'], 'def': 'a plumbing fixture that sprays water over you', 'name': 'shower_head'}, {'frequency': 'f', 'id': 974, 'synset': 'shower_curtain.n.01', 'synonyms': ['shower_curtain'], 'def': 'a curtain that keeps water from splashing out of the shower area', 'name': 'shower_curtain'}, {'frequency': 'r', 'id': 975, 'synset': 'shredder.n.01', 'synonyms': ['shredder_(for_paper)'], 'def': 'a device that shreds documents', 'name': 'shredder_(for_paper)'}, {'frequency': 'r', 'id': 976, 'synset': 'sieve.n.01', 'synonyms': ['sieve', 'screen_(sieve)'], 'def': 'a strainer for separating lumps from powdered material or grading particles', 'name': 'sieve'}, {'frequency': 'f', 'id': 977, 'synset': 'signboard.n.01', 'synonyms': ['signboard'], 'def': 'structure displaying a board on which advertisements can be posted', 'name': 'signboard'}, {'frequency': 'c', 'id': 978, 'synset': 'silo.n.01', 'synonyms': ['silo'], 'def': 'a cylindrical tower used for storing goods', 'name': 'silo'}, {'frequency': 'f', 'id': 979, 'synset': 'sink.n.01', 'synonyms': ['sink'], 'def': 'plumbing fixture consisting of a water basin fixed to a wall or floor and having a drainpipe', 'name': 'sink'}, {'frequency': 'f', 'id': 980, 'synset': 'skateboard.n.01', 'synonyms': ['skateboard'], 'def': 'a board with wheels that is ridden in a standing or crouching position and propelled by foot', 'name': 'skateboard'}, {'frequency': 'c', 'id': 981, 'synset': 'skewer.n.01', 'synonyms': ['skewer'], 'def': 'a long pin for holding meat in position while it is being roasted', 'name': 'skewer'}, {'frequency': 'f', 'id': 982, 'synset': 'ski.n.01', 'synonyms': ['ski'], 'def': 'sports equipment for skiing on snow', 'name': 'ski'}, {'frequency': 'f', 'id': 983, 'synset': 'ski_boot.n.01', 'synonyms': ['ski_boot'], 'def': 'a stiff boot that is fastened to a ski with a ski binding', 'name': 'ski_boot'}, {'frequency': 'f', 'id': 984, 'synset': 'ski_parka.n.01', 'synonyms': ['ski_parka', 'ski_jacket'], 'def': 'a parka to be worn while skiing', 'name': 'ski_parka'}, {'frequency': 'f', 'id': 985, 'synset': 'ski_pole.n.01', 'synonyms': ['ski_pole'], 'def': 'a pole with metal points used as an aid in skiing', 'name': 'ski_pole'}, {'frequency': 'f', 'id': 986, 'synset': 'skirt.n.02', 'synonyms': ['skirt'], 'def': 'a garment hanging from the waist; worn mainly by girls and women', 'name': 'skirt'}, {'frequency': 'c', 'id': 987, 'synset': 'sled.n.01', 'synonyms': ['sled', 'sledge', 'sleigh'], 'def': 'a vehicle or flat object for transportation over snow by sliding or pulled by dogs, etc.', 'name': 'sled'}, {'frequency': 'c', 'id': 988, 'synset': 'sleeping_bag.n.01', 'synonyms': ['sleeping_bag'], 'def': 'large padded bag designed to be slept in outdoors', 'name': 'sleeping_bag'}, {'frequency': 'r', 'id': 989, 'synset': 'sling.n.05', 'synonyms': ['sling_(bandage)', 'triangular_bandage'], 'def': 'bandage to support an injured forearm; slung over the shoulder or neck', 'name': 'sling_(bandage)'}, {'frequency': 'c', 'id': 990, 'synset': 'slipper.n.01', 'synonyms': ['slipper_(footwear)', 'carpet_slipper_(footwear)'], 'def': 'low footwear that can be slipped on and off easily; usually worn indoors', 'name': 'slipper_(footwear)'}, {'frequency': 'r', 'id': 991, 'synset': 'smoothie.n.02', 'synonyms': ['smoothie'], 'def': 'a thick smooth drink consisting of fresh fruit pureed with ice cream or yoghurt or milk', 'name': 'smoothie'}, {'frequency': 'r', 'id': 992, 'synset': 'snake.n.01', 'synonyms': ['snake', 'serpent'], 'def': 'limbless scaly elongate reptile; some are venomous', 'name': 'snake'}, {'frequency': 'f', 'id': 993, 'synset': 'snowboard.n.01', 'synonyms': ['snowboard'], 'def': 'a board that resembles a broad ski or a small surfboard; used in a standing position to slide down snow-covered slopes', 'name': 'snowboard'}, {'frequency': 'c', 'id': 994, 'synset': 'snowman.n.01', 'synonyms': ['snowman'], 'def': 'a figure of a person made of packed snow', 'name': 'snowman'}, {'frequency': 'c', 'id': 995, 'synset': 'snowmobile.n.01', 'synonyms': ['snowmobile'], 'def': 'tracked vehicle for travel on snow having skis in front', 'name': 'snowmobile'}, {'frequency': 'f', 'id': 996, 'synset': 'soap.n.01', 'synonyms': ['soap'], 'def': 'a cleansing agent made from the salts of vegetable or animal fats', 'name': 'soap'}, {'frequency': 'f', 'id': 997, 'synset': 'soccer_ball.n.01', 'synonyms': ['soccer_ball'], 'def': "an inflated ball used in playing soccer (called `football' outside of the United States)", 'name': 'soccer_ball'}, {'frequency': 'f', 'id': 998, 'synset': 'sock.n.01', 'synonyms': ['sock'], 'def': 'cloth covering for the foot; worn inside the shoe; reaches to between the ankle and the knee', 'name': 'sock'}, {'frequency': 'r', 'id': 999, 'synset': 'soda_fountain.n.02', 'synonyms': ['soda_fountain'], 'def': 'an apparatus for dispensing soda water', 'name': 'soda_fountain'}, {'frequency': 'r', 'id': 1000, 'synset': 'soda_water.n.01', 'synonyms': ['carbonated_water', 'club_soda', 'seltzer', 'sparkling_water'], 'def': 'effervescent beverage artificially charged with carbon dioxide', 'name': 'carbonated_water'}, {'frequency': 'f', 'id': 1001, 'synset': 'sofa.n.01', 'synonyms': ['sofa', 'couch', 'lounge'], 'def': 'an upholstered seat for more than one person', 'name': 'sofa'}, {'frequency': 'r', 'id': 1002, 'synset': 'softball.n.01', 'synonyms': ['softball'], 'def': 'ball used in playing softball', 'name': 'softball'}, {'frequency': 'c', 'id': 1003, 'synset': 'solar_array.n.01', 'synonyms': ['solar_array', 'solar_battery', 'solar_panel'], 'def': 'electrical device consisting of a large array of connected solar cells', 'name': 'solar_array'}, {'frequency': 'r', 'id': 1004, 'synset': 'sombrero.n.02', 'synonyms': ['sombrero'], 'def': 'a straw hat with a tall crown and broad brim; worn in American southwest and in Mexico', 'name': 'sombrero'}, {'frequency': 'c', 'id': 1005, 'synset': 'soup.n.01', 'synonyms': ['soup'], 'def': 'liquid food especially of meat or fish or vegetable stock often containing pieces of solid food', 'name': 'soup'}, {'frequency': 'r', 'id': 1006, 'synset': 'soup_bowl.n.01', 'synonyms': ['soup_bowl'], 'def': 'a bowl for serving soup', 'name': 'soup_bowl'}, {'frequency': 'c', 'id': 1007, 'synset': 'soupspoon.n.01', 'synonyms': ['soupspoon'], 'def': 'a spoon with a rounded bowl for eating soup', 'name': 'soupspoon'}, {'frequency': 'c', 'id': 1008, 'synset': 'sour_cream.n.01', 'synonyms': ['sour_cream', 'soured_cream'], 'def': 'soured light cream', 'name': 'sour_cream'}, {'frequency': 'r', 'id': 1009, 'synset': 'soya_milk.n.01', 'synonyms': ['soya_milk', 'soybean_milk', 'soymilk'], 'def': 'a milk substitute containing soybean flour and water; used in some infant formulas and in making tofu', 'name': 'soya_milk'}, {'frequency': 'r', 'id': 1010, 'synset': 'space_shuttle.n.01', 'synonyms': ['space_shuttle'], 'def': "a reusable spacecraft with wings for a controlled descent through the Earth's atmosphere", 'name': 'space_shuttle'}, {'frequency': 'r', 'id': 1011, 'synset': 'sparkler.n.02', 'synonyms': ['sparkler_(fireworks)'], 'def': 'a firework that burns slowly and throws out a shower of sparks', 'name': 'sparkler_(fireworks)'}, {'frequency': 'f', 'id': 1012, 'synset': 'spatula.n.02', 'synonyms': ['spatula'], 'def': 'a hand tool with a thin flexible blade used to mix or spread soft substances', 'name': 'spatula'}, {'frequency': 'r', 'id': 1013, 'synset': 'spear.n.01', 'synonyms': ['spear', 'lance'], 'def': 'a long pointed rod used as a tool or weapon', 'name': 'spear'}, {'frequency': 'f', 'id': 1014, 'synset': 'spectacles.n.01', 'synonyms': ['spectacles', 'specs', 'eyeglasses', 'glasses'], 'def': 'optical instrument consisting of a frame that holds a pair of lenses for correcting defective vision', 'name': 'spectacles'}, {'frequency': 'c', 'id': 1015, 'synset': 'spice_rack.n.01', 'synonyms': ['spice_rack'], 'def': 'a rack for displaying containers filled with spices', 'name': 'spice_rack'}, {'frequency': 'r', 'id': 1016, 'synset': 'spider.n.01', 'synonyms': ['spider'], 'def': 'predatory arachnid with eight legs, two poison fangs, two feelers, and usually two silk-spinning organs at the back end of the body', 'name': 'spider'}, {'frequency': 'c', 'id': 1017, 'synset': 'sponge.n.01', 'synonyms': ['sponge'], 'def': 'a porous mass usable to absorb water typically used for cleaning', 'name': 'sponge'}, {'frequency': 'f', 'id': 1018, 'synset': 'spoon.n.01', 'synonyms': ['spoon'], 'def': 'a piece of cutlery with a shallow bowl-shaped container and a handle', 'name': 'spoon'}, {'frequency': 'c', 'id': 1019, 'synset': 'sportswear.n.01', 'synonyms': ['sportswear', 'athletic_wear', 'activewear'], 'def': 'attire worn for sport or for casual wear', 'name': 'sportswear'}, {'frequency': 'c', 'id': 1020, 'synset': 'spotlight.n.02', 'synonyms': ['spotlight'], 'def': 'a lamp that produces a strong beam of light to illuminate a restricted area; used to focus attention of a stage performer', 'name': 'spotlight'}, {'frequency': 'r', 'id': 1021, 'synset': 'squirrel.n.01', 'synonyms': ['squirrel'], 'def': 'a kind of arboreal rodent having a long bushy tail', 'name': 'squirrel'}, {'frequency': 'c', 'id': 1022, 'synset': 'stapler.n.01', 'synonyms': ['stapler_(stapling_machine)'], 'def': 'a machine that inserts staples into sheets of paper in order to fasten them together', 'name': 'stapler_(stapling_machine)'}, {'frequency': 'r', 'id': 1023, 'synset': 'starfish.n.01', 'synonyms': ['starfish', 'sea_star'], 'def': 'echinoderms characterized by five arms extending from a central disk', 'name': 'starfish'}, {'frequency': 'f', 'id': 1024, 'synset': 'statue.n.01', 'synonyms': ['statue_(sculpture)'], 'def': 'a sculpture representing a human or animal', 'name': 'statue_(sculpture)'}, {'frequency': 'c', 'id': 1025, 'synset': 'steak.n.01', 'synonyms': ['steak_(food)'], 'def': 'a slice of meat cut from the fleshy part of an animal or large fish', 'name': 'steak_(food)'}, {'frequency': 'r', 'id': 1026, 'synset': 'steak_knife.n.01', 'synonyms': ['steak_knife'], 'def': 'a sharp table knife used in eating steak', 'name': 'steak_knife'}, {'frequency': 'r', 'id': 1027, 'synset': 'steamer.n.02', 'synonyms': ['steamer_(kitchen_appliance)'], 'def': 'a cooking utensil that can be used to cook food by steaming it', 'name': 'steamer_(kitchen_appliance)'}, {'frequency': 'f', 'id': 1028, 'synset': 'steering_wheel.n.01', 'synonyms': ['steering_wheel'], 'def': 'a handwheel that is used for steering', 'name': 'steering_wheel'}, {'frequency': 'r', 'id': 1029, 'synset': 'stencil.n.01', 'synonyms': ['stencil'], 'def': 'a sheet of material (metal, plastic, etc.) that has been perforated with a pattern; ink or paint can pass through the perforations to create the printed pattern on the surface below', 'name': 'stencil'}, {'frequency': 'r', 'id': 1030, 'synset': 'step_ladder.n.01', 'synonyms': ['stepladder'], 'def': 'a folding portable ladder hinged at the top', 'name': 'stepladder'}, {'frequency': 'c', 'id': 1031, 'synset': 'step_stool.n.01', 'synonyms': ['step_stool'], 'def': 'a stool that has one or two steps that fold under the seat', 'name': 'step_stool'}, {'frequency': 'c', 'id': 1032, 'synset': 'stereo.n.01', 'synonyms': ['stereo_(sound_system)'], 'def': 'electronic device for playing audio', 'name': 'stereo_(sound_system)'}, {'frequency': 'r', 'id': 1033, 'synset': 'stew.n.02', 'synonyms': ['stew'], 'def': 'food prepared by stewing especially meat or fish with vegetables', 'name': 'stew'}, {'frequency': 'r', 'id': 1034, 'synset': 'stirrer.n.02', 'synonyms': ['stirrer'], 'def': 'an implement used for stirring', 'name': 'stirrer'}, {'frequency': 'f', 'id': 1035, 'synset': 'stirrup.n.01', 'synonyms': ['stirrup'], 'def': "support consisting of metal loops into which rider's feet go", 'name': 'stirrup'}, {'frequency': 'c', 'id': 1036, 'synset': 'stocking.n.01', 'synonyms': ['stockings_(leg_wear)'], 'def': 'close-fitting hosiery to cover the foot and leg; come in matched pairs', 'name': 'stockings_(leg_wear)'}, {'frequency': 'f', 'id': 1037, 'synset': 'stool.n.01', 'synonyms': ['stool'], 'def': 'a simple seat without a back or arms', 'name': 'stool'}, {'frequency': 'f', 'id': 1038, 'synset': 'stop_sign.n.01', 'synonyms': ['stop_sign'], 'def': 'a traffic sign to notify drivers that they must come to a complete stop', 'name': 'stop_sign'}, {'frequency': 'f', 'id': 1039, 'synset': 'stoplight.n.01', 'synonyms': ['brake_light'], 'def': 'a red light on the rear of a motor vehicle that signals when the brakes are applied', 'name': 'brake_light'}, {'frequency': 'f', 'id': 1040, 'synset': 'stove.n.01', 'synonyms': ['stove', 'kitchen_stove', 'range_(kitchen_appliance)', 'kitchen_range', 'cooking_stove'], 'def': 'a kitchen appliance used for cooking food', 'name': 'stove'}, {'frequency': 'c', 'id': 1041, 'synset': 'strainer.n.01', 'synonyms': ['strainer'], 'def': 'a filter to retain larger pieces while smaller pieces and liquids pass through', 'name': 'strainer'}, {'frequency': 'f', 'id': 1042, 'synset': 'strap.n.01', 'synonyms': ['strap'], 'def': 'an elongated strip of material for binding things together or holding', 'name': 'strap'}, {'frequency': 'f', 'id': 1043, 'synset': 'straw.n.04', 'synonyms': ['straw_(for_drinking)', 'drinking_straw'], 'def': 'a thin paper or plastic tube used to suck liquids into the mouth', 'name': 'straw_(for_drinking)'}, {'frequency': 'f', 'id': 1044, 'synset': 'strawberry.n.01', 'synonyms': ['strawberry'], 'def': 'sweet fleshy red fruit', 'name': 'strawberry'}, {'frequency': 'f', 'id': 1045, 'synset': 'street_sign.n.01', 'synonyms': ['street_sign'], 'def': 'a sign visible from the street', 'name': 'street_sign'}, {'frequency': 'f', 'id': 1046, 'synset': 'streetlight.n.01', 'synonyms': ['streetlight', 'street_lamp'], 'def': 'a lamp supported on a lamppost; for illuminating a street', 'name': 'streetlight'}, {'frequency': 'r', 'id': 1047, 'synset': 'string_cheese.n.01', 'synonyms': ['string_cheese'], 'def': 'cheese formed in long strings twisted together', 'name': 'string_cheese'}, {'frequency': 'r', 'id': 1048, 'synset': 'stylus.n.02', 'synonyms': ['stylus'], 'def': 'a pointed tool for writing or drawing or engraving', 'name': 'stylus'}, {'frequency': 'r', 'id': 1049, 'synset': 'subwoofer.n.01', 'synonyms': ['subwoofer'], 'def': 'a loudspeaker that is designed to reproduce very low bass frequencies', 'name': 'subwoofer'}, {'frequency': 'r', 'id': 1050, 'synset': 'sugar_bowl.n.01', 'synonyms': ['sugar_bowl'], 'def': 'a dish in which sugar is served', 'name': 'sugar_bowl'}, {'frequency': 'r', 'id': 1051, 'synset': 'sugarcane.n.01', 'synonyms': ['sugarcane_(plant)'], 'def': 'juicy canes whose sap is a source of molasses and commercial sugar; fresh canes are sometimes chewed for the juice', 'name': 'sugarcane_(plant)'}, {'frequency': 'c', 'id': 1052, 'synset': 'suit.n.01', 'synonyms': ['suit_(clothing)'], 'def': 'a set of garments (usually including a jacket and trousers or skirt) for outerwear all of the same fabric and color', 'name': 'suit_(clothing)'}, {'frequency': 'c', 'id': 1053, 'synset': 'sunflower.n.01', 'synonyms': ['sunflower'], 'def': 'any plant of the genus Helianthus having large flower heads with dark disk florets and showy yellow rays', 'name': 'sunflower'}, {'frequency': 'f', 'id': 1054, 'synset': 'sunglasses.n.01', 'synonyms': ['sunglasses'], 'def': 'spectacles that are darkened or polarized to protect the eyes from the glare of the sun', 'name': 'sunglasses'}, {'frequency': 'c', 'id': 1055, 'synset': 'sunhat.n.01', 'synonyms': ['sunhat'], 'def': 'a hat with a broad brim that protects the face from direct exposure to the sun', 'name': 'sunhat'}, {'frequency': 'r', 'id': 1056, 'synset': 'sunscreen.n.01', 'synonyms': ['sunscreen', 'sunblock'], 'def': 'a cream spread on the skin; contains a chemical to filter out ultraviolet light and so protect from sunburn', 'name': 'sunscreen'}, {'frequency': 'f', 'id': 1057, 'synset': 'surfboard.n.01', 'synonyms': ['surfboard'], 'def': 'a narrow buoyant board for riding surf', 'name': 'surfboard'}, {'frequency': 'c', 'id': 1058, 'synset': 'sushi.n.01', 'synonyms': ['sushi'], 'def': 'rice (with raw fish) wrapped in seaweed', 'name': 'sushi'}, {'frequency': 'c', 'id': 1059, 'synset': 'swab.n.02', 'synonyms': ['mop'], 'def': 'cleaning implement consisting of absorbent material fastened to a handle; for cleaning floors', 'name': 'mop'}, {'frequency': 'c', 'id': 1060, 'synset': 'sweat_pants.n.01', 'synonyms': ['sweat_pants'], 'def': 'loose-fitting trousers with elastic cuffs; worn by athletes', 'name': 'sweat_pants'}, {'frequency': 'c', 'id': 1061, 'synset': 'sweatband.n.02', 'synonyms': ['sweatband'], 'def': 'a band of material tied around the forehead or wrist to absorb sweat', 'name': 'sweatband'}, {'frequency': 'f', 'id': 1062, 'synset': 'sweater.n.01', 'synonyms': ['sweater'], 'def': 'a crocheted or knitted garment covering the upper part of the body', 'name': 'sweater'}, {'frequency': 'f', 'id': 1063, 'synset': 'sweatshirt.n.01', 'synonyms': ['sweatshirt'], 'def': 'cotton knit pullover with long sleeves worn during athletic activity', 'name': 'sweatshirt'}, {'frequency': 'c', 'id': 1064, 'synset': 'sweet_potato.n.02', 'synonyms': ['sweet_potato'], 'def': 'the edible tuberous root of the sweet potato vine', 'name': 'sweet_potato'}, {'frequency': 'f', 'id': 1065, 'synset': 'swimsuit.n.01', 'synonyms': ['swimsuit', 'swimwear', 'bathing_suit', 'swimming_costume', 'bathing_costume', 'swimming_trunks', 'bathing_trunks'], 'def': 'garment worn for swimming', 'name': 'swimsuit'}, {'frequency': 'c', 'id': 1066, 'synset': 'sword.n.01', 'synonyms': ['sword'], 'def': 'a cutting or thrusting weapon that has a long metal blade', 'name': 'sword'}, {'frequency': 'r', 'id': 1067, 'synset': 'syringe.n.01', 'synonyms': ['syringe'], 'def': 'a medical instrument used to inject or withdraw fluids', 'name': 'syringe'}, {'frequency': 'r', 'id': 1068, 'synset': 'tabasco.n.02', 'synonyms': ['Tabasco_sauce'], 'def': 'very spicy sauce (trade name Tabasco) made from fully-aged red peppers', 'name': 'Tabasco_sauce'}, {'frequency': 'r', 'id': 1069, 'synset': 'table-tennis_table.n.01', 'synonyms': ['table-tennis_table', 'ping-pong_table'], 'def': 'a table used for playing table tennis', 'name': 'table-tennis_table'}, {'frequency': 'f', 'id': 1070, 'synset': 'table.n.02', 'synonyms': ['table'], 'def': 'a piece of furniture having a smooth flat top that is usually supported by one or more vertical legs', 'name': 'table'}, {'frequency': 'c', 'id': 1071, 'synset': 'table_lamp.n.01', 'synonyms': ['table_lamp'], 'def': 'a lamp that sits on a table', 'name': 'table_lamp'}, {'frequency': 'f', 'id': 1072, 'synset': 'tablecloth.n.01', 'synonyms': ['tablecloth'], 'def': 'a covering spread over a dining table', 'name': 'tablecloth'}, {'frequency': 'r', 'id': 1073, 'synset': 'tachometer.n.01', 'synonyms': ['tachometer'], 'def': 'measuring instrument for indicating speed of rotation', 'name': 'tachometer'}, {'frequency': 'r', 'id': 1074, 'synset': 'taco.n.02', 'synonyms': ['taco'], 'def': 'a small tortilla cupped around a filling', 'name': 'taco'}, {'frequency': 'f', 'id': 1075, 'synset': 'tag.n.02', 'synonyms': ['tag'], 'def': 'a label associated with something for the purpose of identification or information', 'name': 'tag'}, {'frequency': 'f', 'id': 1076, 'synset': 'taillight.n.01', 'synonyms': ['taillight', 'rear_light'], 'def': 'lamp (usually red) mounted at the rear of a motor vehicle', 'name': 'taillight'}, {'frequency': 'r', 'id': 1077, 'synset': 'tambourine.n.01', 'synonyms': ['tambourine'], 'def': 'a shallow drum with a single drumhead and with metallic disks in the sides', 'name': 'tambourine'}, {'frequency': 'r', 'id': 1078, 'synset': 'tank.n.01', 'synonyms': ['army_tank', 'armored_combat_vehicle', 'armoured_combat_vehicle'], 'def': 'an enclosed armored military vehicle; has a cannon and moves on caterpillar treads', 'name': 'army_tank'}, {'frequency': 'c', 'id': 1079, 'synset': 'tank.n.02', 'synonyms': ['tank_(storage_vessel)', 'storage_tank'], 'def': 'a large (usually metallic) vessel for holding gases or liquids', 'name': 'tank_(storage_vessel)'}, {'frequency': 'f', 'id': 1080, 'synset': 'tank_top.n.01', 'synonyms': ['tank_top_(clothing)'], 'def': 'a tight-fitting sleeveless shirt with wide shoulder straps and low neck and no front opening', 'name': 'tank_top_(clothing)'}, {'frequency': 'c', 'id': 1081, 'synset': 'tape.n.01', 'synonyms': ['tape_(sticky_cloth_or_paper)'], 'def': 'a long thin piece of cloth or paper as used for binding or fastening', 'name': 'tape_(sticky_cloth_or_paper)'}, {'frequency': 'c', 'id': 1082, 'synset': 'tape.n.04', 'synonyms': ['tape_measure', 'measuring_tape'], 'def': 'measuring instrument consisting of a narrow strip (cloth or metal) marked in inches or centimeters and used for measuring lengths', 'name': 'tape_measure'}, {'frequency': 'c', 'id': 1083, 'synset': 'tapestry.n.02', 'synonyms': ['tapestry'], 'def': 'a heavy textile with a woven design; used for curtains and upholstery', 'name': 'tapestry'}, {'frequency': 'f', 'id': 1084, 'synset': 'tarpaulin.n.01', 'synonyms': ['tarp'], 'def': 'waterproofed canvas', 'name': 'tarp'}, {'frequency': 'c', 'id': 1085, 'synset': 'tartan.n.01', 'synonyms': ['tartan', 'plaid'], 'def': 'a cloth having a crisscross design', 'name': 'tartan'}, {'frequency': 'c', 'id': 1086, 'synset': 'tassel.n.01', 'synonyms': ['tassel'], 'def': 'adornment consisting of a bunch of cords fastened at one end', 'name': 'tassel'}, {'frequency': 'r', 'id': 1087, 'synset': 'tea_bag.n.01', 'synonyms': ['tea_bag'], 'def': 'a measured amount of tea in a bag for an individual serving of tea', 'name': 'tea_bag'}, {'frequency': 'c', 'id': 1088, 'synset': 'teacup.n.02', 'synonyms': ['teacup'], 'def': 'a cup from which tea is drunk', 'name': 'teacup'}, {'frequency': 'c', 'id': 1089, 'synset': 'teakettle.n.01', 'synonyms': ['teakettle'], 'def': 'kettle for boiling water to make tea', 'name': 'teakettle'}, {'frequency': 'c', 'id': 1090, 'synset': 'teapot.n.01', 'synonyms': ['teapot'], 'def': 'pot for brewing tea; usually has a spout and handle', 'name': 'teapot'}, {'frequency': 'f', 'id': 1091, 'synset': 'teddy.n.01', 'synonyms': ['teddy_bear'], 'def': "plaything consisting of a child's toy bear (usually plush and stuffed with soft materials)", 'name': 'teddy_bear'}, {'frequency': 'f', 'id': 1092, 'synset': 'telephone.n.01', 'synonyms': ['telephone', 'phone', 'telephone_set'], 'def': 'electronic device for communicating by voice over long distances', 'name': 'telephone'}, {'frequency': 'c', 'id': 1093, 'synset': 'telephone_booth.n.01', 'synonyms': ['telephone_booth', 'phone_booth', 'call_box', 'telephone_box', 'telephone_kiosk'], 'def': 'booth for using a telephone', 'name': 'telephone_booth'}, {'frequency': 'f', 'id': 1094, 'synset': 'telephone_pole.n.01', 'synonyms': ['telephone_pole', 'telegraph_pole', 'telegraph_post'], 'def': 'tall pole supporting telephone wires', 'name': 'telephone_pole'}, {'frequency': 'r', 'id': 1095, 'synset': 'telephoto_lens.n.01', 'synonyms': ['telephoto_lens', 'zoom_lens'], 'def': 'a camera lens that magnifies the image', 'name': 'telephoto_lens'}, {'frequency': 'c', 'id': 1096, 'synset': 'television_camera.n.01', 'synonyms': ['television_camera', 'tv_camera'], 'def': 'television equipment for capturing and recording video', 'name': 'television_camera'}, {'frequency': 'f', 'id': 1097, 'synset': 'television_receiver.n.01', 'synonyms': ['television_set', 'tv', 'tv_set'], 'def': 'an electronic device that receives television signals and displays them on a screen', 'name': 'television_set'}, {'frequency': 'f', 'id': 1098, 'synset': 'tennis_ball.n.01', 'synonyms': ['tennis_ball'], 'def': 'ball about the size of a fist used in playing tennis', 'name': 'tennis_ball'}, {'frequency': 'f', 'id': 1099, 'synset': 'tennis_racket.n.01', 'synonyms': ['tennis_racket'], 'def': 'a racket used to play tennis', 'name': 'tennis_racket'}, {'frequency': 'r', 'id': 1100, 'synset': 'tequila.n.01', 'synonyms': ['tequila'], 'def': 'Mexican liquor made from fermented juices of an agave plant', 'name': 'tequila'}, {'frequency': 'c', 'id': 1101, 'synset': 'thermometer.n.01', 'synonyms': ['thermometer'], 'def': 'measuring instrument for measuring temperature', 'name': 'thermometer'}, {'frequency': 'c', 'id': 1102, 'synset': 'thermos.n.01', 'synonyms': ['thermos_bottle'], 'def': 'vacuum flask that preserves temperature of hot or cold drinks', 'name': 'thermos_bottle'}, {'frequency': 'c', 'id': 1103, 'synset': 'thermostat.n.01', 'synonyms': ['thermostat'], 'def': 'a regulator for automatically regulating temperature by starting or stopping the supply of heat', 'name': 'thermostat'}, {'frequency': 'r', 'id': 1104, 'synset': 'thimble.n.02', 'synonyms': ['thimble'], 'def': 'a small metal cap to protect the finger while sewing; can be used as a small container', 'name': 'thimble'}, {'frequency': 'c', 'id': 1105, 'synset': 'thread.n.01', 'synonyms': ['thread', 'yarn'], 'def': 'a fine cord of twisted fibers (of cotton or silk or wool or nylon etc.) used in sewing and weaving', 'name': 'thread'}, {'frequency': 'c', 'id': 1106, 'synset': 'thumbtack.n.01', 'synonyms': ['thumbtack', 'drawing_pin', 'pushpin'], 'def': 'a tack for attaching papers to a bulletin board or drawing board', 'name': 'thumbtack'}, {'frequency': 'c', 'id': 1107, 'synset': 'tiara.n.01', 'synonyms': ['tiara'], 'def': 'a jeweled headdress worn by women on formal occasions', 'name': 'tiara'}, {'frequency': 'c', 'id': 1108, 'synset': 'tiger.n.02', 'synonyms': ['tiger'], 'def': 'large feline of forests in most of Asia having a tawny coat with black stripes', 'name': 'tiger'}, {'frequency': 'c', 'id': 1109, 'synset': 'tights.n.01', 'synonyms': ['tights_(clothing)', 'leotards'], 'def': 'skintight knit hose covering the body from the waist to the feet worn by acrobats and dancers and as stockings by women and girls', 'name': 'tights_(clothing)'}, {'frequency': 'c', 'id': 1110, 'synset': 'timer.n.01', 'synonyms': ['timer', 'stopwatch'], 'def': 'a timepiece that measures a time interval and signals its end', 'name': 'timer'}, {'frequency': 'f', 'id': 1111, 'synset': 'tinfoil.n.01', 'synonyms': ['tinfoil'], 'def': 'foil made of tin or an alloy of tin and lead', 'name': 'tinfoil'}, {'frequency': 'r', 'id': 1112, 'synset': 'tinsel.n.01', 'synonyms': ['tinsel'], 'def': 'a showy decoration that is basically valueless', 'name': 'tinsel'}, {'frequency': 'f', 'id': 1113, 'synset': 'tissue.n.02', 'synonyms': ['tissue_paper'], 'def': 'a soft thin (usually translucent) paper', 'name': 'tissue_paper'}, {'frequency': 'c', 'id': 1114, 'synset': 'toast.n.01', 'synonyms': ['toast_(food)'], 'def': 'slice of bread that has been toasted', 'name': 'toast_(food)'}, {'frequency': 'f', 'id': 1115, 'synset': 'toaster.n.02', 'synonyms': ['toaster'], 'def': 'a kitchen appliance (usually electric) for toasting bread', 'name': 'toaster'}, {'frequency': 'c', 'id': 1116, 'synset': 'toaster_oven.n.01', 'synonyms': ['toaster_oven'], 'def': 'kitchen appliance consisting of a small electric oven for toasting or warming food', 'name': 'toaster_oven'}, {'frequency': 'f', 'id': 1117, 'synset': 'toilet.n.02', 'synonyms': ['toilet'], 'def': 'a plumbing fixture for defecation and urination', 'name': 'toilet'}, {'frequency': 'f', 'id': 1118, 'synset': 'toilet_tissue.n.01', 'synonyms': ['toilet_tissue', 'toilet_paper', 'bathroom_tissue'], 'def': 'a soft thin absorbent paper for use in toilets', 'name': 'toilet_tissue'}, {'frequency': 'f', 'id': 1119, 'synset': 'tomato.n.01', 'synonyms': ['tomato'], 'def': 'mildly acid red or yellow pulpy fruit eaten as a vegetable', 'name': 'tomato'}, {'frequency': 'c', 'id': 1120, 'synset': 'tongs.n.01', 'synonyms': ['tongs'], 'def': 'any of various devices for taking hold of objects; usually have two hinged legs with handles above and pointed hooks below', 'name': 'tongs'}, {'frequency': 'c', 'id': 1121, 'synset': 'toolbox.n.01', 'synonyms': ['toolbox'], 'def': 'a box or chest or cabinet for holding hand tools', 'name': 'toolbox'}, {'frequency': 'f', 'id': 1122, 'synset': 'toothbrush.n.01', 'synonyms': ['toothbrush'], 'def': 'small brush; has long handle; used to clean teeth', 'name': 'toothbrush'}, {'frequency': 'f', 'id': 1123, 'synset': 'toothpaste.n.01', 'synonyms': ['toothpaste'], 'def': 'a dentifrice in the form of a paste', 'name': 'toothpaste'}, {'frequency': 'c', 'id': 1124, 'synset': 'toothpick.n.01', 'synonyms': ['toothpick'], 'def': 'pick consisting of a small strip of wood or plastic; used to pick food from between the teeth', 'name': 'toothpick'}, {'frequency': 'c', 'id': 1125, 'synset': 'top.n.09', 'synonyms': ['cover'], 'def': 'covering for a hole (especially a hole in the top of a container)', 'name': 'cover'}, {'frequency': 'c', 'id': 1126, 'synset': 'tortilla.n.01', 'synonyms': ['tortilla'], 'def': 'thin unleavened pancake made from cornmeal or wheat flour', 'name': 'tortilla'}, {'frequency': 'c', 'id': 1127, 'synset': 'tow_truck.n.01', 'synonyms': ['tow_truck'], 'def': 'a truck equipped to hoist and pull wrecked cars (or to remove cars from no-parking zones)', 'name': 'tow_truck'}, {'frequency': 'f', 'id': 1128, 'synset': 'towel.n.01', 'synonyms': ['towel'], 'def': 'a rectangular piece of absorbent cloth (or paper) for drying or wiping', 'name': 'towel'}, {'frequency': 'f', 'id': 1129, 'synset': 'towel_rack.n.01', 'synonyms': ['towel_rack', 'towel_rail', 'towel_bar'], 'def': 'a rack consisting of one or more bars on which towels can be hung', 'name': 'towel_rack'}, {'frequency': 'f', 'id': 1130, 'synset': 'toy.n.03', 'synonyms': ['toy'], 'def': 'a device regarded as providing amusement', 'name': 'toy'}, {'frequency': 'c', 'id': 1131, 'synset': 'tractor.n.01', 'synonyms': ['tractor_(farm_equipment)'], 'def': 'a wheeled vehicle with large wheels; used in farming and other applications', 'name': 'tractor_(farm_equipment)'}, {'frequency': 'f', 'id': 1132, 'synset': 'traffic_light.n.01', 'synonyms': ['traffic_light'], 'def': 'a device to control vehicle traffic often consisting of three or more lights', 'name': 'traffic_light'}, {'frequency': 'r', 'id': 1133, 'synset': 'trail_bike.n.01', 'synonyms': ['dirt_bike'], 'def': 'a lightweight motorcycle equipped with rugged tires and suspension for off-road use', 'name': 'dirt_bike'}, {'frequency': 'c', 'id': 1134, 'synset': 'trailer_truck.n.01', 'synonyms': ['trailer_truck', 'tractor_trailer', 'trucking_rig', 'articulated_lorry', 'semi_truck'], 'def': 'a truck consisting of a tractor and trailer together', 'name': 'trailer_truck'}, {'frequency': 'f', 'id': 1135, 'synset': 'train.n.01', 'synonyms': ['train_(railroad_vehicle)', 'railroad_train'], 'def': 'public or private transport provided by a line of railway cars coupled together and drawn by a locomotive', 'name': 'train_(railroad_vehicle)'}, {'frequency': 'r', 'id': 1136, 'synset': 'trampoline.n.01', 'synonyms': ['trampoline'], 'def': 'gymnastic apparatus consisting of a strong canvas sheet attached with springs to a metal frame', 'name': 'trampoline'}, {'frequency': 'f', 'id': 1137, 'synset': 'tray.n.01', 'synonyms': ['tray'], 'def': 'an open receptacle for holding or displaying or serving articles or food', 'name': 'tray'}, {'frequency': 'r', 'id': 1138, 'synset': 'tree_house.n.01', 'synonyms': ['tree_house'], 'def': '(NOT A TREE) a PLAYHOUSE built in the branches of a tree', 'name': 'tree_house'}, {'frequency': 'r', 'id': 1139, 'synset': 'trench_coat.n.01', 'synonyms': ['trench_coat'], 'def': 'a military style raincoat; belted with deep pockets', 'name': 'trench_coat'}, {'frequency': 'r', 'id': 1140, 'synset': 'triangle.n.05', 'synonyms': ['triangle_(musical_instrument)'], 'def': 'a percussion instrument consisting of a metal bar bent in the shape of an open triangle', 'name': 'triangle_(musical_instrument)'}, {'frequency': 'r', 'id': 1141, 'synset': 'tricycle.n.01', 'synonyms': ['tricycle'], 'def': 'a vehicle with three wheels that is moved by foot pedals', 'name': 'tricycle'}, {'frequency': 'c', 'id': 1142, 'synset': 'tripod.n.01', 'synonyms': ['tripod'], 'def': 'a three-legged rack used for support', 'name': 'tripod'}, {'frequency': 'f', 'id': 1143, 'synset': 'trouser.n.01', 'synonyms': ['trousers', 'pants_(clothing)'], 'def': 'a garment extending from the waist to the knee or ankle, covering each leg separately', 'name': 'trousers'}, {'frequency': 'f', 'id': 1144, 'synset': 'truck.n.01', 'synonyms': ['truck'], 'def': 'an automotive vehicle suitable for hauling', 'name': 'truck'}, {'frequency': 'r', 'id': 1145, 'synset': 'truffle.n.03', 'synonyms': ['truffle_(chocolate)', 'chocolate_truffle'], 'def': 'creamy chocolate candy', 'name': 'truffle_(chocolate)'}, {'frequency': 'c', 'id': 1146, 'synset': 'trunk.n.02', 'synonyms': ['trunk'], 'def': 'luggage consisting of a large strong case used when traveling or for storage', 'name': 'trunk'}, {'frequency': 'r', 'id': 1147, 'synset': 'tub.n.02', 'synonyms': ['vat'], 'def': 'a large open vessel for holding or storing liquids', 'name': 'vat'}, {'frequency': 'c', 'id': 1148, 'synset': 'turban.n.01', 'synonyms': ['turban'], 'def': 'a traditional headdress consisting of a long scarf wrapped around the head', 'name': 'turban'}, {'frequency': 'r', 'id': 1149, 'synset': 'turkey.n.01', 'synonyms': ['turkey_(bird)'], 'def': 'large gallinaceous bird with fan-shaped tail; widely domesticated for food', 'name': 'turkey_(bird)'}, {'frequency': 'c', 'id': 1150, 'synset': 'turkey.n.04', 'synonyms': ['turkey_(food)'], 'def': 'flesh of large domesticated fowl usually roasted', 'name': 'turkey_(food)'}, {'frequency': 'r', 'id': 1151, 'synset': 'turnip.n.01', 'synonyms': ['turnip'], 'def': 'widely cultivated plant having a large fleshy edible white or yellow root', 'name': 'turnip'}, {'frequency': 'c', 'id': 1152, 'synset': 'turtle.n.02', 'synonyms': ['turtle'], 'def': 'any of various aquatic and land reptiles having a bony shell and flipper-like limbs for swimming', 'name': 'turtle'}, {'frequency': 'r', 'id': 1153, 'synset': 'turtleneck.n.01', 'synonyms': ['turtleneck_(clothing)', 'polo-neck'], 'def': 'a sweater or jersey with a high close-fitting collar', 'name': 'turtleneck_(clothing)'}, {'frequency': 'r', 'id': 1154, 'synset': 'typewriter.n.01', 'synonyms': ['typewriter'], 'def': 'hand-operated character printer for printing written messages one character at a time', 'name': 'typewriter'}, {'frequency': 'f', 'id': 1155, 'synset': 'umbrella.n.01', 'synonyms': ['umbrella'], 'def': 'a lightweight handheld collapsible canopy', 'name': 'umbrella'}, {'frequency': 'c', 'id': 1156, 'synset': 'underwear.n.01', 'synonyms': ['underwear', 'underclothes', 'underclothing', 'underpants'], 'def': 'undergarment worn next to the skin and under the outer garments', 'name': 'underwear'}, {'frequency': 'r', 'id': 1157, 'synset': 'unicycle.n.01', 'synonyms': ['unicycle'], 'def': 'a vehicle with a single wheel that is driven by pedals', 'name': 'unicycle'}, {'frequency': 'c', 'id': 1158, 'synset': 'urinal.n.01', 'synonyms': ['urinal'], 'def': 'a plumbing fixture (usually attached to the wall) used by men to urinate', 'name': 'urinal'}, {'frequency': 'r', 'id': 1159, 'synset': 'urn.n.01', 'synonyms': ['urn'], 'def': 'a large vase that usually has a pedestal or feet', 'name': 'urn'}, {'frequency': 'c', 'id': 1160, 'synset': 'vacuum.n.04', 'synonyms': ['vacuum_cleaner'], 'def': 'an electrical home appliance that cleans by suction', 'name': 'vacuum_cleaner'}, {'frequency': 'c', 'id': 1161, 'synset': 'valve.n.03', 'synonyms': ['valve'], 'def': 'control consisting of a mechanical device for controlling the flow of a fluid', 'name': 'valve'}, {'frequency': 'f', 'id': 1162, 'synset': 'vase.n.01', 'synonyms': ['vase'], 'def': 'an open jar of glass or porcelain used as an ornament or to hold flowers', 'name': 'vase'}, {'frequency': 'c', 'id': 1163, 'synset': 'vending_machine.n.01', 'synonyms': ['vending_machine'], 'def': 'a slot machine for selling goods', 'name': 'vending_machine'}, {'frequency': 'f', 'id': 1164, 'synset': 'vent.n.01', 'synonyms': ['vent', 'blowhole', 'air_vent'], 'def': 'a hole for the escape of gas or air', 'name': 'vent'}, {'frequency': 'c', 'id': 1165, 'synset': 'videotape.n.01', 'synonyms': ['videotape'], 'def': 'a video recording made on magnetic tape', 'name': 'videotape'}, {'frequency': 'r', 'id': 1166, 'synset': 'vinegar.n.01', 'synonyms': ['vinegar'], 'def': 'sour-tasting liquid produced usually by oxidation of the alcohol in wine or cider and used as a condiment or food preservative', 'name': 'vinegar'}, {'frequency': 'r', 'id': 1167, 'synset': 'violin.n.01', 'synonyms': ['violin', 'fiddle'], 'def': 'bowed stringed instrument that is the highest member of the violin family', 'name': 'violin'}, {'frequency': 'r', 'id': 1168, 'synset': 'vodka.n.01', 'synonyms': ['vodka'], 'def': 'unaged colorless liquor originating in Russia', 'name': 'vodka'}, {'frequency': 'r', 'id': 1169, 'synset': 'volleyball.n.02', 'synonyms': ['volleyball'], 'def': 'an inflated ball used in playing volleyball', 'name': 'volleyball'}, {'frequency': 'r', 'id': 1170, 'synset': 'vulture.n.01', 'synonyms': ['vulture'], 'def': 'any of various large birds of prey having naked heads and weak claws and feeding chiefly on carrion', 'name': 'vulture'}, {'frequency': 'c', 'id': 1171, 'synset': 'waffle.n.01', 'synonyms': ['waffle'], 'def': 'pancake batter baked in a waffle iron', 'name': 'waffle'}, {'frequency': 'r', 'id': 1172, 'synset': 'waffle_iron.n.01', 'synonyms': ['waffle_iron'], 'def': 'a kitchen appliance for baking waffles', 'name': 'waffle_iron'}, {'frequency': 'c', 'id': 1173, 'synset': 'wagon.n.01', 'synonyms': ['wagon'], 'def': 'any of various kinds of wheeled vehicles drawn by an animal or a tractor', 'name': 'wagon'}, {'frequency': 'c', 'id': 1174, 'synset': 'wagon_wheel.n.01', 'synonyms': ['wagon_wheel'], 'def': 'a wheel of a wagon', 'name': 'wagon_wheel'}, {'frequency': 'c', 'id': 1175, 'synset': 'walking_stick.n.01', 'synonyms': ['walking_stick'], 'def': 'a stick carried in the hand for support in walking', 'name': 'walking_stick'}, {'frequency': 'c', 'id': 1176, 'synset': 'wall_clock.n.01', 'synonyms': ['wall_clock'], 'def': 'a clock mounted on a wall', 'name': 'wall_clock'}, {'frequency': 'f', 'id': 1177, 'synset': 'wall_socket.n.01', 'synonyms': ['wall_socket', 'wall_plug', 'electric_outlet', 'electrical_outlet', 'outlet', 'electric_receptacle'], 'def': 'receptacle providing a place in a wiring system where current can be taken to run electrical devices', 'name': 'wall_socket'}, {'frequency': 'c', 'id': 1178, 'synset': 'wallet.n.01', 'synonyms': ['wallet', 'billfold'], 'def': 'a pocket-size case for holding papers and paper money', 'name': 'wallet'}, {'frequency': 'r', 'id': 1179, 'synset': 'walrus.n.01', 'synonyms': ['walrus'], 'def': 'either of two large northern marine mammals having ivory tusks and tough hide over thick blubber', 'name': 'walrus'}, {'frequency': 'r', 'id': 1180, 'synset': 'wardrobe.n.01', 'synonyms': ['wardrobe'], 'def': 'a tall piece of furniture that provides storage space for clothes; has a door and rails or hooks for hanging clothes', 'name': 'wardrobe'}, {'frequency': 'r', 'id': 1181, 'synset': 'wasabi.n.02', 'synonyms': ['wasabi'], 'def': 'the thick green root of the wasabi plant that the Japanese use in cooking and that tastes like strong horseradish', 'name': 'wasabi'}, {'frequency': 'c', 'id': 1182, 'synset': 'washer.n.03', 'synonyms': ['automatic_washer', 'washing_machine'], 'def': 'a home appliance for washing clothes and linens automatically', 'name': 'automatic_washer'}, {'frequency': 'f', 'id': 1183, 'synset': 'watch.n.01', 'synonyms': ['watch', 'wristwatch'], 'def': 'a small, portable timepiece', 'name': 'watch'}, {'frequency': 'f', 'id': 1184, 'synset': 'water_bottle.n.01', 'synonyms': ['water_bottle'], 'def': 'a bottle for holding water', 'name': 'water_bottle'}, {'frequency': 'c', 'id': 1185, 'synset': 'water_cooler.n.01', 'synonyms': ['water_cooler'], 'def': 'a device for cooling and dispensing drinking water', 'name': 'water_cooler'}, {'frequency': 'c', 'id': 1186, 'synset': 'water_faucet.n.01', 'synonyms': ['water_faucet', 'water_tap', 'tap_(water_faucet)'], 'def': 'a faucet for drawing water from a pipe or cask', 'name': 'water_faucet'}, {'frequency': 'r', 'id': 1187, 'synset': 'water_filter.n.01', 'synonyms': ['water_filter'], 'def': 'a filter to remove impurities from the water supply', 'name': 'water_filter'}, {'frequency': 'r', 'id': 1188, 'synset': 'water_heater.n.01', 'synonyms': ['water_heater', 'hot-water_heater'], 'def': 'a heater and storage tank to supply heated water', 'name': 'water_heater'}, {'frequency': 'r', 'id': 1189, 'synset': 'water_jug.n.01', 'synonyms': ['water_jug'], 'def': 'a jug that holds water', 'name': 'water_jug'}, {'frequency': 'r', 'id': 1190, 'synset': 'water_pistol.n.01', 'synonyms': ['water_gun', 'squirt_gun'], 'def': 'plaything consisting of a toy pistol that squirts water', 'name': 'water_gun'}, {'frequency': 'c', 'id': 1191, 'synset': 'water_scooter.n.01', 'synonyms': ['water_scooter', 'sea_scooter', 'jet_ski'], 'def': 'a motorboat resembling a motor scooter (NOT A SURFBOARD OR WATER SKI)', 'name': 'water_scooter'}, {'frequency': 'c', 'id': 1192, 'synset': 'water_ski.n.01', 'synonyms': ['water_ski'], 'def': 'broad ski for skimming over water towed by a speedboat (DO NOT MARK WATER)', 'name': 'water_ski'}, {'frequency': 'c', 'id': 1193, 'synset': 'water_tower.n.01', 'synonyms': ['water_tower'], 'def': 'a large reservoir for water', 'name': 'water_tower'}, {'frequency': 'c', 'id': 1194, 'synset': 'watering_can.n.01', 'synonyms': ['watering_can'], 'def': 'a container with a handle and a spout with a perforated nozzle; used to sprinkle water over plants', 'name': 'watering_can'}, {'frequency': 'c', 'id': 1195, 'synset': 'watermelon.n.02', 'synonyms': ['watermelon'], 'def': 'large oblong or roundish melon with a hard green rind and sweet watery red or occasionally yellowish pulp', 'name': 'watermelon'}, {'frequency': 'f', 'id': 1196, 'synset': 'weathervane.n.01', 'synonyms': ['weathervane', 'vane_(weathervane)', 'wind_vane'], 'def': 'mechanical device attached to an elevated structure; rotates freely to show the direction of the wind', 'name': 'weathervane'}, {'frequency': 'c', 'id': 1197, 'synset': 'webcam.n.01', 'synonyms': ['webcam'], 'def': 'a digital camera designed to take digital photographs and transmit them over the internet', 'name': 'webcam'}, {'frequency': 'c', 'id': 1198, 'synset': 'wedding_cake.n.01', 'synonyms': ['wedding_cake', 'bridecake'], 'def': 'a rich cake with two or more tiers and covered with frosting and decorations; served at a wedding reception', 'name': 'wedding_cake'}, {'frequency': 'c', 'id': 1199, 'synset': 'wedding_ring.n.01', 'synonyms': ['wedding_ring', 'wedding_band'], 'def': 'a ring given to the bride and/or groom at the wedding', 'name': 'wedding_ring'}, {'frequency': 'f', 'id': 1200, 'synset': 'wet_suit.n.01', 'synonyms': ['wet_suit'], 'def': 'a close-fitting garment made of a permeable material; worn in cold water to retain body heat', 'name': 'wet_suit'}, {'frequency': 'f', 'id': 1201, 'synset': 'wheel.n.01', 'synonyms': ['wheel'], 'def': 'a circular frame with spokes (or a solid disc) that can rotate on a shaft or axle', 'name': 'wheel'}, {'frequency': 'c', 'id': 1202, 'synset': 'wheelchair.n.01', 'synonyms': ['wheelchair'], 'def': 'a movable chair mounted on large wheels', 'name': 'wheelchair'}, {'frequency': 'c', 'id': 1203, 'synset': 'whipped_cream.n.01', 'synonyms': ['whipped_cream'], 'def': 'cream that has been beaten until light and fluffy', 'name': 'whipped_cream'}, {'frequency': 'r', 'id': 1204, 'synset': 'whiskey.n.01', 'synonyms': ['whiskey'], 'def': 'a liquor made from fermented mash of grain', 'name': 'whiskey'}, {'frequency': 'r', 'id': 1205, 'synset': 'whistle.n.03', 'synonyms': ['whistle'], 'def': 'a small wind instrument that produces a whistling sound by blowing into it', 'name': 'whistle'}, {'frequency': 'r', 'id': 1206, 'synset': 'wick.n.02', 'synonyms': ['wick'], 'def': 'a loosely woven cord in a candle or oil lamp that is lit on fire', 'name': 'wick'}, {'frequency': 'c', 'id': 1207, 'synset': 'wig.n.01', 'synonyms': ['wig'], 'def': 'hairpiece covering the head and made of real or synthetic hair', 'name': 'wig'}, {'frequency': 'c', 'id': 1208, 'synset': 'wind_chime.n.01', 'synonyms': ['wind_chime'], 'def': 'a decorative arrangement of pieces of metal or glass or pottery that hang together loosely so the wind can cause them to tinkle', 'name': 'wind_chime'}, {'frequency': 'c', 'id': 1209, 'synset': 'windmill.n.01', 'synonyms': ['windmill'], 'def': 'a mill that is powered by the wind', 'name': 'windmill'}, {'frequency': 'c', 'id': 1210, 'synset': 'window_box.n.01', 'synonyms': ['window_box_(for_plants)'], 'def': 'a container for growing plants on a windowsill', 'name': 'window_box_(for_plants)'}, {'frequency': 'f', 'id': 1211, 'synset': 'windshield_wiper.n.01', 'synonyms': ['windshield_wiper', 'windscreen_wiper', 'wiper_(for_windshield/screen)'], 'def': 'a mechanical device that cleans the windshield', 'name': 'windshield_wiper'}, {'frequency': 'c', 'id': 1212, 'synset': 'windsock.n.01', 'synonyms': ['windsock', 'air_sock', 'air-sleeve', 'wind_sleeve', 'wind_cone'], 'def': 'a truncated cloth cone mounted on a mast/pole; shows wind direction', 'name': 'windsock'}, {'frequency': 'f', 'id': 1213, 'synset': 'wine_bottle.n.01', 'synonyms': ['wine_bottle'], 'def': 'a bottle for holding wine', 'name': 'wine_bottle'}, {'frequency': 'r', 'id': 1214, 'synset': 'wine_bucket.n.01', 'synonyms': ['wine_bucket', 'wine_cooler'], 'def': 'a bucket of ice used to chill a bottle of wine', 'name': 'wine_bucket'}, {'frequency': 'f', 'id': 1215, 'synset': 'wineglass.n.01', 'synonyms': ['wineglass'], 'def': 'a glass that has a stem and in which wine is served', 'name': 'wineglass'}, {'frequency': 'r', 'id': 1216, 'synset': 'wing_chair.n.01', 'synonyms': ['wing_chair'], 'def': 'easy chair having wings on each side of a high back', 'name': 'wing_chair'}, {'frequency': 'c', 'id': 1217, 'synset': 'winker.n.02', 'synonyms': ['blinder_(for_horses)'], 'def': 'blinds that prevent a horse from seeing something on either side', 'name': 'blinder_(for_horses)'}, {'frequency': 'c', 'id': 1218, 'synset': 'wok.n.01', 'synonyms': ['wok'], 'def': 'pan with a convex bottom; used for frying in Chinese cooking', 'name': 'wok'}, {'frequency': 'r', 'id': 1219, 'synset': 'wolf.n.01', 'synonyms': ['wolf'], 'def': 'a wild carnivorous mammal of the dog family, living and hunting in packs', 'name': 'wolf'}, {'frequency': 'c', 'id': 1220, 'synset': 'wooden_spoon.n.02', 'synonyms': ['wooden_spoon'], 'def': 'a spoon made of wood', 'name': 'wooden_spoon'}, {'frequency': 'c', 'id': 1221, 'synset': 'wreath.n.01', 'synonyms': ['wreath'], 'def': 'an arrangement of flowers, leaves, or stems fastened in a ring', 'name': 'wreath'}, {'frequency': 'c', 'id': 1222, 'synset': 'wrench.n.03', 'synonyms': ['wrench', 'spanner'], 'def': 'a hand tool that is used to hold or twist a nut or bolt', 'name': 'wrench'}, {'frequency': 'c', 'id': 1223, 'synset': 'wristband.n.01', 'synonyms': ['wristband'], 'def': 'band consisting of a part of a sleeve that covers the wrist', 'name': 'wristband'}, {'frequency': 'f', 'id': 1224, 'synset': 'wristlet.n.01', 'synonyms': ['wristlet', 'wrist_band'], 'def': 'a band or bracelet worn around the wrist', 'name': 'wristlet'}, {'frequency': 'r', 'id': 1225, 'synset': 'yacht.n.01', 'synonyms': ['yacht'], 'def': 'an expensive vessel propelled by sail or power and used for cruising or racing', 'name': 'yacht'}, {'frequency': 'r', 'id': 1226, 'synset': 'yak.n.02', 'synonyms': ['yak'], 'def': 'large long-haired wild ox of Tibet often domesticated', 'name': 'yak'}, {'frequency': 'c', 'id': 1227, 'synset': 'yogurt.n.01', 'synonyms': ['yogurt', 'yoghurt', 'yoghourt'], 'def': 'a custard-like food made from curdled milk', 'name': 'yogurt'}, {'frequency': 'r', 'id': 1228, 'synset': 'yoke.n.07', 'synonyms': ['yoke_(animal_equipment)'], 'def': 'gear joining two animals at the neck; NOT egg yolk', 'name': 'yoke_(animal_equipment)'}, {'frequency': 'f', 'id': 1229, 'synset': 'zebra.n.01', 'synonyms': ['zebra'], 'def': 'any of several fleet black-and-white striped African equines', 'name': 'zebra'}, {'frequency': 'c', 'id': 1230, 'synset': 'zucchini.n.02', 'synonyms': ['zucchini', 'courgette'], 'def': 'small cucumber-shaped vegetable marrow; typically dark green', 'name': 'zucchini'}] # noqa
+# fmt: on
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/lvis_v1_categories.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/lvis_v1_categories.py
new file mode 100644
index 0000000000000000000000000000000000000000..7374e6968bb006f5d8c49e75d9d3b31ea3d77d05
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/lvis_v1_categories.py
@@ -0,0 +1,16 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Autogen with
+# with open("lvis_v1_val.json", "r") as f:
+# a = json.load(f)
+# c = a["categories"]
+# for x in c:
+# del x["image_count"]
+# del x["instance_count"]
+# LVIS_CATEGORIES = repr(c) + " # noqa"
+# with open("/tmp/lvis_categories.py", "wt") as f:
+# f.write(f"LVIS_CATEGORIES = {LVIS_CATEGORIES}")
+# Then paste the contents of that file below
+
+# fmt: off
+LVIS_CATEGORIES = [{'frequency': 'c', 'synset': 'aerosol.n.02', 'synonyms': ['aerosol_can', 'spray_can'], 'id': 1, 'def': 'a dispenser that holds a substance under pressure', 'name': 'aerosol_can'}, {'frequency': 'f', 'synset': 'air_conditioner.n.01', 'synonyms': ['air_conditioner'], 'id': 2, 'def': 'a machine that keeps air cool and dry', 'name': 'air_conditioner'}, {'frequency': 'f', 'synset': 'airplane.n.01', 'synonyms': ['airplane', 'aeroplane'], 'id': 3, 'def': 'an aircraft that has a fixed wing and is powered by propellers or jets', 'name': 'airplane'}, {'frequency': 'f', 'synset': 'alarm_clock.n.01', 'synonyms': ['alarm_clock'], 'id': 4, 'def': 'a clock that wakes a sleeper at some preset time', 'name': 'alarm_clock'}, {'frequency': 'c', 'synset': 'alcohol.n.01', 'synonyms': ['alcohol', 'alcoholic_beverage'], 'id': 5, 'def': 'a liquor or brew containing alcohol as the active agent', 'name': 'alcohol'}, {'frequency': 'c', 'synset': 'alligator.n.02', 'synonyms': ['alligator', 'gator'], 'id': 6, 'def': 'amphibious reptiles related to crocodiles but with shorter broader snouts', 'name': 'alligator'}, {'frequency': 'c', 'synset': 'almond.n.02', 'synonyms': ['almond'], 'id': 7, 'def': 'oval-shaped edible seed of the almond tree', 'name': 'almond'}, {'frequency': 'c', 'synset': 'ambulance.n.01', 'synonyms': ['ambulance'], 'id': 8, 'def': 'a vehicle that takes people to and from hospitals', 'name': 'ambulance'}, {'frequency': 'c', 'synset': 'amplifier.n.01', 'synonyms': ['amplifier'], 'id': 9, 'def': 'electronic equipment that increases strength of signals', 'name': 'amplifier'}, {'frequency': 'c', 'synset': 'anklet.n.03', 'synonyms': ['anklet', 'ankle_bracelet'], 'id': 10, 'def': 'an ornament worn around the ankle', 'name': 'anklet'}, {'frequency': 'f', 'synset': 'antenna.n.01', 'synonyms': ['antenna', 'aerial', 'transmitting_aerial'], 'id': 11, 'def': 'an electrical device that sends or receives radio or television signals', 'name': 'antenna'}, {'frequency': 'f', 'synset': 'apple.n.01', 'synonyms': ['apple'], 'id': 12, 'def': 'fruit with red or yellow or green skin and sweet to tart crisp whitish flesh', 'name': 'apple'}, {'frequency': 'r', 'synset': 'applesauce.n.01', 'synonyms': ['applesauce'], 'id': 13, 'def': 'puree of stewed apples usually sweetened and spiced', 'name': 'applesauce'}, {'frequency': 'r', 'synset': 'apricot.n.02', 'synonyms': ['apricot'], 'id': 14, 'def': 'downy yellow to rosy-colored fruit resembling a small peach', 'name': 'apricot'}, {'frequency': 'f', 'synset': 'apron.n.01', 'synonyms': ['apron'], 'id': 15, 'def': 'a garment of cloth that is tied about the waist and worn to protect clothing', 'name': 'apron'}, {'frequency': 'c', 'synset': 'aquarium.n.01', 'synonyms': ['aquarium', 'fish_tank'], 'id': 16, 'def': 'a tank/pool/bowl filled with water for keeping live fish and underwater animals', 'name': 'aquarium'}, {'frequency': 'r', 'synset': 'arctic.n.02', 'synonyms': ['arctic_(type_of_shoe)', 'galosh', 'golosh', 'rubber_(type_of_shoe)', 'gumshoe'], 'id': 17, 'def': 'a waterproof overshoe that protects shoes from water or snow', 'name': 'arctic_(type_of_shoe)'}, {'frequency': 'c', 'synset': 'armband.n.02', 'synonyms': ['armband'], 'id': 18, 'def': 'a band worn around the upper arm', 'name': 'armband'}, {'frequency': 'f', 'synset': 'armchair.n.01', 'synonyms': ['armchair'], 'id': 19, 'def': 'chair with a support on each side for arms', 'name': 'armchair'}, {'frequency': 'r', 'synset': 'armoire.n.01', 'synonyms': ['armoire'], 'id': 20, 'def': 'a large wardrobe or cabinet', 'name': 'armoire'}, {'frequency': 'r', 'synset': 'armor.n.01', 'synonyms': ['armor', 'armour'], 'id': 21, 'def': 'protective covering made of metal and used in combat', 'name': 'armor'}, {'frequency': 'c', 'synset': 'artichoke.n.02', 'synonyms': ['artichoke'], 'id': 22, 'def': 'a thistlelike flower head with edible fleshy leaves and heart', 'name': 'artichoke'}, {'frequency': 'f', 'synset': 'ashcan.n.01', 'synonyms': ['trash_can', 'garbage_can', 'wastebin', 'dustbin', 'trash_barrel', 'trash_bin'], 'id': 23, 'def': 'a bin that holds rubbish until it is collected', 'name': 'trash_can'}, {'frequency': 'c', 'synset': 'ashtray.n.01', 'synonyms': ['ashtray'], 'id': 24, 'def': "a receptacle for the ash from smokers' cigars or cigarettes", 'name': 'ashtray'}, {'frequency': 'c', 'synset': 'asparagus.n.02', 'synonyms': ['asparagus'], 'id': 25, 'def': 'edible young shoots of the asparagus plant', 'name': 'asparagus'}, {'frequency': 'c', 'synset': 'atomizer.n.01', 'synonyms': ['atomizer', 'atomiser', 'spray', 'sprayer', 'nebulizer', 'nebuliser'], 'id': 26, 'def': 'a dispenser that turns a liquid (such as perfume) into a fine mist', 'name': 'atomizer'}, {'frequency': 'f', 'synset': 'avocado.n.01', 'synonyms': ['avocado'], 'id': 27, 'def': 'a pear-shaped fruit with green or blackish skin and rich yellowish pulp enclosing a single large seed', 'name': 'avocado'}, {'frequency': 'c', 'synset': 'award.n.02', 'synonyms': ['award', 'accolade'], 'id': 28, 'def': 'a tangible symbol signifying approval or distinction', 'name': 'award'}, {'frequency': 'f', 'synset': 'awning.n.01', 'synonyms': ['awning'], 'id': 29, 'def': 'a canopy made of canvas to shelter people or things from rain or sun', 'name': 'awning'}, {'frequency': 'r', 'synset': 'ax.n.01', 'synonyms': ['ax', 'axe'], 'id': 30, 'def': 'an edge tool with a heavy bladed head mounted across a handle', 'name': 'ax'}, {'frequency': 'r', 'synset': 'baboon.n.01', 'synonyms': ['baboon'], 'id': 31, 'def': 'large terrestrial monkeys having doglike muzzles', 'name': 'baboon'}, {'frequency': 'f', 'synset': 'baby_buggy.n.01', 'synonyms': ['baby_buggy', 'baby_carriage', 'perambulator', 'pram', 'stroller'], 'id': 32, 'def': 'a small vehicle with four wheels in which a baby or child is pushed around', 'name': 'baby_buggy'}, {'frequency': 'c', 'synset': 'backboard.n.01', 'synonyms': ['basketball_backboard'], 'id': 33, 'def': 'a raised vertical board with basket attached; used to play basketball', 'name': 'basketball_backboard'}, {'frequency': 'f', 'synset': 'backpack.n.01', 'synonyms': ['backpack', 'knapsack', 'packsack', 'rucksack', 'haversack'], 'id': 34, 'def': 'a bag carried by a strap on your back or shoulder', 'name': 'backpack'}, {'frequency': 'f', 'synset': 'bag.n.04', 'synonyms': ['handbag', 'purse', 'pocketbook'], 'id': 35, 'def': 'a container used for carrying money and small personal items or accessories', 'name': 'handbag'}, {'frequency': 'f', 'synset': 'bag.n.06', 'synonyms': ['suitcase', 'baggage', 'luggage'], 'id': 36, 'def': 'cases used to carry belongings when traveling', 'name': 'suitcase'}, {'frequency': 'c', 'synset': 'bagel.n.01', 'synonyms': ['bagel', 'beigel'], 'id': 37, 'def': 'glazed yeast-raised doughnut-shaped roll with hard crust', 'name': 'bagel'}, {'frequency': 'r', 'synset': 'bagpipe.n.01', 'synonyms': ['bagpipe'], 'id': 38, 'def': 'a tubular wind instrument; the player blows air into a bag and squeezes it out', 'name': 'bagpipe'}, {'frequency': 'r', 'synset': 'baguet.n.01', 'synonyms': ['baguet', 'baguette'], 'id': 39, 'def': 'narrow French stick loaf', 'name': 'baguet'}, {'frequency': 'r', 'synset': 'bait.n.02', 'synonyms': ['bait', 'lure'], 'id': 40, 'def': 'something used to lure fish or other animals into danger so they can be trapped or killed', 'name': 'bait'}, {'frequency': 'f', 'synset': 'ball.n.06', 'synonyms': ['ball'], 'id': 41, 'def': 'a spherical object used as a plaything', 'name': 'ball'}, {'frequency': 'r', 'synset': 'ballet_skirt.n.01', 'synonyms': ['ballet_skirt', 'tutu'], 'id': 42, 'def': 'very short skirt worn by ballerinas', 'name': 'ballet_skirt'}, {'frequency': 'f', 'synset': 'balloon.n.01', 'synonyms': ['balloon'], 'id': 43, 'def': 'large tough nonrigid bag filled with gas or heated air', 'name': 'balloon'}, {'frequency': 'c', 'synset': 'bamboo.n.02', 'synonyms': ['bamboo'], 'id': 44, 'def': 'woody tropical grass having hollow woody stems', 'name': 'bamboo'}, {'frequency': 'f', 'synset': 'banana.n.02', 'synonyms': ['banana'], 'id': 45, 'def': 'elongated crescent-shaped yellow fruit with soft sweet flesh', 'name': 'banana'}, {'frequency': 'c', 'synset': 'band_aid.n.01', 'synonyms': ['Band_Aid'], 'id': 46, 'def': 'trade name for an adhesive bandage to cover small cuts or blisters', 'name': 'Band_Aid'}, {'frequency': 'c', 'synset': 'bandage.n.01', 'synonyms': ['bandage'], 'id': 47, 'def': 'a piece of soft material that covers and protects an injured part of the body', 'name': 'bandage'}, {'frequency': 'f', 'synset': 'bandanna.n.01', 'synonyms': ['bandanna', 'bandana'], 'id': 48, 'def': 'large and brightly colored handkerchief; often used as a neckerchief', 'name': 'bandanna'}, {'frequency': 'r', 'synset': 'banjo.n.01', 'synonyms': ['banjo'], 'id': 49, 'def': 'a stringed instrument of the guitar family with a long neck and circular body', 'name': 'banjo'}, {'frequency': 'f', 'synset': 'banner.n.01', 'synonyms': ['banner', 'streamer'], 'id': 50, 'def': 'long strip of cloth or paper used for decoration or advertising', 'name': 'banner'}, {'frequency': 'r', 'synset': 'barbell.n.01', 'synonyms': ['barbell'], 'id': 51, 'def': 'a bar to which heavy discs are attached at each end; used in weightlifting', 'name': 'barbell'}, {'frequency': 'r', 'synset': 'barge.n.01', 'synonyms': ['barge'], 'id': 52, 'def': 'a flatbottom boat for carrying heavy loads (especially on canals)', 'name': 'barge'}, {'frequency': 'f', 'synset': 'barrel.n.02', 'synonyms': ['barrel', 'cask'], 'id': 53, 'def': 'a cylindrical container that holds liquids', 'name': 'barrel'}, {'frequency': 'c', 'synset': 'barrette.n.01', 'synonyms': ['barrette'], 'id': 54, 'def': "a pin for holding women's hair in place", 'name': 'barrette'}, {'frequency': 'c', 'synset': 'barrow.n.03', 'synonyms': ['barrow', 'garden_cart', 'lawn_cart', 'wheelbarrow'], 'id': 55, 'def': 'a cart for carrying small loads; has handles and one or more wheels', 'name': 'barrow'}, {'frequency': 'f', 'synset': 'base.n.03', 'synonyms': ['baseball_base'], 'id': 56, 'def': 'a place that the runner must touch before scoring', 'name': 'baseball_base'}, {'frequency': 'f', 'synset': 'baseball.n.02', 'synonyms': ['baseball'], 'id': 57, 'def': 'a ball used in playing baseball', 'name': 'baseball'}, {'frequency': 'f', 'synset': 'baseball_bat.n.01', 'synonyms': ['baseball_bat'], 'id': 58, 'def': 'an implement used in baseball by the batter', 'name': 'baseball_bat'}, {'frequency': 'f', 'synset': 'baseball_cap.n.01', 'synonyms': ['baseball_cap', 'jockey_cap', 'golf_cap'], 'id': 59, 'def': 'a cap with a bill', 'name': 'baseball_cap'}, {'frequency': 'f', 'synset': 'baseball_glove.n.01', 'synonyms': ['baseball_glove', 'baseball_mitt'], 'id': 60, 'def': 'the handwear used by fielders in playing baseball', 'name': 'baseball_glove'}, {'frequency': 'f', 'synset': 'basket.n.01', 'synonyms': ['basket', 'handbasket'], 'id': 61, 'def': 'a container that is usually woven and has handles', 'name': 'basket'}, {'frequency': 'c', 'synset': 'basketball.n.02', 'synonyms': ['basketball'], 'id': 62, 'def': 'an inflated ball used in playing basketball', 'name': 'basketball'}, {'frequency': 'r', 'synset': 'bass_horn.n.01', 'synonyms': ['bass_horn', 'sousaphone', 'tuba'], 'id': 63, 'def': 'the lowest brass wind instrument', 'name': 'bass_horn'}, {'frequency': 'c', 'synset': 'bat.n.01', 'synonyms': ['bat_(animal)'], 'id': 64, 'def': 'nocturnal mouselike mammal with forelimbs modified to form membranous wings', 'name': 'bat_(animal)'}, {'frequency': 'f', 'synset': 'bath_mat.n.01', 'synonyms': ['bath_mat'], 'id': 65, 'def': 'a heavy towel or mat to stand on while drying yourself after a bath', 'name': 'bath_mat'}, {'frequency': 'f', 'synset': 'bath_towel.n.01', 'synonyms': ['bath_towel'], 'id': 66, 'def': 'a large towel; to dry yourself after a bath', 'name': 'bath_towel'}, {'frequency': 'c', 'synset': 'bathrobe.n.01', 'synonyms': ['bathrobe'], 'id': 67, 'def': 'a loose-fitting robe of towelling; worn after a bath or swim', 'name': 'bathrobe'}, {'frequency': 'f', 'synset': 'bathtub.n.01', 'synonyms': ['bathtub', 'bathing_tub'], 'id': 68, 'def': 'a large open container that you fill with water and use to wash the body', 'name': 'bathtub'}, {'frequency': 'r', 'synset': 'batter.n.02', 'synonyms': ['batter_(food)'], 'id': 69, 'def': 'a liquid or semiliquid mixture, as of flour, eggs, and milk, used in cooking', 'name': 'batter_(food)'}, {'frequency': 'c', 'synset': 'battery.n.02', 'synonyms': ['battery'], 'id': 70, 'def': 'a portable device that produces electricity', 'name': 'battery'}, {'frequency': 'r', 'synset': 'beach_ball.n.01', 'synonyms': ['beachball'], 'id': 71, 'def': 'large and light ball; for play at the seaside', 'name': 'beachball'}, {'frequency': 'c', 'synset': 'bead.n.01', 'synonyms': ['bead'], 'id': 72, 'def': 'a small ball with a hole through the middle used for ornamentation, jewellery, etc.', 'name': 'bead'}, {'frequency': 'c', 'synset': 'bean_curd.n.01', 'synonyms': ['bean_curd', 'tofu'], 'id': 73, 'def': 'cheeselike food made of curdled soybean milk', 'name': 'bean_curd'}, {'frequency': 'c', 'synset': 'beanbag.n.01', 'synonyms': ['beanbag'], 'id': 74, 'def': 'a bag filled with dried beans or similar items; used in games or to sit on', 'name': 'beanbag'}, {'frequency': 'f', 'synset': 'beanie.n.01', 'synonyms': ['beanie', 'beany'], 'id': 75, 'def': 'a small skullcap; formerly worn by schoolboys and college freshmen', 'name': 'beanie'}, {'frequency': 'f', 'synset': 'bear.n.01', 'synonyms': ['bear'], 'id': 76, 'def': 'large carnivorous or omnivorous mammals with shaggy coats and claws', 'name': 'bear'}, {'frequency': 'f', 'synset': 'bed.n.01', 'synonyms': ['bed'], 'id': 77, 'def': 'a piece of furniture that provides a place to sleep', 'name': 'bed'}, {'frequency': 'r', 'synset': 'bedpan.n.01', 'synonyms': ['bedpan'], 'id': 78, 'def': 'a shallow vessel used by a bedridden patient for defecation and urination', 'name': 'bedpan'}, {'frequency': 'f', 'synset': 'bedspread.n.01', 'synonyms': ['bedspread', 'bedcover', 'bed_covering', 'counterpane', 'spread'], 'id': 79, 'def': 'decorative cover for a bed', 'name': 'bedspread'}, {'frequency': 'f', 'synset': 'beef.n.01', 'synonyms': ['cow'], 'id': 80, 'def': 'cattle/cow', 'name': 'cow'}, {'frequency': 'f', 'synset': 'beef.n.02', 'synonyms': ['beef_(food)', 'boeuf_(food)'], 'id': 81, 'def': 'meat from an adult domestic bovine', 'name': 'beef_(food)'}, {'frequency': 'r', 'synset': 'beeper.n.01', 'synonyms': ['beeper', 'pager'], 'id': 82, 'def': 'an device that beeps when the person carrying it is being paged', 'name': 'beeper'}, {'frequency': 'f', 'synset': 'beer_bottle.n.01', 'synonyms': ['beer_bottle'], 'id': 83, 'def': 'a bottle that holds beer', 'name': 'beer_bottle'}, {'frequency': 'c', 'synset': 'beer_can.n.01', 'synonyms': ['beer_can'], 'id': 84, 'def': 'a can that holds beer', 'name': 'beer_can'}, {'frequency': 'r', 'synset': 'beetle.n.01', 'synonyms': ['beetle'], 'id': 85, 'def': 'insect with hard wing covers', 'name': 'beetle'}, {'frequency': 'f', 'synset': 'bell.n.01', 'synonyms': ['bell'], 'id': 86, 'def': 'a hollow device made of metal that makes a ringing sound when struck', 'name': 'bell'}, {'frequency': 'f', 'synset': 'bell_pepper.n.02', 'synonyms': ['bell_pepper', 'capsicum'], 'id': 87, 'def': 'large bell-shaped sweet pepper in green or red or yellow or orange or black varieties', 'name': 'bell_pepper'}, {'frequency': 'f', 'synset': 'belt.n.02', 'synonyms': ['belt'], 'id': 88, 'def': 'a band to tie or buckle around the body (usually at the waist)', 'name': 'belt'}, {'frequency': 'f', 'synset': 'belt_buckle.n.01', 'synonyms': ['belt_buckle'], 'id': 89, 'def': 'the buckle used to fasten a belt', 'name': 'belt_buckle'}, {'frequency': 'f', 'synset': 'bench.n.01', 'synonyms': ['bench'], 'id': 90, 'def': 'a long seat for more than one person', 'name': 'bench'}, {'frequency': 'c', 'synset': 'beret.n.01', 'synonyms': ['beret'], 'id': 91, 'def': 'a cap with no brim or bill; made of soft cloth', 'name': 'beret'}, {'frequency': 'c', 'synset': 'bib.n.02', 'synonyms': ['bib'], 'id': 92, 'def': 'a napkin tied under the chin of a child while eating', 'name': 'bib'}, {'frequency': 'r', 'synset': 'bible.n.01', 'synonyms': ['Bible'], 'id': 93, 'def': 'the sacred writings of the Christian religions', 'name': 'Bible'}, {'frequency': 'f', 'synset': 'bicycle.n.01', 'synonyms': ['bicycle', 'bike_(bicycle)'], 'id': 94, 'def': 'a wheeled vehicle that has two wheels and is moved by foot pedals', 'name': 'bicycle'}, {'frequency': 'f', 'synset': 'bill.n.09', 'synonyms': ['visor', 'vizor'], 'id': 95, 'def': 'a brim that projects to the front to shade the eyes', 'name': 'visor'}, {'frequency': 'f', 'synset': 'billboard.n.01', 'synonyms': ['billboard'], 'id': 96, 'def': 'large outdoor signboard', 'name': 'billboard'}, {'frequency': 'c', 'synset': 'binder.n.03', 'synonyms': ['binder', 'ring-binder'], 'id': 97, 'def': 'holds loose papers or magazines', 'name': 'binder'}, {'frequency': 'c', 'synset': 'binoculars.n.01', 'synonyms': ['binoculars', 'field_glasses', 'opera_glasses'], 'id': 98, 'def': 'an optical instrument designed for simultaneous use by both eyes', 'name': 'binoculars'}, {'frequency': 'f', 'synset': 'bird.n.01', 'synonyms': ['bird'], 'id': 99, 'def': 'animal characterized by feathers and wings', 'name': 'bird'}, {'frequency': 'c', 'synset': 'bird_feeder.n.01', 'synonyms': ['birdfeeder'], 'id': 100, 'def': 'an outdoor device that supplies food for wild birds', 'name': 'birdfeeder'}, {'frequency': 'c', 'synset': 'birdbath.n.01', 'synonyms': ['birdbath'], 'id': 101, 'def': 'an ornamental basin (usually in a garden) for birds to bathe in', 'name': 'birdbath'}, {'frequency': 'c', 'synset': 'birdcage.n.01', 'synonyms': ['birdcage'], 'id': 102, 'def': 'a cage in which a bird can be kept', 'name': 'birdcage'}, {'frequency': 'c', 'synset': 'birdhouse.n.01', 'synonyms': ['birdhouse'], 'id': 103, 'def': 'a shelter for birds', 'name': 'birdhouse'}, {'frequency': 'f', 'synset': 'birthday_cake.n.01', 'synonyms': ['birthday_cake'], 'id': 104, 'def': 'decorated cake served at a birthday party', 'name': 'birthday_cake'}, {'frequency': 'r', 'synset': 'birthday_card.n.01', 'synonyms': ['birthday_card'], 'id': 105, 'def': 'a card expressing a birthday greeting', 'name': 'birthday_card'}, {'frequency': 'r', 'synset': 'black_flag.n.01', 'synonyms': ['pirate_flag'], 'id': 106, 'def': 'a flag usually bearing a white skull and crossbones on a black background', 'name': 'pirate_flag'}, {'frequency': 'c', 'synset': 'black_sheep.n.02', 'synonyms': ['black_sheep'], 'id': 107, 'def': 'sheep with a black coat', 'name': 'black_sheep'}, {'frequency': 'c', 'synset': 'blackberry.n.01', 'synonyms': ['blackberry'], 'id': 108, 'def': 'large sweet black or very dark purple edible aggregate fruit', 'name': 'blackberry'}, {'frequency': 'f', 'synset': 'blackboard.n.01', 'synonyms': ['blackboard', 'chalkboard'], 'id': 109, 'def': 'sheet of slate; for writing with chalk', 'name': 'blackboard'}, {'frequency': 'f', 'synset': 'blanket.n.01', 'synonyms': ['blanket'], 'id': 110, 'def': 'bedding that keeps a person warm in bed', 'name': 'blanket'}, {'frequency': 'c', 'synset': 'blazer.n.01', 'synonyms': ['blazer', 'sport_jacket', 'sport_coat', 'sports_jacket', 'sports_coat'], 'id': 111, 'def': 'lightweight jacket; often striped in the colors of a club or school', 'name': 'blazer'}, {'frequency': 'f', 'synset': 'blender.n.01', 'synonyms': ['blender', 'liquidizer', 'liquidiser'], 'id': 112, 'def': 'an electrically powered mixer that mix or chop or liquefy foods', 'name': 'blender'}, {'frequency': 'r', 'synset': 'blimp.n.02', 'synonyms': ['blimp'], 'id': 113, 'def': 'a small nonrigid airship used for observation or as a barrage balloon', 'name': 'blimp'}, {'frequency': 'f', 'synset': 'blinker.n.01', 'synonyms': ['blinker', 'flasher'], 'id': 114, 'def': 'a light that flashes on and off; used as a signal or to send messages', 'name': 'blinker'}, {'frequency': 'f', 'synset': 'blouse.n.01', 'synonyms': ['blouse'], 'id': 115, 'def': 'a top worn by women', 'name': 'blouse'}, {'frequency': 'f', 'synset': 'blueberry.n.02', 'synonyms': ['blueberry'], 'id': 116, 'def': 'sweet edible dark-blue berries of blueberry plants', 'name': 'blueberry'}, {'frequency': 'r', 'synset': 'board.n.09', 'synonyms': ['gameboard'], 'id': 117, 'def': 'a flat portable surface (usually rectangular) designed for board games', 'name': 'gameboard'}, {'frequency': 'f', 'synset': 'boat.n.01', 'synonyms': ['boat', 'ship_(boat)'], 'id': 118, 'def': 'a vessel for travel on water', 'name': 'boat'}, {'frequency': 'r', 'synset': 'bob.n.05', 'synonyms': ['bob', 'bobber', 'bobfloat'], 'id': 119, 'def': 'a small float usually made of cork; attached to a fishing line', 'name': 'bob'}, {'frequency': 'c', 'synset': 'bobbin.n.01', 'synonyms': ['bobbin', 'spool', 'reel'], 'id': 120, 'def': 'a thing around which thread/tape/film or other flexible materials can be wound', 'name': 'bobbin'}, {'frequency': 'c', 'synset': 'bobby_pin.n.01', 'synonyms': ['bobby_pin', 'hairgrip'], 'id': 121, 'def': 'a flat wire hairpin used to hold bobbed hair in place', 'name': 'bobby_pin'}, {'frequency': 'c', 'synset': 'boiled_egg.n.01', 'synonyms': ['boiled_egg', 'coddled_egg'], 'id': 122, 'def': 'egg cooked briefly in the shell in gently boiling water', 'name': 'boiled_egg'}, {'frequency': 'r', 'synset': 'bolo_tie.n.01', 'synonyms': ['bolo_tie', 'bolo', 'bola_tie', 'bola'], 'id': 123, 'def': 'a cord fastened around the neck with an ornamental clasp and worn as a necktie', 'name': 'bolo_tie'}, {'frequency': 'c', 'synset': 'bolt.n.03', 'synonyms': ['deadbolt'], 'id': 124, 'def': 'the part of a lock that is engaged or withdrawn with a key', 'name': 'deadbolt'}, {'frequency': 'f', 'synset': 'bolt.n.06', 'synonyms': ['bolt'], 'id': 125, 'def': 'a screw that screws into a nut to form a fastener', 'name': 'bolt'}, {'frequency': 'r', 'synset': 'bonnet.n.01', 'synonyms': ['bonnet'], 'id': 126, 'def': 'a hat tied under the chin', 'name': 'bonnet'}, {'frequency': 'f', 'synset': 'book.n.01', 'synonyms': ['book'], 'id': 127, 'def': 'a written work or composition that has been published', 'name': 'book'}, {'frequency': 'c', 'synset': 'bookcase.n.01', 'synonyms': ['bookcase'], 'id': 128, 'def': 'a piece of furniture with shelves for storing books', 'name': 'bookcase'}, {'frequency': 'c', 'synset': 'booklet.n.01', 'synonyms': ['booklet', 'brochure', 'leaflet', 'pamphlet'], 'id': 129, 'def': 'a small book usually having a paper cover', 'name': 'booklet'}, {'frequency': 'r', 'synset': 'bookmark.n.01', 'synonyms': ['bookmark', 'bookmarker'], 'id': 130, 'def': 'a marker (a piece of paper or ribbon) placed between the pages of a book', 'name': 'bookmark'}, {'frequency': 'r', 'synset': 'boom.n.04', 'synonyms': ['boom_microphone', 'microphone_boom'], 'id': 131, 'def': 'a pole carrying an overhead microphone projected over a film or tv set', 'name': 'boom_microphone'}, {'frequency': 'f', 'synset': 'boot.n.01', 'synonyms': ['boot'], 'id': 132, 'def': 'footwear that covers the whole foot and lower leg', 'name': 'boot'}, {'frequency': 'f', 'synset': 'bottle.n.01', 'synonyms': ['bottle'], 'id': 133, 'def': 'a glass or plastic vessel used for storing drinks or other liquids', 'name': 'bottle'}, {'frequency': 'c', 'synset': 'bottle_opener.n.01', 'synonyms': ['bottle_opener'], 'id': 134, 'def': 'an opener for removing caps or corks from bottles', 'name': 'bottle_opener'}, {'frequency': 'c', 'synset': 'bouquet.n.01', 'synonyms': ['bouquet'], 'id': 135, 'def': 'an arrangement of flowers that is usually given as a present', 'name': 'bouquet'}, {'frequency': 'r', 'synset': 'bow.n.04', 'synonyms': ['bow_(weapon)'], 'id': 136, 'def': 'a weapon for shooting arrows', 'name': 'bow_(weapon)'}, {'frequency': 'f', 'synset': 'bow.n.08', 'synonyms': ['bow_(decorative_ribbons)'], 'id': 137, 'def': 'a decorative interlacing of ribbons', 'name': 'bow_(decorative_ribbons)'}, {'frequency': 'f', 'synset': 'bow_tie.n.01', 'synonyms': ['bow-tie', 'bowtie'], 'id': 138, 'def': "a man's tie that ties in a bow", 'name': 'bow-tie'}, {'frequency': 'f', 'synset': 'bowl.n.03', 'synonyms': ['bowl'], 'id': 139, 'def': 'a dish that is round and open at the top for serving foods', 'name': 'bowl'}, {'frequency': 'r', 'synset': 'bowl.n.08', 'synonyms': ['pipe_bowl'], 'id': 140, 'def': 'a small round container that is open at the top for holding tobacco', 'name': 'pipe_bowl'}, {'frequency': 'c', 'synset': 'bowler_hat.n.01', 'synonyms': ['bowler_hat', 'bowler', 'derby_hat', 'derby', 'plug_hat'], 'id': 141, 'def': 'a felt hat that is round and hard with a narrow brim', 'name': 'bowler_hat'}, {'frequency': 'r', 'synset': 'bowling_ball.n.01', 'synonyms': ['bowling_ball'], 'id': 142, 'def': 'a large ball with finger holes used in the sport of bowling', 'name': 'bowling_ball'}, {'frequency': 'f', 'synset': 'box.n.01', 'synonyms': ['box'], 'id': 143, 'def': 'a (usually rectangular) container; may have a lid', 'name': 'box'}, {'frequency': 'r', 'synset': 'boxing_glove.n.01', 'synonyms': ['boxing_glove'], 'id': 144, 'def': 'large glove coverings the fists of a fighter worn for the sport of boxing', 'name': 'boxing_glove'}, {'frequency': 'c', 'synset': 'brace.n.06', 'synonyms': ['suspenders'], 'id': 145, 'def': 'elastic straps that hold trousers up (usually used in the plural)', 'name': 'suspenders'}, {'frequency': 'f', 'synset': 'bracelet.n.02', 'synonyms': ['bracelet', 'bangle'], 'id': 146, 'def': 'jewelry worn around the wrist for decoration', 'name': 'bracelet'}, {'frequency': 'r', 'synset': 'brass.n.07', 'synonyms': ['brass_plaque'], 'id': 147, 'def': 'a memorial made of brass', 'name': 'brass_plaque'}, {'frequency': 'c', 'synset': 'brassiere.n.01', 'synonyms': ['brassiere', 'bra', 'bandeau'], 'id': 148, 'def': 'an undergarment worn by women to support their breasts', 'name': 'brassiere'}, {'frequency': 'c', 'synset': 'bread-bin.n.01', 'synonyms': ['bread-bin', 'breadbox'], 'id': 149, 'def': 'a container used to keep bread or cake in', 'name': 'bread-bin'}, {'frequency': 'f', 'synset': 'bread.n.01', 'synonyms': ['bread'], 'id': 150, 'def': 'food made from dough of flour or meal and usually raised with yeast or baking powder and then baked', 'name': 'bread'}, {'frequency': 'r', 'synset': 'breechcloth.n.01', 'synonyms': ['breechcloth', 'breechclout', 'loincloth'], 'id': 151, 'def': 'a garment that provides covering for the loins', 'name': 'breechcloth'}, {'frequency': 'f', 'synset': 'bridal_gown.n.01', 'synonyms': ['bridal_gown', 'wedding_gown', 'wedding_dress'], 'id': 152, 'def': 'a gown worn by the bride at a wedding', 'name': 'bridal_gown'}, {'frequency': 'c', 'synset': 'briefcase.n.01', 'synonyms': ['briefcase'], 'id': 153, 'def': 'a case with a handle; for carrying papers or files or books', 'name': 'briefcase'}, {'frequency': 'f', 'synset': 'broccoli.n.01', 'synonyms': ['broccoli'], 'id': 154, 'def': 'plant with dense clusters of tight green flower buds', 'name': 'broccoli'}, {'frequency': 'r', 'synset': 'brooch.n.01', 'synonyms': ['broach'], 'id': 155, 'def': 'a decorative pin worn by women', 'name': 'broach'}, {'frequency': 'c', 'synset': 'broom.n.01', 'synonyms': ['broom'], 'id': 156, 'def': 'bundle of straws or twigs attached to a long handle; used for cleaning', 'name': 'broom'}, {'frequency': 'c', 'synset': 'brownie.n.03', 'synonyms': ['brownie'], 'id': 157, 'def': 'square or bar of very rich chocolate cake usually with nuts', 'name': 'brownie'}, {'frequency': 'c', 'synset': 'brussels_sprouts.n.01', 'synonyms': ['brussels_sprouts'], 'id': 158, 'def': 'the small edible cabbage-like buds growing along a stalk', 'name': 'brussels_sprouts'}, {'frequency': 'r', 'synset': 'bubble_gum.n.01', 'synonyms': ['bubble_gum'], 'id': 159, 'def': 'a kind of chewing gum that can be blown into bubbles', 'name': 'bubble_gum'}, {'frequency': 'f', 'synset': 'bucket.n.01', 'synonyms': ['bucket', 'pail'], 'id': 160, 'def': 'a roughly cylindrical vessel that is open at the top', 'name': 'bucket'}, {'frequency': 'r', 'synset': 'buggy.n.01', 'synonyms': ['horse_buggy'], 'id': 161, 'def': 'a small lightweight carriage; drawn by a single horse', 'name': 'horse_buggy'}, {'frequency': 'c', 'synset': 'bull.n.11', 'synonyms': ['horned_cow'], 'id': 162, 'def': 'a cow with horns', 'name': 'bull'}, {'frequency': 'c', 'synset': 'bulldog.n.01', 'synonyms': ['bulldog'], 'id': 163, 'def': 'a thickset short-haired dog with a large head and strong undershot lower jaw', 'name': 'bulldog'}, {'frequency': 'r', 'synset': 'bulldozer.n.01', 'synonyms': ['bulldozer', 'dozer'], 'id': 164, 'def': 'large powerful tractor; a large blade in front flattens areas of ground', 'name': 'bulldozer'}, {'frequency': 'c', 'synset': 'bullet_train.n.01', 'synonyms': ['bullet_train'], 'id': 165, 'def': 'a high-speed passenger train', 'name': 'bullet_train'}, {'frequency': 'c', 'synset': 'bulletin_board.n.02', 'synonyms': ['bulletin_board', 'notice_board'], 'id': 166, 'def': 'a board that hangs on a wall; displays announcements', 'name': 'bulletin_board'}, {'frequency': 'r', 'synset': 'bulletproof_vest.n.01', 'synonyms': ['bulletproof_vest'], 'id': 167, 'def': 'a vest capable of resisting the impact of a bullet', 'name': 'bulletproof_vest'}, {'frequency': 'c', 'synset': 'bullhorn.n.01', 'synonyms': ['bullhorn', 'megaphone'], 'id': 168, 'def': 'a portable loudspeaker with built-in microphone and amplifier', 'name': 'bullhorn'}, {'frequency': 'f', 'synset': 'bun.n.01', 'synonyms': ['bun', 'roll'], 'id': 169, 'def': 'small rounded bread either plain or sweet', 'name': 'bun'}, {'frequency': 'c', 'synset': 'bunk_bed.n.01', 'synonyms': ['bunk_bed'], 'id': 170, 'def': 'beds built one above the other', 'name': 'bunk_bed'}, {'frequency': 'f', 'synset': 'buoy.n.01', 'synonyms': ['buoy'], 'id': 171, 'def': 'a float attached by rope to the seabed to mark channels in a harbor or underwater hazards', 'name': 'buoy'}, {'frequency': 'r', 'synset': 'burrito.n.01', 'synonyms': ['burrito'], 'id': 172, 'def': 'a flour tortilla folded around a filling', 'name': 'burrito'}, {'frequency': 'f', 'synset': 'bus.n.01', 'synonyms': ['bus_(vehicle)', 'autobus', 'charabanc', 'double-decker', 'motorbus', 'motorcoach'], 'id': 173, 'def': 'a vehicle carrying many passengers; used for public transport', 'name': 'bus_(vehicle)'}, {'frequency': 'c', 'synset': 'business_card.n.01', 'synonyms': ['business_card'], 'id': 174, 'def': "a card on which are printed the person's name and business affiliation", 'name': 'business_card'}, {'frequency': 'f', 'synset': 'butter.n.01', 'synonyms': ['butter'], 'id': 175, 'def': 'an edible emulsion of fat globules made by churning milk or cream; for cooking and table use', 'name': 'butter'}, {'frequency': 'c', 'synset': 'butterfly.n.01', 'synonyms': ['butterfly'], 'id': 176, 'def': 'insect typically having a slender body with knobbed antennae and broad colorful wings', 'name': 'butterfly'}, {'frequency': 'f', 'synset': 'button.n.01', 'synonyms': ['button'], 'id': 177, 'def': 'a round fastener sewn to shirts and coats etc to fit through buttonholes', 'name': 'button'}, {'frequency': 'f', 'synset': 'cab.n.03', 'synonyms': ['cab_(taxi)', 'taxi', 'taxicab'], 'id': 178, 'def': 'a car that takes passengers where they want to go in exchange for money', 'name': 'cab_(taxi)'}, {'frequency': 'r', 'synset': 'cabana.n.01', 'synonyms': ['cabana'], 'id': 179, 'def': 'a small tent used as a dressing room beside the sea or a swimming pool', 'name': 'cabana'}, {'frequency': 'c', 'synset': 'cabin_car.n.01', 'synonyms': ['cabin_car', 'caboose'], 'id': 180, 'def': 'a car on a freight train for use of the train crew; usually the last car on the train', 'name': 'cabin_car'}, {'frequency': 'f', 'synset': 'cabinet.n.01', 'synonyms': ['cabinet'], 'id': 181, 'def': 'a piece of furniture resembling a cupboard with doors and shelves and drawers', 'name': 'cabinet'}, {'frequency': 'r', 'synset': 'cabinet.n.03', 'synonyms': ['locker', 'storage_locker'], 'id': 182, 'def': 'a storage compartment for clothes and valuables; usually it has a lock', 'name': 'locker'}, {'frequency': 'f', 'synset': 'cake.n.03', 'synonyms': ['cake'], 'id': 183, 'def': 'baked goods made from or based on a mixture of flour, sugar, eggs, and fat', 'name': 'cake'}, {'frequency': 'c', 'synset': 'calculator.n.02', 'synonyms': ['calculator'], 'id': 184, 'def': 'a small machine that is used for mathematical calculations', 'name': 'calculator'}, {'frequency': 'f', 'synset': 'calendar.n.02', 'synonyms': ['calendar'], 'id': 185, 'def': 'a list or register of events (appointments/social events/court cases, etc)', 'name': 'calendar'}, {'frequency': 'c', 'synset': 'calf.n.01', 'synonyms': ['calf'], 'id': 186, 'def': 'young of domestic cattle', 'name': 'calf'}, {'frequency': 'c', 'synset': 'camcorder.n.01', 'synonyms': ['camcorder'], 'id': 187, 'def': 'a portable television camera and videocassette recorder', 'name': 'camcorder'}, {'frequency': 'c', 'synset': 'camel.n.01', 'synonyms': ['camel'], 'id': 188, 'def': 'cud-chewing mammal used as a draft or saddle animal in desert regions', 'name': 'camel'}, {'frequency': 'f', 'synset': 'camera.n.01', 'synonyms': ['camera'], 'id': 189, 'def': 'equipment for taking photographs', 'name': 'camera'}, {'frequency': 'c', 'synset': 'camera_lens.n.01', 'synonyms': ['camera_lens'], 'id': 190, 'def': 'a lens that focuses the image in a camera', 'name': 'camera_lens'}, {'frequency': 'c', 'synset': 'camper.n.02', 'synonyms': ['camper_(vehicle)', 'camping_bus', 'motor_home'], 'id': 191, 'def': 'a recreational vehicle equipped for camping out while traveling', 'name': 'camper_(vehicle)'}, {'frequency': 'f', 'synset': 'can.n.01', 'synonyms': ['can', 'tin_can'], 'id': 192, 'def': 'airtight sealed metal container for food or drink or paint etc.', 'name': 'can'}, {'frequency': 'c', 'synset': 'can_opener.n.01', 'synonyms': ['can_opener', 'tin_opener'], 'id': 193, 'def': 'a device for cutting cans open', 'name': 'can_opener'}, {'frequency': 'f', 'synset': 'candle.n.01', 'synonyms': ['candle', 'candlestick'], 'id': 194, 'def': 'stick of wax with a wick in the middle', 'name': 'candle'}, {'frequency': 'f', 'synset': 'candlestick.n.01', 'synonyms': ['candle_holder'], 'id': 195, 'def': 'a holder with sockets for candles', 'name': 'candle_holder'}, {'frequency': 'r', 'synset': 'candy_bar.n.01', 'synonyms': ['candy_bar'], 'id': 196, 'def': 'a candy shaped as a bar', 'name': 'candy_bar'}, {'frequency': 'c', 'synset': 'candy_cane.n.01', 'synonyms': ['candy_cane'], 'id': 197, 'def': 'a hard candy in the shape of a rod (usually with stripes)', 'name': 'candy_cane'}, {'frequency': 'c', 'synset': 'cane.n.01', 'synonyms': ['walking_cane'], 'id': 198, 'def': 'a stick that people can lean on to help them walk', 'name': 'walking_cane'}, {'frequency': 'c', 'synset': 'canister.n.02', 'synonyms': ['canister', 'cannister'], 'id': 199, 'def': 'metal container for storing dry foods such as tea or flour', 'name': 'canister'}, {'frequency': 'c', 'synset': 'canoe.n.01', 'synonyms': ['canoe'], 'id': 200, 'def': 'small and light boat; pointed at both ends; propelled with a paddle', 'name': 'canoe'}, {'frequency': 'c', 'synset': 'cantaloup.n.02', 'synonyms': ['cantaloup', 'cantaloupe'], 'id': 201, 'def': 'the fruit of a cantaloup vine; small to medium-sized melon with yellowish flesh', 'name': 'cantaloup'}, {'frequency': 'r', 'synset': 'canteen.n.01', 'synonyms': ['canteen'], 'id': 202, 'def': 'a flask for carrying water; used by soldiers or travelers', 'name': 'canteen'}, {'frequency': 'f', 'synset': 'cap.n.01', 'synonyms': ['cap_(headwear)'], 'id': 203, 'def': 'a tight-fitting headwear', 'name': 'cap_(headwear)'}, {'frequency': 'f', 'synset': 'cap.n.02', 'synonyms': ['bottle_cap', 'cap_(container_lid)'], 'id': 204, 'def': 'a top (as for a bottle)', 'name': 'bottle_cap'}, {'frequency': 'c', 'synset': 'cape.n.02', 'synonyms': ['cape'], 'id': 205, 'def': 'a sleeveless garment like a cloak but shorter', 'name': 'cape'}, {'frequency': 'c', 'synset': 'cappuccino.n.01', 'synonyms': ['cappuccino', 'coffee_cappuccino'], 'id': 206, 'def': 'equal parts of espresso and steamed milk', 'name': 'cappuccino'}, {'frequency': 'f', 'synset': 'car.n.01', 'synonyms': ['car_(automobile)', 'auto_(automobile)', 'automobile'], 'id': 207, 'def': 'a motor vehicle with four wheels', 'name': 'car_(automobile)'}, {'frequency': 'f', 'synset': 'car.n.02', 'synonyms': ['railcar_(part_of_a_train)', 'railway_car_(part_of_a_train)', 'railroad_car_(part_of_a_train)'], 'id': 208, 'def': 'a wheeled vehicle adapted to the rails of railroad (mark each individual railcar separately)', 'name': 'railcar_(part_of_a_train)'}, {'frequency': 'r', 'synset': 'car.n.04', 'synonyms': ['elevator_car'], 'id': 209, 'def': 'where passengers ride up and down', 'name': 'elevator_car'}, {'frequency': 'r', 'synset': 'car_battery.n.01', 'synonyms': ['car_battery', 'automobile_battery'], 'id': 210, 'def': 'a battery in a motor vehicle', 'name': 'car_battery'}, {'frequency': 'c', 'synset': 'card.n.02', 'synonyms': ['identity_card'], 'id': 211, 'def': 'a card certifying the identity of the bearer', 'name': 'identity_card'}, {'frequency': 'c', 'synset': 'card.n.03', 'synonyms': ['card'], 'id': 212, 'def': 'a rectangular piece of paper used to send messages (e.g. greetings or pictures)', 'name': 'card'}, {'frequency': 'c', 'synset': 'cardigan.n.01', 'synonyms': ['cardigan'], 'id': 213, 'def': 'knitted jacket that is fastened up the front with buttons or a zipper', 'name': 'cardigan'}, {'frequency': 'r', 'synset': 'cargo_ship.n.01', 'synonyms': ['cargo_ship', 'cargo_vessel'], 'id': 214, 'def': 'a ship designed to carry cargo', 'name': 'cargo_ship'}, {'frequency': 'r', 'synset': 'carnation.n.01', 'synonyms': ['carnation'], 'id': 215, 'def': 'plant with pink to purple-red spice-scented usually double flowers', 'name': 'carnation'}, {'frequency': 'c', 'synset': 'carriage.n.02', 'synonyms': ['horse_carriage'], 'id': 216, 'def': 'a vehicle with wheels drawn by one or more horses', 'name': 'horse_carriage'}, {'frequency': 'f', 'synset': 'carrot.n.01', 'synonyms': ['carrot'], 'id': 217, 'def': 'deep orange edible root of the cultivated carrot plant', 'name': 'carrot'}, {'frequency': 'f', 'synset': 'carryall.n.01', 'synonyms': ['tote_bag'], 'id': 218, 'def': 'a capacious bag or basket', 'name': 'tote_bag'}, {'frequency': 'c', 'synset': 'cart.n.01', 'synonyms': ['cart'], 'id': 219, 'def': 'a heavy open wagon usually having two wheels and drawn by an animal', 'name': 'cart'}, {'frequency': 'c', 'synset': 'carton.n.02', 'synonyms': ['carton'], 'id': 220, 'def': 'a container made of cardboard for holding food or drink', 'name': 'carton'}, {'frequency': 'c', 'synset': 'cash_register.n.01', 'synonyms': ['cash_register', 'register_(for_cash_transactions)'], 'id': 221, 'def': 'a cashbox with an adding machine to register transactions', 'name': 'cash_register'}, {'frequency': 'r', 'synset': 'casserole.n.01', 'synonyms': ['casserole'], 'id': 222, 'def': 'food cooked and served in a casserole', 'name': 'casserole'}, {'frequency': 'r', 'synset': 'cassette.n.01', 'synonyms': ['cassette'], 'id': 223, 'def': 'a container that holds a magnetic tape used for recording or playing sound or video', 'name': 'cassette'}, {'frequency': 'c', 'synset': 'cast.n.05', 'synonyms': ['cast', 'plaster_cast', 'plaster_bandage'], 'id': 224, 'def': 'bandage consisting of a firm covering that immobilizes broken bones while they heal', 'name': 'cast'}, {'frequency': 'f', 'synset': 'cat.n.01', 'synonyms': ['cat'], 'id': 225, 'def': 'a domestic house cat', 'name': 'cat'}, {'frequency': 'f', 'synset': 'cauliflower.n.02', 'synonyms': ['cauliflower'], 'id': 226, 'def': 'edible compact head of white undeveloped flowers', 'name': 'cauliflower'}, {'frequency': 'c', 'synset': 'cayenne.n.02', 'synonyms': ['cayenne_(spice)', 'cayenne_pepper_(spice)', 'red_pepper_(spice)'], 'id': 227, 'def': 'ground pods and seeds of pungent red peppers of the genus Capsicum', 'name': 'cayenne_(spice)'}, {'frequency': 'c', 'synset': 'cd_player.n.01', 'synonyms': ['CD_player'], 'id': 228, 'def': 'electronic equipment for playing compact discs (CDs)', 'name': 'CD_player'}, {'frequency': 'f', 'synset': 'celery.n.01', 'synonyms': ['celery'], 'id': 229, 'def': 'widely cultivated herb with aromatic leaf stalks that are eaten raw or cooked', 'name': 'celery'}, {'frequency': 'f', 'synset': 'cellular_telephone.n.01', 'synonyms': ['cellular_telephone', 'cellular_phone', 'cellphone', 'mobile_phone', 'smart_phone'], 'id': 230, 'def': 'a hand-held mobile telephone', 'name': 'cellular_telephone'}, {'frequency': 'r', 'synset': 'chain_mail.n.01', 'synonyms': ['chain_mail', 'ring_mail', 'chain_armor', 'chain_armour', 'ring_armor', 'ring_armour'], 'id': 231, 'def': '(Middle Ages) flexible armor made of interlinked metal rings', 'name': 'chain_mail'}, {'frequency': 'f', 'synset': 'chair.n.01', 'synonyms': ['chair'], 'id': 232, 'def': 'a seat for one person, with a support for the back', 'name': 'chair'}, {'frequency': 'r', 'synset': 'chaise_longue.n.01', 'synonyms': ['chaise_longue', 'chaise', 'daybed'], 'id': 233, 'def': 'a long chair; for reclining', 'name': 'chaise_longue'}, {'frequency': 'r', 'synset': 'chalice.n.01', 'synonyms': ['chalice'], 'id': 234, 'def': 'a bowl-shaped drinking vessel; especially the Eucharistic cup', 'name': 'chalice'}, {'frequency': 'f', 'synset': 'chandelier.n.01', 'synonyms': ['chandelier'], 'id': 235, 'def': 'branched lighting fixture; often ornate; hangs from the ceiling', 'name': 'chandelier'}, {'frequency': 'r', 'synset': 'chap.n.04', 'synonyms': ['chap'], 'id': 236, 'def': 'leather leggings without a seat; worn over trousers by cowboys to protect their legs', 'name': 'chap'}, {'frequency': 'r', 'synset': 'checkbook.n.01', 'synonyms': ['checkbook', 'chequebook'], 'id': 237, 'def': 'a book issued to holders of checking accounts', 'name': 'checkbook'}, {'frequency': 'r', 'synset': 'checkerboard.n.01', 'synonyms': ['checkerboard'], 'id': 238, 'def': 'a board having 64 squares of two alternating colors', 'name': 'checkerboard'}, {'frequency': 'c', 'synset': 'cherry.n.03', 'synonyms': ['cherry'], 'id': 239, 'def': 'a red fruit with a single hard stone', 'name': 'cherry'}, {'frequency': 'r', 'synset': 'chessboard.n.01', 'synonyms': ['chessboard'], 'id': 240, 'def': 'a checkerboard used to play chess', 'name': 'chessboard'}, {'frequency': 'c', 'synset': 'chicken.n.02', 'synonyms': ['chicken_(animal)'], 'id': 241, 'def': 'a domestic fowl bred for flesh or eggs', 'name': 'chicken_(animal)'}, {'frequency': 'c', 'synset': 'chickpea.n.01', 'synonyms': ['chickpea', 'garbanzo'], 'id': 242, 'def': 'the seed of the chickpea plant; usually dried', 'name': 'chickpea'}, {'frequency': 'c', 'synset': 'chili.n.02', 'synonyms': ['chili_(vegetable)', 'chili_pepper_(vegetable)', 'chilli_(vegetable)', 'chilly_(vegetable)', 'chile_(vegetable)'], 'id': 243, 'def': 'very hot and finely tapering pepper of special pungency', 'name': 'chili_(vegetable)'}, {'frequency': 'r', 'synset': 'chime.n.01', 'synonyms': ['chime', 'gong'], 'id': 244, 'def': 'an instrument consisting of a set of bells that are struck with a hammer', 'name': 'chime'}, {'frequency': 'r', 'synset': 'chinaware.n.01', 'synonyms': ['chinaware'], 'id': 245, 'def': 'dishware made of high quality porcelain', 'name': 'chinaware'}, {'frequency': 'c', 'synset': 'chip.n.04', 'synonyms': ['crisp_(potato_chip)', 'potato_chip'], 'id': 246, 'def': 'a thin crisp slice of potato fried in deep fat', 'name': 'crisp_(potato_chip)'}, {'frequency': 'r', 'synset': 'chip.n.06', 'synonyms': ['poker_chip'], 'id': 247, 'def': 'a small disk-shaped counter used to represent money when gambling', 'name': 'poker_chip'}, {'frequency': 'c', 'synset': 'chocolate_bar.n.01', 'synonyms': ['chocolate_bar'], 'id': 248, 'def': 'a bar of chocolate candy', 'name': 'chocolate_bar'}, {'frequency': 'c', 'synset': 'chocolate_cake.n.01', 'synonyms': ['chocolate_cake'], 'id': 249, 'def': 'cake containing chocolate', 'name': 'chocolate_cake'}, {'frequency': 'r', 'synset': 'chocolate_milk.n.01', 'synonyms': ['chocolate_milk'], 'id': 250, 'def': 'milk flavored with chocolate syrup', 'name': 'chocolate_milk'}, {'frequency': 'r', 'synset': 'chocolate_mousse.n.01', 'synonyms': ['chocolate_mousse'], 'id': 251, 'def': 'dessert mousse made with chocolate', 'name': 'chocolate_mousse'}, {'frequency': 'f', 'synset': 'choker.n.03', 'synonyms': ['choker', 'collar', 'neckband'], 'id': 252, 'def': 'shirt collar, animal collar, or tight-fitting necklace', 'name': 'choker'}, {'frequency': 'f', 'synset': 'chopping_board.n.01', 'synonyms': ['chopping_board', 'cutting_board', 'chopping_block'], 'id': 253, 'def': 'a wooden board where meats or vegetables can be cut', 'name': 'chopping_board'}, {'frequency': 'f', 'synset': 'chopstick.n.01', 'synonyms': ['chopstick'], 'id': 254, 'def': 'one of a pair of slender sticks used as oriental tableware to eat food with', 'name': 'chopstick'}, {'frequency': 'f', 'synset': 'christmas_tree.n.05', 'synonyms': ['Christmas_tree'], 'id': 255, 'def': 'an ornamented evergreen used as a Christmas decoration', 'name': 'Christmas_tree'}, {'frequency': 'c', 'synset': 'chute.n.02', 'synonyms': ['slide'], 'id': 256, 'def': 'sloping channel through which things can descend', 'name': 'slide'}, {'frequency': 'r', 'synset': 'cider.n.01', 'synonyms': ['cider', 'cyder'], 'id': 257, 'def': 'a beverage made from juice pressed from apples', 'name': 'cider'}, {'frequency': 'r', 'synset': 'cigar_box.n.01', 'synonyms': ['cigar_box'], 'id': 258, 'def': 'a box for holding cigars', 'name': 'cigar_box'}, {'frequency': 'f', 'synset': 'cigarette.n.01', 'synonyms': ['cigarette'], 'id': 259, 'def': 'finely ground tobacco wrapped in paper; for smoking', 'name': 'cigarette'}, {'frequency': 'c', 'synset': 'cigarette_case.n.01', 'synonyms': ['cigarette_case', 'cigarette_pack'], 'id': 260, 'def': 'a small flat case for holding cigarettes', 'name': 'cigarette_case'}, {'frequency': 'f', 'synset': 'cistern.n.02', 'synonyms': ['cistern', 'water_tank'], 'id': 261, 'def': 'a tank that holds the water used to flush a toilet', 'name': 'cistern'}, {'frequency': 'r', 'synset': 'clarinet.n.01', 'synonyms': ['clarinet'], 'id': 262, 'def': 'a single-reed instrument with a straight tube', 'name': 'clarinet'}, {'frequency': 'c', 'synset': 'clasp.n.01', 'synonyms': ['clasp'], 'id': 263, 'def': 'a fastener (as a buckle or hook) that is used to hold two things together', 'name': 'clasp'}, {'frequency': 'c', 'synset': 'cleansing_agent.n.01', 'synonyms': ['cleansing_agent', 'cleanser', 'cleaner'], 'id': 264, 'def': 'a preparation used in cleaning something', 'name': 'cleansing_agent'}, {'frequency': 'r', 'synset': 'cleat.n.02', 'synonyms': ['cleat_(for_securing_rope)'], 'id': 265, 'def': 'a fastener (usually with two projecting horns) around which a rope can be secured', 'name': 'cleat_(for_securing_rope)'}, {'frequency': 'r', 'synset': 'clementine.n.01', 'synonyms': ['clementine'], 'id': 266, 'def': 'a variety of mandarin orange', 'name': 'clementine'}, {'frequency': 'c', 'synset': 'clip.n.03', 'synonyms': ['clip'], 'id': 267, 'def': 'any of various small fasteners used to hold loose articles together', 'name': 'clip'}, {'frequency': 'c', 'synset': 'clipboard.n.01', 'synonyms': ['clipboard'], 'id': 268, 'def': 'a small writing board with a clip at the top for holding papers', 'name': 'clipboard'}, {'frequency': 'r', 'synset': 'clipper.n.03', 'synonyms': ['clippers_(for_plants)'], 'id': 269, 'def': 'shears for cutting grass or shrubbery (often used in the plural)', 'name': 'clippers_(for_plants)'}, {'frequency': 'r', 'synset': 'cloak.n.02', 'synonyms': ['cloak'], 'id': 270, 'def': 'a loose outer garment', 'name': 'cloak'}, {'frequency': 'f', 'synset': 'clock.n.01', 'synonyms': ['clock', 'timepiece', 'timekeeper'], 'id': 271, 'def': 'a timepiece that shows the time of day', 'name': 'clock'}, {'frequency': 'f', 'synset': 'clock_tower.n.01', 'synonyms': ['clock_tower'], 'id': 272, 'def': 'a tower with a large clock visible high up on an outside face', 'name': 'clock_tower'}, {'frequency': 'c', 'synset': 'clothes_hamper.n.01', 'synonyms': ['clothes_hamper', 'laundry_basket', 'clothes_basket'], 'id': 273, 'def': 'a hamper that holds dirty clothes to be washed or wet clothes to be dried', 'name': 'clothes_hamper'}, {'frequency': 'c', 'synset': 'clothespin.n.01', 'synonyms': ['clothespin', 'clothes_peg'], 'id': 274, 'def': 'wood or plastic fastener; for holding clothes on a clothesline', 'name': 'clothespin'}, {'frequency': 'r', 'synset': 'clutch_bag.n.01', 'synonyms': ['clutch_bag'], 'id': 275, 'def': "a woman's strapless purse that is carried in the hand", 'name': 'clutch_bag'}, {'frequency': 'f', 'synset': 'coaster.n.03', 'synonyms': ['coaster'], 'id': 276, 'def': 'a covering (plate or mat) that protects the surface of a table', 'name': 'coaster'}, {'frequency': 'f', 'synset': 'coat.n.01', 'synonyms': ['coat'], 'id': 277, 'def': 'an outer garment that has sleeves and covers the body from shoulder down', 'name': 'coat'}, {'frequency': 'c', 'synset': 'coat_hanger.n.01', 'synonyms': ['coat_hanger', 'clothes_hanger', 'dress_hanger'], 'id': 278, 'def': "a hanger that is shaped like a person's shoulders", 'name': 'coat_hanger'}, {'frequency': 'c', 'synset': 'coatrack.n.01', 'synonyms': ['coatrack', 'hatrack'], 'id': 279, 'def': 'a rack with hooks for temporarily holding coats and hats', 'name': 'coatrack'}, {'frequency': 'c', 'synset': 'cock.n.04', 'synonyms': ['cock', 'rooster'], 'id': 280, 'def': 'adult male chicken', 'name': 'cock'}, {'frequency': 'r', 'synset': 'cockroach.n.01', 'synonyms': ['cockroach'], 'id': 281, 'def': 'any of numerous chiefly nocturnal insects; some are domestic pests', 'name': 'cockroach'}, {'frequency': 'r', 'synset': 'cocoa.n.01', 'synonyms': ['cocoa_(beverage)', 'hot_chocolate_(beverage)', 'drinking_chocolate'], 'id': 282, 'def': 'a beverage made from cocoa powder and milk and sugar; usually drunk hot', 'name': 'cocoa_(beverage)'}, {'frequency': 'c', 'synset': 'coconut.n.02', 'synonyms': ['coconut', 'cocoanut'], 'id': 283, 'def': 'large hard-shelled brown oval nut with a fibrous husk', 'name': 'coconut'}, {'frequency': 'f', 'synset': 'coffee_maker.n.01', 'synonyms': ['coffee_maker', 'coffee_machine'], 'id': 284, 'def': 'a kitchen appliance for brewing coffee automatically', 'name': 'coffee_maker'}, {'frequency': 'f', 'synset': 'coffee_table.n.01', 'synonyms': ['coffee_table', 'cocktail_table'], 'id': 285, 'def': 'low table where magazines can be placed and coffee or cocktails are served', 'name': 'coffee_table'}, {'frequency': 'c', 'synset': 'coffeepot.n.01', 'synonyms': ['coffeepot'], 'id': 286, 'def': 'tall pot in which coffee is brewed', 'name': 'coffeepot'}, {'frequency': 'r', 'synset': 'coil.n.05', 'synonyms': ['coil'], 'id': 287, 'def': 'tubing that is wound in a spiral', 'name': 'coil'}, {'frequency': 'c', 'synset': 'coin.n.01', 'synonyms': ['coin'], 'id': 288, 'def': 'a flat metal piece (usually a disc) used as money', 'name': 'coin'}, {'frequency': 'c', 'synset': 'colander.n.01', 'synonyms': ['colander', 'cullender'], 'id': 289, 'def': 'bowl-shaped strainer; used to wash or drain foods', 'name': 'colander'}, {'frequency': 'c', 'synset': 'coleslaw.n.01', 'synonyms': ['coleslaw', 'slaw'], 'id': 290, 'def': 'basically shredded cabbage', 'name': 'coleslaw'}, {'frequency': 'r', 'synset': 'coloring_material.n.01', 'synonyms': ['coloring_material', 'colouring_material'], 'id': 291, 'def': 'any material used for its color', 'name': 'coloring_material'}, {'frequency': 'r', 'synset': 'combination_lock.n.01', 'synonyms': ['combination_lock'], 'id': 292, 'def': 'lock that can be opened only by turning dials in a special sequence', 'name': 'combination_lock'}, {'frequency': 'c', 'synset': 'comforter.n.04', 'synonyms': ['pacifier', 'teething_ring'], 'id': 293, 'def': 'device used for an infant to suck or bite on', 'name': 'pacifier'}, {'frequency': 'r', 'synset': 'comic_book.n.01', 'synonyms': ['comic_book'], 'id': 294, 'def': 'a magazine devoted to comic strips', 'name': 'comic_book'}, {'frequency': 'r', 'synset': 'compass.n.01', 'synonyms': ['compass'], 'id': 295, 'def': 'navigational instrument for finding directions', 'name': 'compass'}, {'frequency': 'f', 'synset': 'computer_keyboard.n.01', 'synonyms': ['computer_keyboard', 'keyboard_(computer)'], 'id': 296, 'def': 'a keyboard that is a data input device for computers', 'name': 'computer_keyboard'}, {'frequency': 'f', 'synset': 'condiment.n.01', 'synonyms': ['condiment'], 'id': 297, 'def': 'a preparation (a sauce or relish or spice) to enhance flavor or enjoyment', 'name': 'condiment'}, {'frequency': 'f', 'synset': 'cone.n.01', 'synonyms': ['cone', 'traffic_cone'], 'id': 298, 'def': 'a cone-shaped object used to direct traffic', 'name': 'cone'}, {'frequency': 'f', 'synset': 'control.n.09', 'synonyms': ['control', 'controller'], 'id': 299, 'def': 'a mechanism that controls the operation of a machine', 'name': 'control'}, {'frequency': 'r', 'synset': 'convertible.n.01', 'synonyms': ['convertible_(automobile)'], 'id': 300, 'def': 'a car that has top that can be folded or removed', 'name': 'convertible_(automobile)'}, {'frequency': 'r', 'synset': 'convertible.n.03', 'synonyms': ['sofa_bed'], 'id': 301, 'def': 'a sofa that can be converted into a bed', 'name': 'sofa_bed'}, {'frequency': 'r', 'synset': 'cooker.n.01', 'synonyms': ['cooker'], 'id': 302, 'def': 'a utensil for cooking', 'name': 'cooker'}, {'frequency': 'f', 'synset': 'cookie.n.01', 'synonyms': ['cookie', 'cooky', 'biscuit_(cookie)'], 'id': 303, 'def': "any of various small flat sweet cakes (`biscuit' is the British term)", 'name': 'cookie'}, {'frequency': 'r', 'synset': 'cooking_utensil.n.01', 'synonyms': ['cooking_utensil'], 'id': 304, 'def': 'a kitchen utensil made of material that does not melt easily; used for cooking', 'name': 'cooking_utensil'}, {'frequency': 'f', 'synset': 'cooler.n.01', 'synonyms': ['cooler_(for_food)', 'ice_chest'], 'id': 305, 'def': 'an insulated box for storing food often with ice', 'name': 'cooler_(for_food)'}, {'frequency': 'f', 'synset': 'cork.n.04', 'synonyms': ['cork_(bottle_plug)', 'bottle_cork'], 'id': 306, 'def': 'the plug in the mouth of a bottle (especially a wine bottle)', 'name': 'cork_(bottle_plug)'}, {'frequency': 'r', 'synset': 'corkboard.n.01', 'synonyms': ['corkboard'], 'id': 307, 'def': 'a sheet consisting of cork granules', 'name': 'corkboard'}, {'frequency': 'c', 'synset': 'corkscrew.n.01', 'synonyms': ['corkscrew', 'bottle_screw'], 'id': 308, 'def': 'a bottle opener that pulls corks', 'name': 'corkscrew'}, {'frequency': 'f', 'synset': 'corn.n.03', 'synonyms': ['edible_corn', 'corn', 'maize'], 'id': 309, 'def': 'ears or kernels of corn that can be prepared and served for human food (only mark individual ears or kernels)', 'name': 'edible_corn'}, {'frequency': 'r', 'synset': 'cornbread.n.01', 'synonyms': ['cornbread'], 'id': 310, 'def': 'bread made primarily of cornmeal', 'name': 'cornbread'}, {'frequency': 'c', 'synset': 'cornet.n.01', 'synonyms': ['cornet', 'horn', 'trumpet'], 'id': 311, 'def': 'a brass musical instrument with a narrow tube and a flared bell and many valves', 'name': 'cornet'}, {'frequency': 'c', 'synset': 'cornice.n.01', 'synonyms': ['cornice', 'valance', 'valance_board', 'pelmet'], 'id': 312, 'def': 'a decorative framework to conceal curtain fixtures at the top of a window casing', 'name': 'cornice'}, {'frequency': 'r', 'synset': 'cornmeal.n.01', 'synonyms': ['cornmeal'], 'id': 313, 'def': 'coarsely ground corn', 'name': 'cornmeal'}, {'frequency': 'c', 'synset': 'corset.n.01', 'synonyms': ['corset', 'girdle'], 'id': 314, 'def': "a woman's close-fitting foundation garment", 'name': 'corset'}, {'frequency': 'c', 'synset': 'costume.n.04', 'synonyms': ['costume'], 'id': 315, 'def': 'the attire characteristic of a country or a time or a social class', 'name': 'costume'}, {'frequency': 'r', 'synset': 'cougar.n.01', 'synonyms': ['cougar', 'puma', 'catamount', 'mountain_lion', 'panther'], 'id': 316, 'def': 'large American feline resembling a lion', 'name': 'cougar'}, {'frequency': 'r', 'synset': 'coverall.n.01', 'synonyms': ['coverall'], 'id': 317, 'def': 'a loose-fitting protective garment that is worn over other clothing', 'name': 'coverall'}, {'frequency': 'c', 'synset': 'cowbell.n.01', 'synonyms': ['cowbell'], 'id': 318, 'def': 'a bell hung around the neck of cow so that the cow can be easily located', 'name': 'cowbell'}, {'frequency': 'f', 'synset': 'cowboy_hat.n.01', 'synonyms': ['cowboy_hat', 'ten-gallon_hat'], 'id': 319, 'def': 'a hat with a wide brim and a soft crown; worn by American ranch hands', 'name': 'cowboy_hat'}, {'frequency': 'c', 'synset': 'crab.n.01', 'synonyms': ['crab_(animal)'], 'id': 320, 'def': 'decapod having eyes on short stalks and a broad flattened shell and pincers', 'name': 'crab_(animal)'}, {'frequency': 'r', 'synset': 'crab.n.05', 'synonyms': ['crabmeat'], 'id': 321, 'def': 'the edible flesh of any of various crabs', 'name': 'crabmeat'}, {'frequency': 'c', 'synset': 'cracker.n.01', 'synonyms': ['cracker'], 'id': 322, 'def': 'a thin crisp wafer', 'name': 'cracker'}, {'frequency': 'r', 'synset': 'crape.n.01', 'synonyms': ['crape', 'crepe', 'French_pancake'], 'id': 323, 'def': 'small very thin pancake', 'name': 'crape'}, {'frequency': 'f', 'synset': 'crate.n.01', 'synonyms': ['crate'], 'id': 324, 'def': 'a rugged box (usually made of wood); used for shipping', 'name': 'crate'}, {'frequency': 'c', 'synset': 'crayon.n.01', 'synonyms': ['crayon', 'wax_crayon'], 'id': 325, 'def': 'writing or drawing implement made of a colored stick of composition wax', 'name': 'crayon'}, {'frequency': 'r', 'synset': 'cream_pitcher.n.01', 'synonyms': ['cream_pitcher'], 'id': 326, 'def': 'a small pitcher for serving cream', 'name': 'cream_pitcher'}, {'frequency': 'c', 'synset': 'crescent_roll.n.01', 'synonyms': ['crescent_roll', 'croissant'], 'id': 327, 'def': 'very rich flaky crescent-shaped roll', 'name': 'crescent_roll'}, {'frequency': 'c', 'synset': 'crib.n.01', 'synonyms': ['crib', 'cot'], 'id': 328, 'def': 'baby bed with high sides made of slats', 'name': 'crib'}, {'frequency': 'c', 'synset': 'crock.n.03', 'synonyms': ['crock_pot', 'earthenware_jar'], 'id': 329, 'def': 'an earthen jar (made of baked clay) or a modern electric crockpot', 'name': 'crock_pot'}, {'frequency': 'f', 'synset': 'crossbar.n.01', 'synonyms': ['crossbar'], 'id': 330, 'def': 'a horizontal bar that goes across something', 'name': 'crossbar'}, {'frequency': 'r', 'synset': 'crouton.n.01', 'synonyms': ['crouton'], 'id': 331, 'def': 'a small piece of toasted or fried bread; served in soup or salads', 'name': 'crouton'}, {'frequency': 'c', 'synset': 'crow.n.01', 'synonyms': ['crow'], 'id': 332, 'def': 'black birds having a raucous call', 'name': 'crow'}, {'frequency': 'r', 'synset': 'crowbar.n.01', 'synonyms': ['crowbar', 'wrecking_bar', 'pry_bar'], 'id': 333, 'def': 'a heavy iron lever with one end forged into a wedge', 'name': 'crowbar'}, {'frequency': 'c', 'synset': 'crown.n.04', 'synonyms': ['crown'], 'id': 334, 'def': 'an ornamental jeweled headdress signifying sovereignty', 'name': 'crown'}, {'frequency': 'c', 'synset': 'crucifix.n.01', 'synonyms': ['crucifix'], 'id': 335, 'def': 'representation of the cross on which Jesus died', 'name': 'crucifix'}, {'frequency': 'c', 'synset': 'cruise_ship.n.01', 'synonyms': ['cruise_ship', 'cruise_liner'], 'id': 336, 'def': 'a passenger ship used commercially for pleasure cruises', 'name': 'cruise_ship'}, {'frequency': 'c', 'synset': 'cruiser.n.01', 'synonyms': ['police_cruiser', 'patrol_car', 'police_car', 'squad_car'], 'id': 337, 'def': 'a car in which policemen cruise the streets', 'name': 'police_cruiser'}, {'frequency': 'f', 'synset': 'crumb.n.03', 'synonyms': ['crumb'], 'id': 338, 'def': 'small piece of e.g. bread or cake', 'name': 'crumb'}, {'frequency': 'c', 'synset': 'crutch.n.01', 'synonyms': ['crutch'], 'id': 339, 'def': 'a wooden or metal staff that fits under the armpit and reaches to the ground', 'name': 'crutch'}, {'frequency': 'c', 'synset': 'cub.n.03', 'synonyms': ['cub_(animal)'], 'id': 340, 'def': 'the young of certain carnivorous mammals such as the bear or wolf or lion', 'name': 'cub_(animal)'}, {'frequency': 'c', 'synset': 'cube.n.05', 'synonyms': ['cube', 'square_block'], 'id': 341, 'def': 'a block in the (approximate) shape of a cube', 'name': 'cube'}, {'frequency': 'f', 'synset': 'cucumber.n.02', 'synonyms': ['cucumber', 'cuke'], 'id': 342, 'def': 'cylindrical green fruit with thin green rind and white flesh eaten as a vegetable', 'name': 'cucumber'}, {'frequency': 'c', 'synset': 'cufflink.n.01', 'synonyms': ['cufflink'], 'id': 343, 'def': 'jewelry consisting of linked buttons used to fasten the cuffs of a shirt', 'name': 'cufflink'}, {'frequency': 'f', 'synset': 'cup.n.01', 'synonyms': ['cup'], 'id': 344, 'def': 'a small open container usually used for drinking; usually has a handle', 'name': 'cup'}, {'frequency': 'c', 'synset': 'cup.n.08', 'synonyms': ['trophy_cup'], 'id': 345, 'def': 'a metal award or cup-shaped vessel with handles that is awarded as a trophy to a competition winner', 'name': 'trophy_cup'}, {'frequency': 'f', 'synset': 'cupboard.n.01', 'synonyms': ['cupboard', 'closet'], 'id': 346, 'def': 'a small room (or recess) or cabinet used for storage space', 'name': 'cupboard'}, {'frequency': 'f', 'synset': 'cupcake.n.01', 'synonyms': ['cupcake'], 'id': 347, 'def': 'small cake baked in a muffin tin', 'name': 'cupcake'}, {'frequency': 'r', 'synset': 'curler.n.01', 'synonyms': ['hair_curler', 'hair_roller', 'hair_crimper'], 'id': 348, 'def': 'a cylindrical tube around which the hair is wound to curl it', 'name': 'hair_curler'}, {'frequency': 'r', 'synset': 'curling_iron.n.01', 'synonyms': ['curling_iron'], 'id': 349, 'def': 'a cylindrical home appliance that heats hair that has been curled around it', 'name': 'curling_iron'}, {'frequency': 'f', 'synset': 'curtain.n.01', 'synonyms': ['curtain', 'drapery'], 'id': 350, 'def': 'hanging cloth used as a blind (especially for a window)', 'name': 'curtain'}, {'frequency': 'f', 'synset': 'cushion.n.03', 'synonyms': ['cushion'], 'id': 351, 'def': 'a soft bag filled with air or padding such as feathers or foam rubber', 'name': 'cushion'}, {'frequency': 'r', 'synset': 'cylinder.n.04', 'synonyms': ['cylinder'], 'id': 352, 'def': 'a cylindrical container', 'name': 'cylinder'}, {'frequency': 'r', 'synset': 'cymbal.n.01', 'synonyms': ['cymbal'], 'id': 353, 'def': 'a percussion instrument consisting of a concave brass disk', 'name': 'cymbal'}, {'frequency': 'r', 'synset': 'dagger.n.01', 'synonyms': ['dagger'], 'id': 354, 'def': 'a short knife with a pointed blade used for piercing or stabbing', 'name': 'dagger'}, {'frequency': 'r', 'synset': 'dalmatian.n.02', 'synonyms': ['dalmatian'], 'id': 355, 'def': 'a large breed having a smooth white coat with black or brown spots', 'name': 'dalmatian'}, {'frequency': 'c', 'synset': 'dartboard.n.01', 'synonyms': ['dartboard'], 'id': 356, 'def': 'a circular board of wood or cork used as the target in the game of darts', 'name': 'dartboard'}, {'frequency': 'r', 'synset': 'date.n.08', 'synonyms': ['date_(fruit)'], 'id': 357, 'def': 'sweet edible fruit of the date palm with a single long woody seed', 'name': 'date_(fruit)'}, {'frequency': 'f', 'synset': 'deck_chair.n.01', 'synonyms': ['deck_chair', 'beach_chair'], 'id': 358, 'def': 'a folding chair for use outdoors; a wooden frame supports a length of canvas', 'name': 'deck_chair'}, {'frequency': 'c', 'synset': 'deer.n.01', 'synonyms': ['deer', 'cervid'], 'id': 359, 'def': "distinguished from Bovidae by the male's having solid deciduous antlers", 'name': 'deer'}, {'frequency': 'c', 'synset': 'dental_floss.n.01', 'synonyms': ['dental_floss', 'floss'], 'id': 360, 'def': 'a soft thread for cleaning the spaces between the teeth', 'name': 'dental_floss'}, {'frequency': 'f', 'synset': 'desk.n.01', 'synonyms': ['desk'], 'id': 361, 'def': 'a piece of furniture with a writing surface and usually drawers or other compartments', 'name': 'desk'}, {'frequency': 'r', 'synset': 'detergent.n.01', 'synonyms': ['detergent'], 'id': 362, 'def': 'a surface-active chemical widely used in industry and laundering', 'name': 'detergent'}, {'frequency': 'c', 'synset': 'diaper.n.01', 'synonyms': ['diaper'], 'id': 363, 'def': 'garment consisting of a folded cloth drawn up between the legs and fastened at the waist', 'name': 'diaper'}, {'frequency': 'r', 'synset': 'diary.n.01', 'synonyms': ['diary', 'journal'], 'id': 364, 'def': 'yearly planner book', 'name': 'diary'}, {'frequency': 'r', 'synset': 'die.n.01', 'synonyms': ['die', 'dice'], 'id': 365, 'def': 'a small cube with 1 to 6 spots on the six faces; used in gambling', 'name': 'die'}, {'frequency': 'r', 'synset': 'dinghy.n.01', 'synonyms': ['dinghy', 'dory', 'rowboat'], 'id': 366, 'def': 'a small boat of shallow draft with seats and oars with which it is propelled', 'name': 'dinghy'}, {'frequency': 'f', 'synset': 'dining_table.n.01', 'synonyms': ['dining_table'], 'id': 367, 'def': 'a table at which meals are served', 'name': 'dining_table'}, {'frequency': 'r', 'synset': 'dinner_jacket.n.01', 'synonyms': ['tux', 'tuxedo'], 'id': 368, 'def': 'semiformal evening dress for men', 'name': 'tux'}, {'frequency': 'f', 'synset': 'dish.n.01', 'synonyms': ['dish'], 'id': 369, 'def': 'a piece of dishware normally used as a container for holding or serving food', 'name': 'dish'}, {'frequency': 'c', 'synset': 'dish.n.05', 'synonyms': ['dish_antenna'], 'id': 370, 'def': 'directional antenna consisting of a parabolic reflector', 'name': 'dish_antenna'}, {'frequency': 'c', 'synset': 'dishrag.n.01', 'synonyms': ['dishrag', 'dishcloth'], 'id': 371, 'def': 'a cloth for washing dishes or cleaning in general', 'name': 'dishrag'}, {'frequency': 'f', 'synset': 'dishtowel.n.01', 'synonyms': ['dishtowel', 'tea_towel'], 'id': 372, 'def': 'a towel for drying dishes', 'name': 'dishtowel'}, {'frequency': 'f', 'synset': 'dishwasher.n.01', 'synonyms': ['dishwasher', 'dishwashing_machine'], 'id': 373, 'def': 'a machine for washing dishes', 'name': 'dishwasher'}, {'frequency': 'r', 'synset': 'dishwasher_detergent.n.01', 'synonyms': ['dishwasher_detergent', 'dishwashing_detergent', 'dishwashing_liquid', 'dishsoap'], 'id': 374, 'def': 'dishsoap or dish detergent designed for use in dishwashers', 'name': 'dishwasher_detergent'}, {'frequency': 'f', 'synset': 'dispenser.n.01', 'synonyms': ['dispenser'], 'id': 375, 'def': 'a container so designed that the contents can be used in prescribed amounts', 'name': 'dispenser'}, {'frequency': 'r', 'synset': 'diving_board.n.01', 'synonyms': ['diving_board'], 'id': 376, 'def': 'a springboard from which swimmers can dive', 'name': 'diving_board'}, {'frequency': 'f', 'synset': 'dixie_cup.n.01', 'synonyms': ['Dixie_cup', 'paper_cup'], 'id': 377, 'def': 'a disposable cup made of paper; for holding drinks', 'name': 'Dixie_cup'}, {'frequency': 'f', 'synset': 'dog.n.01', 'synonyms': ['dog'], 'id': 378, 'def': 'a common domesticated dog', 'name': 'dog'}, {'frequency': 'f', 'synset': 'dog_collar.n.01', 'synonyms': ['dog_collar'], 'id': 379, 'def': 'a collar for a dog', 'name': 'dog_collar'}, {'frequency': 'f', 'synset': 'doll.n.01', 'synonyms': ['doll'], 'id': 380, 'def': 'a toy replica of a HUMAN (NOT AN ANIMAL)', 'name': 'doll'}, {'frequency': 'r', 'synset': 'dollar.n.02', 'synonyms': ['dollar', 'dollar_bill', 'one_dollar_bill'], 'id': 381, 'def': 'a piece of paper money worth one dollar', 'name': 'dollar'}, {'frequency': 'r', 'synset': 'dollhouse.n.01', 'synonyms': ['dollhouse', "doll's_house"], 'id': 382, 'def': "a house so small that it is likened to a child's plaything", 'name': 'dollhouse'}, {'frequency': 'c', 'synset': 'dolphin.n.02', 'synonyms': ['dolphin'], 'id': 383, 'def': 'any of various small toothed whales with a beaklike snout; larger than porpoises', 'name': 'dolphin'}, {'frequency': 'c', 'synset': 'domestic_ass.n.01', 'synonyms': ['domestic_ass', 'donkey'], 'id': 384, 'def': 'domestic beast of burden descended from the African wild ass; patient but stubborn', 'name': 'domestic_ass'}, {'frequency': 'f', 'synset': 'doorknob.n.01', 'synonyms': ['doorknob', 'doorhandle'], 'id': 385, 'def': "a knob used to open a door (often called `doorhandle' in Great Britain)", 'name': 'doorknob'}, {'frequency': 'c', 'synset': 'doormat.n.02', 'synonyms': ['doormat', 'welcome_mat'], 'id': 386, 'def': 'a mat placed outside an exterior door for wiping the shoes before entering', 'name': 'doormat'}, {'frequency': 'f', 'synset': 'doughnut.n.02', 'synonyms': ['doughnut', 'donut'], 'id': 387, 'def': 'a small ring-shaped friedcake', 'name': 'doughnut'}, {'frequency': 'r', 'synset': 'dove.n.01', 'synonyms': ['dove'], 'id': 388, 'def': 'any of numerous small pigeons', 'name': 'dove'}, {'frequency': 'r', 'synset': 'dragonfly.n.01', 'synonyms': ['dragonfly'], 'id': 389, 'def': 'slender-bodied non-stinging insect having iridescent wings that are outspread at rest', 'name': 'dragonfly'}, {'frequency': 'f', 'synset': 'drawer.n.01', 'synonyms': ['drawer'], 'id': 390, 'def': 'a boxlike container in a piece of furniture; made so as to slide in and out', 'name': 'drawer'}, {'frequency': 'c', 'synset': 'drawers.n.01', 'synonyms': ['underdrawers', 'boxers', 'boxershorts'], 'id': 391, 'def': 'underpants worn by men', 'name': 'underdrawers'}, {'frequency': 'f', 'synset': 'dress.n.01', 'synonyms': ['dress', 'frock'], 'id': 392, 'def': 'a one-piece garment for a woman; has skirt and bodice', 'name': 'dress'}, {'frequency': 'c', 'synset': 'dress_hat.n.01', 'synonyms': ['dress_hat', 'high_hat', 'opera_hat', 'silk_hat', 'top_hat'], 'id': 393, 'def': "a man's hat with a tall crown; usually covered with silk or with beaver fur", 'name': 'dress_hat'}, {'frequency': 'f', 'synset': 'dress_suit.n.01', 'synonyms': ['dress_suit'], 'id': 394, 'def': 'formalwear consisting of full evening dress for men', 'name': 'dress_suit'}, {'frequency': 'f', 'synset': 'dresser.n.05', 'synonyms': ['dresser'], 'id': 395, 'def': 'a cabinet with shelves', 'name': 'dresser'}, {'frequency': 'c', 'synset': 'drill.n.01', 'synonyms': ['drill'], 'id': 396, 'def': 'a tool with a sharp rotating point for making holes in hard materials', 'name': 'drill'}, {'frequency': 'r', 'synset': 'drone.n.04', 'synonyms': ['drone'], 'id': 397, 'def': 'an aircraft without a pilot that is operated by remote control', 'name': 'drone'}, {'frequency': 'r', 'synset': 'dropper.n.01', 'synonyms': ['dropper', 'eye_dropper'], 'id': 398, 'def': 'pipet consisting of a small tube with a vacuum bulb at one end for drawing liquid in and releasing it a drop at a time', 'name': 'dropper'}, {'frequency': 'c', 'synset': 'drum.n.01', 'synonyms': ['drum_(musical_instrument)'], 'id': 399, 'def': 'a musical percussion instrument; usually consists of a hollow cylinder with a membrane stretched across each end', 'name': 'drum_(musical_instrument)'}, {'frequency': 'r', 'synset': 'drumstick.n.02', 'synonyms': ['drumstick'], 'id': 400, 'def': 'a stick used for playing a drum', 'name': 'drumstick'}, {'frequency': 'f', 'synset': 'duck.n.01', 'synonyms': ['duck'], 'id': 401, 'def': 'small web-footed broad-billed swimming bird', 'name': 'duck'}, {'frequency': 'c', 'synset': 'duckling.n.02', 'synonyms': ['duckling'], 'id': 402, 'def': 'young duck', 'name': 'duckling'}, {'frequency': 'c', 'synset': 'duct_tape.n.01', 'synonyms': ['duct_tape'], 'id': 403, 'def': 'a wide silvery adhesive tape', 'name': 'duct_tape'}, {'frequency': 'f', 'synset': 'duffel_bag.n.01', 'synonyms': ['duffel_bag', 'duffle_bag', 'duffel', 'duffle'], 'id': 404, 'def': 'a large cylindrical bag of heavy cloth (does not include suitcases)', 'name': 'duffel_bag'}, {'frequency': 'r', 'synset': 'dumbbell.n.01', 'synonyms': ['dumbbell'], 'id': 405, 'def': 'an exercising weight with two ball-like ends connected by a short handle', 'name': 'dumbbell'}, {'frequency': 'c', 'synset': 'dumpster.n.01', 'synonyms': ['dumpster'], 'id': 406, 'def': 'a container designed to receive and transport and dump waste', 'name': 'dumpster'}, {'frequency': 'r', 'synset': 'dustpan.n.02', 'synonyms': ['dustpan'], 'id': 407, 'def': 'a short-handled receptacle into which dust can be swept', 'name': 'dustpan'}, {'frequency': 'c', 'synset': 'eagle.n.01', 'synonyms': ['eagle'], 'id': 408, 'def': 'large birds of prey noted for their broad wings and strong soaring flight', 'name': 'eagle'}, {'frequency': 'f', 'synset': 'earphone.n.01', 'synonyms': ['earphone', 'earpiece', 'headphone'], 'id': 409, 'def': 'device for listening to audio that is held over or inserted into the ear', 'name': 'earphone'}, {'frequency': 'r', 'synset': 'earplug.n.01', 'synonyms': ['earplug'], 'id': 410, 'def': 'a soft plug that is inserted into the ear canal to block sound', 'name': 'earplug'}, {'frequency': 'f', 'synset': 'earring.n.01', 'synonyms': ['earring'], 'id': 411, 'def': 'jewelry to ornament the ear', 'name': 'earring'}, {'frequency': 'c', 'synset': 'easel.n.01', 'synonyms': ['easel'], 'id': 412, 'def': "an upright tripod for displaying something (usually an artist's canvas)", 'name': 'easel'}, {'frequency': 'r', 'synset': 'eclair.n.01', 'synonyms': ['eclair'], 'id': 413, 'def': 'oblong cream puff', 'name': 'eclair'}, {'frequency': 'r', 'synset': 'eel.n.01', 'synonyms': ['eel'], 'id': 414, 'def': 'an elongate fish with fatty flesh', 'name': 'eel'}, {'frequency': 'f', 'synset': 'egg.n.02', 'synonyms': ['egg', 'eggs'], 'id': 415, 'def': 'oval reproductive body of a fowl (especially a hen) used as food', 'name': 'egg'}, {'frequency': 'r', 'synset': 'egg_roll.n.01', 'synonyms': ['egg_roll', 'spring_roll'], 'id': 416, 'def': 'minced vegetables and meat wrapped in a pancake and fried', 'name': 'egg_roll'}, {'frequency': 'c', 'synset': 'egg_yolk.n.01', 'synonyms': ['egg_yolk', 'yolk_(egg)'], 'id': 417, 'def': 'the yellow spherical part of an egg', 'name': 'egg_yolk'}, {'frequency': 'c', 'synset': 'eggbeater.n.02', 'synonyms': ['eggbeater', 'eggwhisk'], 'id': 418, 'def': 'a mixer for beating eggs or whipping cream', 'name': 'eggbeater'}, {'frequency': 'c', 'synset': 'eggplant.n.01', 'synonyms': ['eggplant', 'aubergine'], 'id': 419, 'def': 'egg-shaped vegetable having a shiny skin typically dark purple', 'name': 'eggplant'}, {'frequency': 'r', 'synset': 'electric_chair.n.01', 'synonyms': ['electric_chair'], 'id': 420, 'def': 'a chair-shaped instrument of execution by electrocution', 'name': 'electric_chair'}, {'frequency': 'f', 'synset': 'electric_refrigerator.n.01', 'synonyms': ['refrigerator'], 'id': 421, 'def': 'a refrigerator in which the coolant is pumped around by an electric motor', 'name': 'refrigerator'}, {'frequency': 'f', 'synset': 'elephant.n.01', 'synonyms': ['elephant'], 'id': 422, 'def': 'a common elephant', 'name': 'elephant'}, {'frequency': 'c', 'synset': 'elk.n.01', 'synonyms': ['elk', 'moose'], 'id': 423, 'def': 'large northern deer with enormous flattened antlers in the male', 'name': 'elk'}, {'frequency': 'c', 'synset': 'envelope.n.01', 'synonyms': ['envelope'], 'id': 424, 'def': 'a flat (usually rectangular) container for a letter, thin package, etc.', 'name': 'envelope'}, {'frequency': 'c', 'synset': 'eraser.n.01', 'synonyms': ['eraser'], 'id': 425, 'def': 'an implement used to erase something', 'name': 'eraser'}, {'frequency': 'r', 'synset': 'escargot.n.01', 'synonyms': ['escargot'], 'id': 426, 'def': 'edible snail usually served in the shell with a sauce of melted butter and garlic', 'name': 'escargot'}, {'frequency': 'r', 'synset': 'eyepatch.n.01', 'synonyms': ['eyepatch'], 'id': 427, 'def': 'a protective cloth covering for an injured eye', 'name': 'eyepatch'}, {'frequency': 'r', 'synset': 'falcon.n.01', 'synonyms': ['falcon'], 'id': 428, 'def': 'birds of prey having long pointed powerful wings adapted for swift flight', 'name': 'falcon'}, {'frequency': 'f', 'synset': 'fan.n.01', 'synonyms': ['fan'], 'id': 429, 'def': 'a device for creating a current of air by movement of a surface or surfaces', 'name': 'fan'}, {'frequency': 'f', 'synset': 'faucet.n.01', 'synonyms': ['faucet', 'spigot', 'tap'], 'id': 430, 'def': 'a regulator for controlling the flow of a liquid from a reservoir', 'name': 'faucet'}, {'frequency': 'r', 'synset': 'fedora.n.01', 'synonyms': ['fedora'], 'id': 431, 'def': 'a hat made of felt with a creased crown', 'name': 'fedora'}, {'frequency': 'r', 'synset': 'ferret.n.02', 'synonyms': ['ferret'], 'id': 432, 'def': 'domesticated albino variety of the European polecat bred for hunting rats and rabbits', 'name': 'ferret'}, {'frequency': 'c', 'synset': 'ferris_wheel.n.01', 'synonyms': ['Ferris_wheel'], 'id': 433, 'def': 'a large wheel with suspended seats that remain upright as the wheel rotates', 'name': 'Ferris_wheel'}, {'frequency': 'c', 'synset': 'ferry.n.01', 'synonyms': ['ferry', 'ferryboat'], 'id': 434, 'def': 'a boat that transports people or vehicles across a body of water and operates on a regular schedule', 'name': 'ferry'}, {'frequency': 'r', 'synset': 'fig.n.04', 'synonyms': ['fig_(fruit)'], 'id': 435, 'def': 'fleshy sweet pear-shaped yellowish or purple fruit eaten fresh or preserved or dried', 'name': 'fig_(fruit)'}, {'frequency': 'c', 'synset': 'fighter.n.02', 'synonyms': ['fighter_jet', 'fighter_aircraft', 'attack_aircraft'], 'id': 436, 'def': 'a high-speed military or naval airplane designed to destroy enemy targets', 'name': 'fighter_jet'}, {'frequency': 'f', 'synset': 'figurine.n.01', 'synonyms': ['figurine'], 'id': 437, 'def': 'a small carved or molded figure', 'name': 'figurine'}, {'frequency': 'c', 'synset': 'file.n.03', 'synonyms': ['file_cabinet', 'filing_cabinet'], 'id': 438, 'def': 'office furniture consisting of a container for keeping papers in order', 'name': 'file_cabinet'}, {'frequency': 'r', 'synset': 'file.n.04', 'synonyms': ['file_(tool)'], 'id': 439, 'def': 'a steel hand tool with small sharp teeth on some or all of its surfaces; used for smoothing wood or metal', 'name': 'file_(tool)'}, {'frequency': 'f', 'synset': 'fire_alarm.n.02', 'synonyms': ['fire_alarm', 'smoke_alarm'], 'id': 440, 'def': 'an alarm that is tripped off by fire or smoke', 'name': 'fire_alarm'}, {'frequency': 'f', 'synset': 'fire_engine.n.01', 'synonyms': ['fire_engine', 'fire_truck'], 'id': 441, 'def': 'large trucks that carry firefighters and equipment to the site of a fire', 'name': 'fire_engine'}, {'frequency': 'f', 'synset': 'fire_extinguisher.n.01', 'synonyms': ['fire_extinguisher', 'extinguisher'], 'id': 442, 'def': 'a manually operated device for extinguishing small fires', 'name': 'fire_extinguisher'}, {'frequency': 'c', 'synset': 'fire_hose.n.01', 'synonyms': ['fire_hose'], 'id': 443, 'def': 'a large hose that carries water from a fire hydrant to the site of the fire', 'name': 'fire_hose'}, {'frequency': 'f', 'synset': 'fireplace.n.01', 'synonyms': ['fireplace'], 'id': 444, 'def': 'an open recess in a wall at the base of a chimney where a fire can be built', 'name': 'fireplace'}, {'frequency': 'f', 'synset': 'fireplug.n.01', 'synonyms': ['fireplug', 'fire_hydrant', 'hydrant'], 'id': 445, 'def': 'an upright hydrant for drawing water to use in fighting a fire', 'name': 'fireplug'}, {'frequency': 'r', 'synset': 'first-aid_kit.n.01', 'synonyms': ['first-aid_kit'], 'id': 446, 'def': 'kit consisting of a set of bandages and medicines for giving first aid', 'name': 'first-aid_kit'}, {'frequency': 'f', 'synset': 'fish.n.01', 'synonyms': ['fish'], 'id': 447, 'def': 'any of various mostly cold-blooded aquatic vertebrates usually having scales and breathing through gills', 'name': 'fish'}, {'frequency': 'c', 'synset': 'fish.n.02', 'synonyms': ['fish_(food)'], 'id': 448, 'def': 'the flesh of fish used as food', 'name': 'fish_(food)'}, {'frequency': 'r', 'synset': 'fishbowl.n.02', 'synonyms': ['fishbowl', 'goldfish_bowl'], 'id': 449, 'def': 'a transparent bowl in which small fish are kept', 'name': 'fishbowl'}, {'frequency': 'c', 'synset': 'fishing_rod.n.01', 'synonyms': ['fishing_rod', 'fishing_pole'], 'id': 450, 'def': 'a rod that is used in fishing to extend the fishing line', 'name': 'fishing_rod'}, {'frequency': 'f', 'synset': 'flag.n.01', 'synonyms': ['flag'], 'id': 451, 'def': 'emblem usually consisting of a rectangular piece of cloth of distinctive design (do not include pole)', 'name': 'flag'}, {'frequency': 'f', 'synset': 'flagpole.n.02', 'synonyms': ['flagpole', 'flagstaff'], 'id': 452, 'def': 'a tall staff or pole on which a flag is raised', 'name': 'flagpole'}, {'frequency': 'c', 'synset': 'flamingo.n.01', 'synonyms': ['flamingo'], 'id': 453, 'def': 'large pink web-footed bird with down-bent bill', 'name': 'flamingo'}, {'frequency': 'c', 'synset': 'flannel.n.01', 'synonyms': ['flannel'], 'id': 454, 'def': 'a soft light woolen fabric; used for clothing', 'name': 'flannel'}, {'frequency': 'c', 'synset': 'flap.n.01', 'synonyms': ['flap'], 'id': 455, 'def': 'any broad thin covering attached at one edge, such as a mud flap next to a wheel or a flap on an airplane wing', 'name': 'flap'}, {'frequency': 'r', 'synset': 'flash.n.10', 'synonyms': ['flash', 'flashbulb'], 'id': 456, 'def': 'a lamp for providing momentary light to take a photograph', 'name': 'flash'}, {'frequency': 'c', 'synset': 'flashlight.n.01', 'synonyms': ['flashlight', 'torch'], 'id': 457, 'def': 'a small portable battery-powered electric lamp', 'name': 'flashlight'}, {'frequency': 'r', 'synset': 'fleece.n.03', 'synonyms': ['fleece'], 'id': 458, 'def': 'a soft bulky fabric with deep pile; used chiefly for clothing', 'name': 'fleece'}, {'frequency': 'f', 'synset': 'flip-flop.n.02', 'synonyms': ['flip-flop_(sandal)'], 'id': 459, 'def': 'a backless sandal held to the foot by a thong between two toes', 'name': 'flip-flop_(sandal)'}, {'frequency': 'c', 'synset': 'flipper.n.01', 'synonyms': ['flipper_(footwear)', 'fin_(footwear)'], 'id': 460, 'def': 'a shoe to aid a person in swimming', 'name': 'flipper_(footwear)'}, {'frequency': 'f', 'synset': 'flower_arrangement.n.01', 'synonyms': ['flower_arrangement', 'floral_arrangement'], 'id': 461, 'def': 'a decorative arrangement of flowers', 'name': 'flower_arrangement'}, {'frequency': 'c', 'synset': 'flute.n.02', 'synonyms': ['flute_glass', 'champagne_flute'], 'id': 462, 'def': 'a tall narrow wineglass', 'name': 'flute_glass'}, {'frequency': 'c', 'synset': 'foal.n.01', 'synonyms': ['foal'], 'id': 463, 'def': 'a young horse', 'name': 'foal'}, {'frequency': 'c', 'synset': 'folding_chair.n.01', 'synonyms': ['folding_chair'], 'id': 464, 'def': 'a chair that can be folded flat for storage', 'name': 'folding_chair'}, {'frequency': 'c', 'synset': 'food_processor.n.01', 'synonyms': ['food_processor'], 'id': 465, 'def': 'a kitchen appliance for shredding, blending, chopping, or slicing food', 'name': 'food_processor'}, {'frequency': 'c', 'synset': 'football.n.02', 'synonyms': ['football_(American)'], 'id': 466, 'def': 'the inflated oblong ball used in playing American football', 'name': 'football_(American)'}, {'frequency': 'r', 'synset': 'football_helmet.n.01', 'synonyms': ['football_helmet'], 'id': 467, 'def': 'a padded helmet with a face mask to protect the head of football players', 'name': 'football_helmet'}, {'frequency': 'c', 'synset': 'footstool.n.01', 'synonyms': ['footstool', 'footrest'], 'id': 468, 'def': 'a low seat or a stool to rest the feet of a seated person', 'name': 'footstool'}, {'frequency': 'f', 'synset': 'fork.n.01', 'synonyms': ['fork'], 'id': 469, 'def': 'cutlery used for serving and eating food', 'name': 'fork'}, {'frequency': 'c', 'synset': 'forklift.n.01', 'synonyms': ['forklift'], 'id': 470, 'def': 'an industrial vehicle with a power operated fork in front that can be inserted under loads to lift and move them', 'name': 'forklift'}, {'frequency': 'c', 'synset': 'freight_car.n.01', 'synonyms': ['freight_car'], 'id': 471, 'def': 'a railway car that carries freight', 'name': 'freight_car'}, {'frequency': 'c', 'synset': 'french_toast.n.01', 'synonyms': ['French_toast'], 'id': 472, 'def': 'bread slice dipped in egg and milk and fried', 'name': 'French_toast'}, {'frequency': 'c', 'synset': 'freshener.n.01', 'synonyms': ['freshener', 'air_freshener'], 'id': 473, 'def': 'anything that freshens air by removing or covering odor', 'name': 'freshener'}, {'frequency': 'f', 'synset': 'frisbee.n.01', 'synonyms': ['frisbee'], 'id': 474, 'def': 'a light, plastic disk propelled with a flip of the wrist for recreation or competition', 'name': 'frisbee'}, {'frequency': 'c', 'synset': 'frog.n.01', 'synonyms': ['frog', 'toad', 'toad_frog'], 'id': 475, 'def': 'a tailless stout-bodied amphibians with long hind limbs for leaping', 'name': 'frog'}, {'frequency': 'c', 'synset': 'fruit_juice.n.01', 'synonyms': ['fruit_juice'], 'id': 476, 'def': 'drink produced by squeezing or crushing fruit', 'name': 'fruit_juice'}, {'frequency': 'f', 'synset': 'frying_pan.n.01', 'synonyms': ['frying_pan', 'frypan', 'skillet'], 'id': 477, 'def': 'a pan used for frying foods', 'name': 'frying_pan'}, {'frequency': 'r', 'synset': 'fudge.n.01', 'synonyms': ['fudge'], 'id': 478, 'def': 'soft creamy candy', 'name': 'fudge'}, {'frequency': 'r', 'synset': 'funnel.n.02', 'synonyms': ['funnel'], 'id': 479, 'def': 'a cone-shaped utensil used to channel a substance into a container with a small mouth', 'name': 'funnel'}, {'frequency': 'r', 'synset': 'futon.n.01', 'synonyms': ['futon'], 'id': 480, 'def': 'a pad that is used for sleeping on the floor or on a raised frame', 'name': 'futon'}, {'frequency': 'r', 'synset': 'gag.n.02', 'synonyms': ['gag', 'muzzle'], 'id': 481, 'def': "restraint put into a person's mouth to prevent speaking or shouting", 'name': 'gag'}, {'frequency': 'r', 'synset': 'garbage.n.03', 'synonyms': ['garbage'], 'id': 482, 'def': 'a receptacle where waste can be discarded', 'name': 'garbage'}, {'frequency': 'c', 'synset': 'garbage_truck.n.01', 'synonyms': ['garbage_truck'], 'id': 483, 'def': 'a truck for collecting domestic refuse', 'name': 'garbage_truck'}, {'frequency': 'c', 'synset': 'garden_hose.n.01', 'synonyms': ['garden_hose'], 'id': 484, 'def': 'a hose used for watering a lawn or garden', 'name': 'garden_hose'}, {'frequency': 'c', 'synset': 'gargle.n.01', 'synonyms': ['gargle', 'mouthwash'], 'id': 485, 'def': 'a medicated solution used for gargling and rinsing the mouth', 'name': 'gargle'}, {'frequency': 'r', 'synset': 'gargoyle.n.02', 'synonyms': ['gargoyle'], 'id': 486, 'def': 'an ornament consisting of a grotesquely carved figure of a person or animal', 'name': 'gargoyle'}, {'frequency': 'c', 'synset': 'garlic.n.02', 'synonyms': ['garlic', 'ail'], 'id': 487, 'def': 'aromatic bulb used as seasoning', 'name': 'garlic'}, {'frequency': 'r', 'synset': 'gasmask.n.01', 'synonyms': ['gasmask', 'respirator', 'gas_helmet'], 'id': 488, 'def': 'a protective face mask with a filter', 'name': 'gasmask'}, {'frequency': 'c', 'synset': 'gazelle.n.01', 'synonyms': ['gazelle'], 'id': 489, 'def': 'small swift graceful antelope of Africa and Asia having lustrous eyes', 'name': 'gazelle'}, {'frequency': 'c', 'synset': 'gelatin.n.02', 'synonyms': ['gelatin', 'jelly'], 'id': 490, 'def': 'an edible jelly made with gelatin and used as a dessert or salad base or a coating for foods', 'name': 'gelatin'}, {'frequency': 'r', 'synset': 'gem.n.02', 'synonyms': ['gemstone'], 'id': 491, 'def': 'a crystalline rock that can be cut and polished for jewelry', 'name': 'gemstone'}, {'frequency': 'r', 'synset': 'generator.n.02', 'synonyms': ['generator'], 'id': 492, 'def': 'engine that converts mechanical energy into electrical energy by electromagnetic induction', 'name': 'generator'}, {'frequency': 'c', 'synset': 'giant_panda.n.01', 'synonyms': ['giant_panda', 'panda', 'panda_bear'], 'id': 493, 'def': 'large black-and-white herbivorous mammal of bamboo forests of China and Tibet', 'name': 'giant_panda'}, {'frequency': 'c', 'synset': 'gift_wrap.n.01', 'synonyms': ['gift_wrap'], 'id': 494, 'def': 'attractive wrapping paper suitable for wrapping gifts', 'name': 'gift_wrap'}, {'frequency': 'c', 'synset': 'ginger.n.03', 'synonyms': ['ginger', 'gingerroot'], 'id': 495, 'def': 'the root of the common ginger plant; used fresh as a seasoning', 'name': 'ginger'}, {'frequency': 'f', 'synset': 'giraffe.n.01', 'synonyms': ['giraffe'], 'id': 496, 'def': 'tall animal having a spotted coat and small horns and very long neck and legs', 'name': 'giraffe'}, {'frequency': 'c', 'synset': 'girdle.n.02', 'synonyms': ['cincture', 'sash', 'waistband', 'waistcloth'], 'id': 497, 'def': 'a band of material around the waist that strengthens a skirt or trousers', 'name': 'cincture'}, {'frequency': 'f', 'synset': 'glass.n.02', 'synonyms': ['glass_(drink_container)', 'drinking_glass'], 'id': 498, 'def': 'a container for holding liquids while drinking', 'name': 'glass_(drink_container)'}, {'frequency': 'c', 'synset': 'globe.n.03', 'synonyms': ['globe'], 'id': 499, 'def': 'a sphere on which a map (especially of the earth) is represented', 'name': 'globe'}, {'frequency': 'f', 'synset': 'glove.n.02', 'synonyms': ['glove'], 'id': 500, 'def': 'handwear covering the hand', 'name': 'glove'}, {'frequency': 'c', 'synset': 'goat.n.01', 'synonyms': ['goat'], 'id': 501, 'def': 'a common goat', 'name': 'goat'}, {'frequency': 'f', 'synset': 'goggles.n.01', 'synonyms': ['goggles'], 'id': 502, 'def': 'tight-fitting spectacles worn to protect the eyes', 'name': 'goggles'}, {'frequency': 'r', 'synset': 'goldfish.n.01', 'synonyms': ['goldfish'], 'id': 503, 'def': 'small golden or orange-red freshwater fishes used as pond or aquarium pets', 'name': 'goldfish'}, {'frequency': 'c', 'synset': 'golf_club.n.02', 'synonyms': ['golf_club', 'golf-club'], 'id': 504, 'def': 'golf equipment used by a golfer to hit a golf ball', 'name': 'golf_club'}, {'frequency': 'c', 'synset': 'golfcart.n.01', 'synonyms': ['golfcart'], 'id': 505, 'def': 'a small motor vehicle in which golfers can ride between shots', 'name': 'golfcart'}, {'frequency': 'r', 'synset': 'gondola.n.02', 'synonyms': ['gondola_(boat)'], 'id': 506, 'def': 'long narrow flat-bottomed boat propelled by sculling; traditionally used on canals of Venice', 'name': 'gondola_(boat)'}, {'frequency': 'c', 'synset': 'goose.n.01', 'synonyms': ['goose'], 'id': 507, 'def': 'loud, web-footed long-necked aquatic birds usually larger than ducks', 'name': 'goose'}, {'frequency': 'r', 'synset': 'gorilla.n.01', 'synonyms': ['gorilla'], 'id': 508, 'def': 'largest ape', 'name': 'gorilla'}, {'frequency': 'r', 'synset': 'gourd.n.02', 'synonyms': ['gourd'], 'id': 509, 'def': 'any of numerous inedible fruits with hard rinds', 'name': 'gourd'}, {'frequency': 'f', 'synset': 'grape.n.01', 'synonyms': ['grape'], 'id': 510, 'def': 'any of various juicy fruit with green or purple skins; grow in clusters', 'name': 'grape'}, {'frequency': 'c', 'synset': 'grater.n.01', 'synonyms': ['grater'], 'id': 511, 'def': 'utensil with sharp perforations for shredding foods (as vegetables or cheese)', 'name': 'grater'}, {'frequency': 'c', 'synset': 'gravestone.n.01', 'synonyms': ['gravestone', 'headstone', 'tombstone'], 'id': 512, 'def': 'a stone that is used to mark a grave', 'name': 'gravestone'}, {'frequency': 'r', 'synset': 'gravy_boat.n.01', 'synonyms': ['gravy_boat', 'gravy_holder'], 'id': 513, 'def': 'a dish (often boat-shaped) for serving gravy or sauce', 'name': 'gravy_boat'}, {'frequency': 'f', 'synset': 'green_bean.n.02', 'synonyms': ['green_bean'], 'id': 514, 'def': 'a common bean plant cultivated for its slender green edible pods', 'name': 'green_bean'}, {'frequency': 'f', 'synset': 'green_onion.n.01', 'synonyms': ['green_onion', 'spring_onion', 'scallion'], 'id': 515, 'def': 'a young onion before the bulb has enlarged', 'name': 'green_onion'}, {'frequency': 'r', 'synset': 'griddle.n.01', 'synonyms': ['griddle'], 'id': 516, 'def': 'cooking utensil consisting of a flat heated surface on which food is cooked', 'name': 'griddle'}, {'frequency': 'f', 'synset': 'grill.n.02', 'synonyms': ['grill', 'grille', 'grillwork', 'radiator_grille'], 'id': 517, 'def': 'a framework of metal bars used as a partition or a grate', 'name': 'grill'}, {'frequency': 'r', 'synset': 'grits.n.01', 'synonyms': ['grits', 'hominy_grits'], 'id': 518, 'def': 'coarsely ground corn boiled as a breakfast dish', 'name': 'grits'}, {'frequency': 'c', 'synset': 'grizzly.n.01', 'synonyms': ['grizzly', 'grizzly_bear'], 'id': 519, 'def': 'powerful brownish-yellow bear of the uplands of western North America', 'name': 'grizzly'}, {'frequency': 'c', 'synset': 'grocery_bag.n.01', 'synonyms': ['grocery_bag'], 'id': 520, 'def': "a sack for holding customer's groceries", 'name': 'grocery_bag'}, {'frequency': 'f', 'synset': 'guitar.n.01', 'synonyms': ['guitar'], 'id': 521, 'def': 'a stringed instrument usually having six strings; played by strumming or plucking', 'name': 'guitar'}, {'frequency': 'c', 'synset': 'gull.n.02', 'synonyms': ['gull', 'seagull'], 'id': 522, 'def': 'mostly white aquatic bird having long pointed wings and short legs', 'name': 'gull'}, {'frequency': 'c', 'synset': 'gun.n.01', 'synonyms': ['gun'], 'id': 523, 'def': 'a weapon that discharges a bullet at high velocity from a metal tube', 'name': 'gun'}, {'frequency': 'f', 'synset': 'hairbrush.n.01', 'synonyms': ['hairbrush'], 'id': 524, 'def': "a brush used to groom a person's hair", 'name': 'hairbrush'}, {'frequency': 'c', 'synset': 'hairnet.n.01', 'synonyms': ['hairnet'], 'id': 525, 'def': 'a small net that someone wears over their hair to keep it in place', 'name': 'hairnet'}, {'frequency': 'c', 'synset': 'hairpin.n.01', 'synonyms': ['hairpin'], 'id': 526, 'def': "a double pronged pin used to hold women's hair in place", 'name': 'hairpin'}, {'frequency': 'r', 'synset': 'halter.n.03', 'synonyms': ['halter_top'], 'id': 527, 'def': "a woman's top that fastens behind the back and neck leaving the back and arms uncovered", 'name': 'halter_top'}, {'frequency': 'f', 'synset': 'ham.n.01', 'synonyms': ['ham', 'jambon', 'gammon'], 'id': 528, 'def': 'meat cut from the thigh of a hog (usually smoked)', 'name': 'ham'}, {'frequency': 'c', 'synset': 'hamburger.n.01', 'synonyms': ['hamburger', 'beefburger', 'burger'], 'id': 529, 'def': 'a sandwich consisting of a patty of minced beef served on a bun', 'name': 'hamburger'}, {'frequency': 'c', 'synset': 'hammer.n.02', 'synonyms': ['hammer'], 'id': 530, 'def': 'a hand tool with a heavy head and a handle; used to deliver an impulsive force by striking', 'name': 'hammer'}, {'frequency': 'c', 'synset': 'hammock.n.02', 'synonyms': ['hammock'], 'id': 531, 'def': 'a hanging bed of canvas or rope netting (usually suspended between two trees)', 'name': 'hammock'}, {'frequency': 'r', 'synset': 'hamper.n.02', 'synonyms': ['hamper'], 'id': 532, 'def': 'a basket usually with a cover', 'name': 'hamper'}, {'frequency': 'c', 'synset': 'hamster.n.01', 'synonyms': ['hamster'], 'id': 533, 'def': 'short-tailed burrowing rodent with large cheek pouches', 'name': 'hamster'}, {'frequency': 'f', 'synset': 'hand_blower.n.01', 'synonyms': ['hair_dryer'], 'id': 534, 'def': 'a hand-held electric blower that can blow warm air onto the hair', 'name': 'hair_dryer'}, {'frequency': 'r', 'synset': 'hand_glass.n.01', 'synonyms': ['hand_glass', 'hand_mirror'], 'id': 535, 'def': 'a mirror intended to be held in the hand', 'name': 'hand_glass'}, {'frequency': 'f', 'synset': 'hand_towel.n.01', 'synonyms': ['hand_towel', 'face_towel'], 'id': 536, 'def': 'a small towel used to dry the hands or face', 'name': 'hand_towel'}, {'frequency': 'c', 'synset': 'handcart.n.01', 'synonyms': ['handcart', 'pushcart', 'hand_truck'], 'id': 537, 'def': 'wheeled vehicle that can be pushed by a person', 'name': 'handcart'}, {'frequency': 'r', 'synset': 'handcuff.n.01', 'synonyms': ['handcuff'], 'id': 538, 'def': 'shackle that consists of a metal loop that can be locked around the wrist', 'name': 'handcuff'}, {'frequency': 'c', 'synset': 'handkerchief.n.01', 'synonyms': ['handkerchief'], 'id': 539, 'def': 'a square piece of cloth used for wiping the eyes or nose or as a costume accessory', 'name': 'handkerchief'}, {'frequency': 'f', 'synset': 'handle.n.01', 'synonyms': ['handle', 'grip', 'handgrip'], 'id': 540, 'def': 'the appendage to an object that is designed to be held in order to use or move it', 'name': 'handle'}, {'frequency': 'r', 'synset': 'handsaw.n.01', 'synonyms': ['handsaw', "carpenter's_saw"], 'id': 541, 'def': 'a saw used with one hand for cutting wood', 'name': 'handsaw'}, {'frequency': 'r', 'synset': 'hardback.n.01', 'synonyms': ['hardback_book', 'hardcover_book'], 'id': 542, 'def': 'a book with cardboard or cloth or leather covers', 'name': 'hardback_book'}, {'frequency': 'r', 'synset': 'harmonium.n.01', 'synonyms': ['harmonium', 'organ_(musical_instrument)', 'reed_organ_(musical_instrument)'], 'id': 543, 'def': 'a free-reed instrument in which air is forced through the reeds by bellows', 'name': 'harmonium'}, {'frequency': 'f', 'synset': 'hat.n.01', 'synonyms': ['hat'], 'id': 544, 'def': 'headwear that protects the head from bad weather, sun, or worn for fashion', 'name': 'hat'}, {'frequency': 'r', 'synset': 'hatbox.n.01', 'synonyms': ['hatbox'], 'id': 545, 'def': 'a round piece of luggage for carrying hats', 'name': 'hatbox'}, {'frequency': 'c', 'synset': 'head_covering.n.01', 'synonyms': ['veil'], 'id': 546, 'def': 'a garment that covers the head OR face', 'name': 'veil'}, {'frequency': 'f', 'synset': 'headband.n.01', 'synonyms': ['headband'], 'id': 547, 'def': 'a band worn around or over the head', 'name': 'headband'}, {'frequency': 'f', 'synset': 'headboard.n.01', 'synonyms': ['headboard'], 'id': 548, 'def': 'a vertical board or panel forming the head of a bedstead', 'name': 'headboard'}, {'frequency': 'f', 'synset': 'headlight.n.01', 'synonyms': ['headlight', 'headlamp'], 'id': 549, 'def': 'a powerful light with reflector; attached to the front of an automobile or locomotive', 'name': 'headlight'}, {'frequency': 'c', 'synset': 'headscarf.n.01', 'synonyms': ['headscarf'], 'id': 550, 'def': 'a kerchief worn over the head and tied under the chin', 'name': 'headscarf'}, {'frequency': 'r', 'synset': 'headset.n.01', 'synonyms': ['headset'], 'id': 551, 'def': 'receiver consisting of a pair of headphones', 'name': 'headset'}, {'frequency': 'c', 'synset': 'headstall.n.01', 'synonyms': ['headstall_(for_horses)', 'headpiece_(for_horses)'], 'id': 552, 'def': "the band that is the part of a bridle that fits around a horse's head", 'name': 'headstall_(for_horses)'}, {'frequency': 'c', 'synset': 'heart.n.02', 'synonyms': ['heart'], 'id': 553, 'def': 'a muscular organ; its contractions move the blood through the body', 'name': 'heart'}, {'frequency': 'c', 'synset': 'heater.n.01', 'synonyms': ['heater', 'warmer'], 'id': 554, 'def': 'device that heats water or supplies warmth to a room', 'name': 'heater'}, {'frequency': 'c', 'synset': 'helicopter.n.01', 'synonyms': ['helicopter'], 'id': 555, 'def': 'an aircraft without wings that obtains its lift from the rotation of overhead blades', 'name': 'helicopter'}, {'frequency': 'f', 'synset': 'helmet.n.02', 'synonyms': ['helmet'], 'id': 556, 'def': 'a protective headgear made of hard material to resist blows', 'name': 'helmet'}, {'frequency': 'r', 'synset': 'heron.n.02', 'synonyms': ['heron'], 'id': 557, 'def': 'grey or white wading bird with long neck and long legs and (usually) long bill', 'name': 'heron'}, {'frequency': 'c', 'synset': 'highchair.n.01', 'synonyms': ['highchair', 'feeding_chair'], 'id': 558, 'def': 'a chair for feeding a very young child', 'name': 'highchair'}, {'frequency': 'f', 'synset': 'hinge.n.01', 'synonyms': ['hinge'], 'id': 559, 'def': 'a joint that holds two parts together so that one can swing relative to the other', 'name': 'hinge'}, {'frequency': 'r', 'synset': 'hippopotamus.n.01', 'synonyms': ['hippopotamus'], 'id': 560, 'def': 'massive thick-skinned animal living in or around rivers of tropical Africa', 'name': 'hippopotamus'}, {'frequency': 'r', 'synset': 'hockey_stick.n.01', 'synonyms': ['hockey_stick'], 'id': 561, 'def': 'sports implement consisting of a stick used by hockey players to move the puck', 'name': 'hockey_stick'}, {'frequency': 'c', 'synset': 'hog.n.03', 'synonyms': ['hog', 'pig'], 'id': 562, 'def': 'domestic swine', 'name': 'hog'}, {'frequency': 'f', 'synset': 'home_plate.n.01', 'synonyms': ['home_plate_(baseball)', 'home_base_(baseball)'], 'id': 563, 'def': '(baseball) a rubber slab where the batter stands; it must be touched by a base runner in order to score', 'name': 'home_plate_(baseball)'}, {'frequency': 'c', 'synset': 'honey.n.01', 'synonyms': ['honey'], 'id': 564, 'def': 'a sweet yellow liquid produced by bees', 'name': 'honey'}, {'frequency': 'f', 'synset': 'hood.n.06', 'synonyms': ['fume_hood', 'exhaust_hood'], 'id': 565, 'def': 'metal covering leading to a vent that exhausts smoke or fumes', 'name': 'fume_hood'}, {'frequency': 'f', 'synset': 'hook.n.05', 'synonyms': ['hook'], 'id': 566, 'def': 'a curved or bent implement for suspending or pulling something', 'name': 'hook'}, {'frequency': 'r', 'synset': 'hookah.n.01', 'synonyms': ['hookah', 'narghile', 'nargileh', 'sheesha', 'shisha', 'water_pipe'], 'id': 567, 'def': 'a tobacco pipe with a long flexible tube connected to a container where the smoke is cooled by passing through water', 'name': 'hookah'}, {'frequency': 'r', 'synset': 'hornet.n.01', 'synonyms': ['hornet'], 'id': 568, 'def': 'large stinging wasp', 'name': 'hornet'}, {'frequency': 'f', 'synset': 'horse.n.01', 'synonyms': ['horse'], 'id': 569, 'def': 'a common horse', 'name': 'horse'}, {'frequency': 'f', 'synset': 'hose.n.03', 'synonyms': ['hose', 'hosepipe'], 'id': 570, 'def': 'a flexible pipe for conveying a liquid or gas', 'name': 'hose'}, {'frequency': 'r', 'synset': 'hot-air_balloon.n.01', 'synonyms': ['hot-air_balloon'], 'id': 571, 'def': 'balloon for travel through the air in a basket suspended below a large bag of heated air', 'name': 'hot-air_balloon'}, {'frequency': 'r', 'synset': 'hot_plate.n.01', 'synonyms': ['hotplate'], 'id': 572, 'def': 'a portable electric appliance for heating or cooking or keeping food warm', 'name': 'hotplate'}, {'frequency': 'c', 'synset': 'hot_sauce.n.01', 'synonyms': ['hot_sauce'], 'id': 573, 'def': 'a pungent peppery sauce', 'name': 'hot_sauce'}, {'frequency': 'r', 'synset': 'hourglass.n.01', 'synonyms': ['hourglass'], 'id': 574, 'def': 'a sandglass timer that runs for sixty minutes', 'name': 'hourglass'}, {'frequency': 'r', 'synset': 'houseboat.n.01', 'synonyms': ['houseboat'], 'id': 575, 'def': 'a barge that is designed and equipped for use as a dwelling', 'name': 'houseboat'}, {'frequency': 'c', 'synset': 'hummingbird.n.01', 'synonyms': ['hummingbird'], 'id': 576, 'def': 'tiny American bird having brilliant iridescent plumage and long slender bills', 'name': 'hummingbird'}, {'frequency': 'r', 'synset': 'hummus.n.01', 'synonyms': ['hummus', 'humus', 'hommos', 'hoummos', 'humous'], 'id': 577, 'def': 'a thick spread made from mashed chickpeas', 'name': 'hummus'}, {'frequency': 'f', 'synset': 'ice_bear.n.01', 'synonyms': ['polar_bear'], 'id': 578, 'def': 'white bear of Arctic regions', 'name': 'polar_bear'}, {'frequency': 'c', 'synset': 'ice_cream.n.01', 'synonyms': ['icecream'], 'id': 579, 'def': 'frozen dessert containing cream and sugar and flavoring', 'name': 'icecream'}, {'frequency': 'r', 'synset': 'ice_lolly.n.01', 'synonyms': ['popsicle'], 'id': 580, 'def': 'ice cream or water ice on a small wooden stick', 'name': 'popsicle'}, {'frequency': 'c', 'synset': 'ice_maker.n.01', 'synonyms': ['ice_maker'], 'id': 581, 'def': 'an appliance included in some electric refrigerators for making ice cubes', 'name': 'ice_maker'}, {'frequency': 'r', 'synset': 'ice_pack.n.01', 'synonyms': ['ice_pack', 'ice_bag'], 'id': 582, 'def': 'a waterproof bag filled with ice: applied to the body (especially the head) to cool or reduce swelling', 'name': 'ice_pack'}, {'frequency': 'r', 'synset': 'ice_skate.n.01', 'synonyms': ['ice_skate'], 'id': 583, 'def': 'skate consisting of a boot with a steel blade fitted to the sole', 'name': 'ice_skate'}, {'frequency': 'c', 'synset': 'igniter.n.01', 'synonyms': ['igniter', 'ignitor', 'lighter'], 'id': 584, 'def': 'a substance or device used to start a fire', 'name': 'igniter'}, {'frequency': 'r', 'synset': 'inhaler.n.01', 'synonyms': ['inhaler', 'inhalator'], 'id': 585, 'def': 'a dispenser that produces a chemical vapor to be inhaled through mouth or nose', 'name': 'inhaler'}, {'frequency': 'f', 'synset': 'ipod.n.01', 'synonyms': ['iPod'], 'id': 586, 'def': 'a pocket-sized device used to play music files', 'name': 'iPod'}, {'frequency': 'c', 'synset': 'iron.n.04', 'synonyms': ['iron_(for_clothing)', 'smoothing_iron_(for_clothing)'], 'id': 587, 'def': 'home appliance consisting of a flat metal base that is heated and used to smooth cloth', 'name': 'iron_(for_clothing)'}, {'frequency': 'c', 'synset': 'ironing_board.n.01', 'synonyms': ['ironing_board'], 'id': 588, 'def': 'narrow padded board on collapsible supports; used for ironing clothes', 'name': 'ironing_board'}, {'frequency': 'f', 'synset': 'jacket.n.01', 'synonyms': ['jacket'], 'id': 589, 'def': 'a waist-length coat', 'name': 'jacket'}, {'frequency': 'c', 'synset': 'jam.n.01', 'synonyms': ['jam'], 'id': 590, 'def': 'preserve of crushed fruit', 'name': 'jam'}, {'frequency': 'f', 'synset': 'jar.n.01', 'synonyms': ['jar'], 'id': 591, 'def': 'a vessel (usually cylindrical) with a wide mouth and without handles', 'name': 'jar'}, {'frequency': 'f', 'synset': 'jean.n.01', 'synonyms': ['jean', 'blue_jean', 'denim'], 'id': 592, 'def': '(usually plural) close-fitting trousers of heavy denim for manual work or casual wear', 'name': 'jean'}, {'frequency': 'c', 'synset': 'jeep.n.01', 'synonyms': ['jeep', 'landrover'], 'id': 593, 'def': 'a car suitable for traveling over rough terrain', 'name': 'jeep'}, {'frequency': 'r', 'synset': 'jelly_bean.n.01', 'synonyms': ['jelly_bean', 'jelly_egg'], 'id': 594, 'def': 'sugar-glazed jellied candy', 'name': 'jelly_bean'}, {'frequency': 'f', 'synset': 'jersey.n.03', 'synonyms': ['jersey', 'T-shirt', 'tee_shirt'], 'id': 595, 'def': 'a close-fitting pullover shirt', 'name': 'jersey'}, {'frequency': 'c', 'synset': 'jet.n.01', 'synonyms': ['jet_plane', 'jet-propelled_plane'], 'id': 596, 'def': 'an airplane powered by one or more jet engines', 'name': 'jet_plane'}, {'frequency': 'r', 'synset': 'jewel.n.01', 'synonyms': ['jewel', 'gem', 'precious_stone'], 'id': 597, 'def': 'a precious or semiprecious stone incorporated into a piece of jewelry', 'name': 'jewel'}, {'frequency': 'c', 'synset': 'jewelry.n.01', 'synonyms': ['jewelry', 'jewellery'], 'id': 598, 'def': 'an adornment (as a bracelet or ring or necklace) made of precious metals and set with gems (or imitation gems)', 'name': 'jewelry'}, {'frequency': 'r', 'synset': 'joystick.n.02', 'synonyms': ['joystick'], 'id': 599, 'def': 'a control device for computers consisting of a vertical handle that can move freely in two directions', 'name': 'joystick'}, {'frequency': 'c', 'synset': 'jump_suit.n.01', 'synonyms': ['jumpsuit'], 'id': 600, 'def': "one-piece garment fashioned after a parachutist's uniform", 'name': 'jumpsuit'}, {'frequency': 'c', 'synset': 'kayak.n.01', 'synonyms': ['kayak'], 'id': 601, 'def': 'a small canoe consisting of a light frame made watertight with animal skins', 'name': 'kayak'}, {'frequency': 'r', 'synset': 'keg.n.02', 'synonyms': ['keg'], 'id': 602, 'def': 'small cask or barrel', 'name': 'keg'}, {'frequency': 'r', 'synset': 'kennel.n.01', 'synonyms': ['kennel', 'doghouse'], 'id': 603, 'def': 'outbuilding that serves as a shelter for a dog', 'name': 'kennel'}, {'frequency': 'c', 'synset': 'kettle.n.01', 'synonyms': ['kettle', 'boiler'], 'id': 604, 'def': 'a metal pot for stewing or boiling; usually has a lid', 'name': 'kettle'}, {'frequency': 'f', 'synset': 'key.n.01', 'synonyms': ['key'], 'id': 605, 'def': 'metal instrument used to unlock a lock', 'name': 'key'}, {'frequency': 'r', 'synset': 'keycard.n.01', 'synonyms': ['keycard'], 'id': 606, 'def': 'a plastic card used to gain access typically to a door', 'name': 'keycard'}, {'frequency': 'c', 'synset': 'kilt.n.01', 'synonyms': ['kilt'], 'id': 607, 'def': 'a knee-length pleated tartan skirt worn by men as part of the traditional dress in the Highlands of northern Scotland', 'name': 'kilt'}, {'frequency': 'c', 'synset': 'kimono.n.01', 'synonyms': ['kimono'], 'id': 608, 'def': 'a loose robe; imitated from robes originally worn by Japanese', 'name': 'kimono'}, {'frequency': 'f', 'synset': 'kitchen_sink.n.01', 'synonyms': ['kitchen_sink'], 'id': 609, 'def': 'a sink in a kitchen', 'name': 'kitchen_sink'}, {'frequency': 'r', 'synset': 'kitchen_table.n.01', 'synonyms': ['kitchen_table'], 'id': 610, 'def': 'a table in the kitchen', 'name': 'kitchen_table'}, {'frequency': 'f', 'synset': 'kite.n.03', 'synonyms': ['kite'], 'id': 611, 'def': 'plaything consisting of a light frame covered with tissue paper; flown in wind at end of a string', 'name': 'kite'}, {'frequency': 'c', 'synset': 'kitten.n.01', 'synonyms': ['kitten', 'kitty'], 'id': 612, 'def': 'young domestic cat', 'name': 'kitten'}, {'frequency': 'c', 'synset': 'kiwi.n.03', 'synonyms': ['kiwi_fruit'], 'id': 613, 'def': 'fuzzy brown egg-shaped fruit with slightly tart green flesh', 'name': 'kiwi_fruit'}, {'frequency': 'f', 'synset': 'knee_pad.n.01', 'synonyms': ['knee_pad'], 'id': 614, 'def': 'protective garment consisting of a pad worn by football or baseball or hockey players', 'name': 'knee_pad'}, {'frequency': 'f', 'synset': 'knife.n.01', 'synonyms': ['knife'], 'id': 615, 'def': 'tool with a blade and point used as a cutting instrument', 'name': 'knife'}, {'frequency': 'r', 'synset': 'knitting_needle.n.01', 'synonyms': ['knitting_needle'], 'id': 616, 'def': 'needle consisting of a slender rod with pointed ends; usually used in pairs', 'name': 'knitting_needle'}, {'frequency': 'f', 'synset': 'knob.n.02', 'synonyms': ['knob'], 'id': 617, 'def': 'a round handle often found on a door', 'name': 'knob'}, {'frequency': 'r', 'synset': 'knocker.n.05', 'synonyms': ['knocker_(on_a_door)', 'doorknocker'], 'id': 618, 'def': 'a device (usually metal and ornamental) attached by a hinge to a door', 'name': 'knocker_(on_a_door)'}, {'frequency': 'r', 'synset': 'koala.n.01', 'synonyms': ['koala', 'koala_bear'], 'id': 619, 'def': 'sluggish tailless Australian marsupial with grey furry ears and coat', 'name': 'koala'}, {'frequency': 'r', 'synset': 'lab_coat.n.01', 'synonyms': ['lab_coat', 'laboratory_coat'], 'id': 620, 'def': 'a light coat worn to protect clothing from substances used while working in a laboratory', 'name': 'lab_coat'}, {'frequency': 'f', 'synset': 'ladder.n.01', 'synonyms': ['ladder'], 'id': 621, 'def': 'steps consisting of two parallel members connected by rungs', 'name': 'ladder'}, {'frequency': 'c', 'synset': 'ladle.n.01', 'synonyms': ['ladle'], 'id': 622, 'def': 'a spoon-shaped vessel with a long handle frequently used to transfer liquids', 'name': 'ladle'}, {'frequency': 'c', 'synset': 'ladybug.n.01', 'synonyms': ['ladybug', 'ladybeetle', 'ladybird_beetle'], 'id': 623, 'def': 'small round bright-colored and spotted beetle, typically red and black', 'name': 'ladybug'}, {'frequency': 'f', 'synset': 'lamb.n.01', 'synonyms': ['lamb_(animal)'], 'id': 624, 'def': 'young sheep', 'name': 'lamb_(animal)'}, {'frequency': 'r', 'synset': 'lamb_chop.n.01', 'synonyms': ['lamb-chop', 'lambchop'], 'id': 625, 'def': 'chop cut from a lamb', 'name': 'lamb-chop'}, {'frequency': 'f', 'synset': 'lamp.n.02', 'synonyms': ['lamp'], 'id': 626, 'def': 'a piece of furniture holding one or more electric light bulbs', 'name': 'lamp'}, {'frequency': 'f', 'synset': 'lamppost.n.01', 'synonyms': ['lamppost'], 'id': 627, 'def': 'a metal post supporting an outdoor lamp (such as a streetlight)', 'name': 'lamppost'}, {'frequency': 'f', 'synset': 'lampshade.n.01', 'synonyms': ['lampshade'], 'id': 628, 'def': 'a protective ornamental shade used to screen a light bulb from direct view', 'name': 'lampshade'}, {'frequency': 'c', 'synset': 'lantern.n.01', 'synonyms': ['lantern'], 'id': 629, 'def': 'light in a transparent protective case', 'name': 'lantern'}, {'frequency': 'f', 'synset': 'lanyard.n.02', 'synonyms': ['lanyard', 'laniard'], 'id': 630, 'def': 'a cord worn around the neck to hold a knife or whistle, etc.', 'name': 'lanyard'}, {'frequency': 'f', 'synset': 'laptop.n.01', 'synonyms': ['laptop_computer', 'notebook_computer'], 'id': 631, 'def': 'a portable computer small enough to use in your lap', 'name': 'laptop_computer'}, {'frequency': 'r', 'synset': 'lasagna.n.01', 'synonyms': ['lasagna', 'lasagne'], 'id': 632, 'def': 'baked dish of layers of lasagna pasta with sauce and cheese and meat or vegetables', 'name': 'lasagna'}, {'frequency': 'f', 'synset': 'latch.n.02', 'synonyms': ['latch'], 'id': 633, 'def': 'a bar that can be lowered or slid into a groove to fasten a door or gate', 'name': 'latch'}, {'frequency': 'r', 'synset': 'lawn_mower.n.01', 'synonyms': ['lawn_mower'], 'id': 634, 'def': 'garden tool for mowing grass on lawns', 'name': 'lawn_mower'}, {'frequency': 'r', 'synset': 'leather.n.01', 'synonyms': ['leather'], 'id': 635, 'def': 'an animal skin made smooth and flexible by removing the hair and then tanning', 'name': 'leather'}, {'frequency': 'c', 'synset': 'legging.n.01', 'synonyms': ['legging_(clothing)', 'leging_(clothing)', 'leg_covering'], 'id': 636, 'def': 'a garment covering the leg (usually extending from the knee to the ankle)', 'name': 'legging_(clothing)'}, {'frequency': 'c', 'synset': 'lego.n.01', 'synonyms': ['Lego', 'Lego_set'], 'id': 637, 'def': "a child's plastic construction set for making models from blocks", 'name': 'Lego'}, {'frequency': 'r', 'synset': 'legume.n.02', 'synonyms': ['legume'], 'id': 638, 'def': 'the fruit or seed of bean or pea plants', 'name': 'legume'}, {'frequency': 'f', 'synset': 'lemon.n.01', 'synonyms': ['lemon'], 'id': 639, 'def': 'yellow oval fruit with juicy acidic flesh', 'name': 'lemon'}, {'frequency': 'r', 'synset': 'lemonade.n.01', 'synonyms': ['lemonade'], 'id': 640, 'def': 'sweetened beverage of diluted lemon juice', 'name': 'lemonade'}, {'frequency': 'f', 'synset': 'lettuce.n.02', 'synonyms': ['lettuce'], 'id': 641, 'def': 'leafy plant commonly eaten in salad or on sandwiches', 'name': 'lettuce'}, {'frequency': 'f', 'synset': 'license_plate.n.01', 'synonyms': ['license_plate', 'numberplate'], 'id': 642, 'def': "a plate mounted on the front and back of car and bearing the car's registration number", 'name': 'license_plate'}, {'frequency': 'f', 'synset': 'life_buoy.n.01', 'synonyms': ['life_buoy', 'lifesaver', 'life_belt', 'life_ring'], 'id': 643, 'def': 'a ring-shaped life preserver used to prevent drowning (NOT a life-jacket or vest)', 'name': 'life_buoy'}, {'frequency': 'f', 'synset': 'life_jacket.n.01', 'synonyms': ['life_jacket', 'life_vest'], 'id': 644, 'def': 'life preserver consisting of a sleeveless jacket of buoyant or inflatable design', 'name': 'life_jacket'}, {'frequency': 'f', 'synset': 'light_bulb.n.01', 'synonyms': ['lightbulb'], 'id': 645, 'def': 'lightblub/source of light', 'name': 'lightbulb'}, {'frequency': 'r', 'synset': 'lightning_rod.n.02', 'synonyms': ['lightning_rod', 'lightning_conductor'], 'id': 646, 'def': 'a metallic conductor that is attached to a high point and leads to the ground', 'name': 'lightning_rod'}, {'frequency': 'f', 'synset': 'lime.n.06', 'synonyms': ['lime'], 'id': 647, 'def': 'the green acidic fruit of any of various lime trees', 'name': 'lime'}, {'frequency': 'r', 'synset': 'limousine.n.01', 'synonyms': ['limousine'], 'id': 648, 'def': 'long luxurious car; usually driven by a chauffeur', 'name': 'limousine'}, {'frequency': 'c', 'synset': 'lion.n.01', 'synonyms': ['lion'], 'id': 649, 'def': 'large gregarious predatory cat of Africa and India', 'name': 'lion'}, {'frequency': 'c', 'synset': 'lip_balm.n.01', 'synonyms': ['lip_balm'], 'id': 650, 'def': 'a balm applied to the lips', 'name': 'lip_balm'}, {'frequency': 'r', 'synset': 'liquor.n.01', 'synonyms': ['liquor', 'spirits', 'hard_liquor', 'liqueur', 'cordial'], 'id': 651, 'def': 'liquor or beer', 'name': 'liquor'}, {'frequency': 'c', 'synset': 'lizard.n.01', 'synonyms': ['lizard'], 'id': 652, 'def': 'a reptile with usually two pairs of legs and a tapering tail', 'name': 'lizard'}, {'frequency': 'f', 'synset': 'log.n.01', 'synonyms': ['log'], 'id': 653, 'def': 'a segment of the trunk of a tree when stripped of branches', 'name': 'log'}, {'frequency': 'c', 'synset': 'lollipop.n.02', 'synonyms': ['lollipop'], 'id': 654, 'def': 'hard candy on a stick', 'name': 'lollipop'}, {'frequency': 'f', 'synset': 'loudspeaker.n.01', 'synonyms': ['speaker_(stero_equipment)'], 'id': 655, 'def': 'electronic device that produces sound often as part of a stereo system', 'name': 'speaker_(stero_equipment)'}, {'frequency': 'c', 'synset': 'love_seat.n.01', 'synonyms': ['loveseat'], 'id': 656, 'def': 'small sofa that seats two people', 'name': 'loveseat'}, {'frequency': 'r', 'synset': 'machine_gun.n.01', 'synonyms': ['machine_gun'], 'id': 657, 'def': 'a rapidly firing automatic gun', 'name': 'machine_gun'}, {'frequency': 'f', 'synset': 'magazine.n.02', 'synonyms': ['magazine'], 'id': 658, 'def': 'a paperback periodic publication', 'name': 'magazine'}, {'frequency': 'f', 'synset': 'magnet.n.01', 'synonyms': ['magnet'], 'id': 659, 'def': 'a device that attracts iron and produces a magnetic field', 'name': 'magnet'}, {'frequency': 'c', 'synset': 'mail_slot.n.01', 'synonyms': ['mail_slot'], 'id': 660, 'def': 'a slot (usually in a door) through which mail can be delivered', 'name': 'mail_slot'}, {'frequency': 'f', 'synset': 'mailbox.n.01', 'synonyms': ['mailbox_(at_home)', 'letter_box_(at_home)'], 'id': 661, 'def': 'a private box for delivery of mail', 'name': 'mailbox_(at_home)'}, {'frequency': 'r', 'synset': 'mallard.n.01', 'synonyms': ['mallard'], 'id': 662, 'def': 'wild dabbling duck from which domestic ducks are descended', 'name': 'mallard'}, {'frequency': 'r', 'synset': 'mallet.n.01', 'synonyms': ['mallet'], 'id': 663, 'def': 'a sports implement with a long handle and a hammer-like head used to hit a ball', 'name': 'mallet'}, {'frequency': 'r', 'synset': 'mammoth.n.01', 'synonyms': ['mammoth'], 'id': 664, 'def': 'any of numerous extinct elephants widely distributed in the Pleistocene', 'name': 'mammoth'}, {'frequency': 'r', 'synset': 'manatee.n.01', 'synonyms': ['manatee'], 'id': 665, 'def': 'sirenian mammal of tropical coastal waters of America', 'name': 'manatee'}, {'frequency': 'c', 'synset': 'mandarin.n.05', 'synonyms': ['mandarin_orange'], 'id': 666, 'def': 'a somewhat flat reddish-orange loose skinned citrus of China', 'name': 'mandarin_orange'}, {'frequency': 'c', 'synset': 'manger.n.01', 'synonyms': ['manger', 'trough'], 'id': 667, 'def': 'a container (usually in a barn or stable) from which cattle or horses feed', 'name': 'manger'}, {'frequency': 'f', 'synset': 'manhole.n.01', 'synonyms': ['manhole'], 'id': 668, 'def': 'a hole (usually with a flush cover) through which a person can gain access to an underground structure', 'name': 'manhole'}, {'frequency': 'f', 'synset': 'map.n.01', 'synonyms': ['map'], 'id': 669, 'def': "a diagrammatic representation of the earth's surface (or part of it)", 'name': 'map'}, {'frequency': 'f', 'synset': 'marker.n.03', 'synonyms': ['marker'], 'id': 670, 'def': 'a writing implement for making a mark', 'name': 'marker'}, {'frequency': 'r', 'synset': 'martini.n.01', 'synonyms': ['martini'], 'id': 671, 'def': 'a cocktail made of gin (or vodka) with dry vermouth', 'name': 'martini'}, {'frequency': 'r', 'synset': 'mascot.n.01', 'synonyms': ['mascot'], 'id': 672, 'def': 'a person or animal that is adopted by a team or other group as a symbolic figure', 'name': 'mascot'}, {'frequency': 'c', 'synset': 'mashed_potato.n.01', 'synonyms': ['mashed_potato'], 'id': 673, 'def': 'potato that has been peeled and boiled and then mashed', 'name': 'mashed_potato'}, {'frequency': 'r', 'synset': 'masher.n.02', 'synonyms': ['masher'], 'id': 674, 'def': 'a kitchen utensil used for mashing (e.g. potatoes)', 'name': 'masher'}, {'frequency': 'f', 'synset': 'mask.n.04', 'synonyms': ['mask', 'facemask'], 'id': 675, 'def': 'a protective covering worn over the face', 'name': 'mask'}, {'frequency': 'f', 'synset': 'mast.n.01', 'synonyms': ['mast'], 'id': 676, 'def': 'a vertical spar for supporting sails', 'name': 'mast'}, {'frequency': 'c', 'synset': 'mat.n.03', 'synonyms': ['mat_(gym_equipment)', 'gym_mat'], 'id': 677, 'def': 'sports equipment consisting of a piece of thick padding on the floor for gymnastics', 'name': 'mat_(gym_equipment)'}, {'frequency': 'r', 'synset': 'matchbox.n.01', 'synonyms': ['matchbox'], 'id': 678, 'def': 'a box for holding matches', 'name': 'matchbox'}, {'frequency': 'f', 'synset': 'mattress.n.01', 'synonyms': ['mattress'], 'id': 679, 'def': 'a thick pad filled with resilient material used as a bed or part of a bed', 'name': 'mattress'}, {'frequency': 'c', 'synset': 'measuring_cup.n.01', 'synonyms': ['measuring_cup'], 'id': 680, 'def': 'graduated cup used to measure liquid or granular ingredients', 'name': 'measuring_cup'}, {'frequency': 'c', 'synset': 'measuring_stick.n.01', 'synonyms': ['measuring_stick', 'ruler_(measuring_stick)', 'measuring_rod'], 'id': 681, 'def': 'measuring instrument having a sequence of marks at regular intervals', 'name': 'measuring_stick'}, {'frequency': 'c', 'synset': 'meatball.n.01', 'synonyms': ['meatball'], 'id': 682, 'def': 'ground meat formed into a ball and fried or simmered in broth', 'name': 'meatball'}, {'frequency': 'c', 'synset': 'medicine.n.02', 'synonyms': ['medicine'], 'id': 683, 'def': 'something that treats or prevents or alleviates the symptoms of disease', 'name': 'medicine'}, {'frequency': 'c', 'synset': 'melon.n.01', 'synonyms': ['melon'], 'id': 684, 'def': 'fruit of the gourd family having a hard rind and sweet juicy flesh', 'name': 'melon'}, {'frequency': 'f', 'synset': 'microphone.n.01', 'synonyms': ['microphone'], 'id': 685, 'def': 'device for converting sound waves into electrical energy', 'name': 'microphone'}, {'frequency': 'r', 'synset': 'microscope.n.01', 'synonyms': ['microscope'], 'id': 686, 'def': 'magnifier of the image of small objects', 'name': 'microscope'}, {'frequency': 'f', 'synset': 'microwave.n.02', 'synonyms': ['microwave_oven'], 'id': 687, 'def': 'kitchen appliance that cooks food by passing an electromagnetic wave through it', 'name': 'microwave_oven'}, {'frequency': 'r', 'synset': 'milestone.n.01', 'synonyms': ['milestone', 'milepost'], 'id': 688, 'def': 'stone post at side of a road to show distances', 'name': 'milestone'}, {'frequency': 'f', 'synset': 'milk.n.01', 'synonyms': ['milk'], 'id': 689, 'def': 'a white nutritious liquid secreted by mammals and used as food by human beings', 'name': 'milk'}, {'frequency': 'r', 'synset': 'milk_can.n.01', 'synonyms': ['milk_can'], 'id': 690, 'def': 'can for transporting milk', 'name': 'milk_can'}, {'frequency': 'r', 'synset': 'milkshake.n.01', 'synonyms': ['milkshake'], 'id': 691, 'def': 'frothy drink of milk and flavoring and sometimes fruit or ice cream', 'name': 'milkshake'}, {'frequency': 'f', 'synset': 'minivan.n.01', 'synonyms': ['minivan'], 'id': 692, 'def': 'a small box-shaped passenger van', 'name': 'minivan'}, {'frequency': 'r', 'synset': 'mint.n.05', 'synonyms': ['mint_candy'], 'id': 693, 'def': 'a candy that is flavored with a mint oil', 'name': 'mint_candy'}, {'frequency': 'f', 'synset': 'mirror.n.01', 'synonyms': ['mirror'], 'id': 694, 'def': 'polished surface that forms images by reflecting light', 'name': 'mirror'}, {'frequency': 'c', 'synset': 'mitten.n.01', 'synonyms': ['mitten'], 'id': 695, 'def': 'glove that encases the thumb separately and the other four fingers together', 'name': 'mitten'}, {'frequency': 'c', 'synset': 'mixer.n.04', 'synonyms': ['mixer_(kitchen_tool)', 'stand_mixer'], 'id': 696, 'def': 'a kitchen utensil that is used for mixing foods', 'name': 'mixer_(kitchen_tool)'}, {'frequency': 'c', 'synset': 'money.n.03', 'synonyms': ['money'], 'id': 697, 'def': 'the official currency issued by a government or national bank', 'name': 'money'}, {'frequency': 'f', 'synset': 'monitor.n.04', 'synonyms': ['monitor_(computer_equipment) computer_monitor'], 'id': 698, 'def': 'a computer monitor', 'name': 'monitor_(computer_equipment) computer_monitor'}, {'frequency': 'c', 'synset': 'monkey.n.01', 'synonyms': ['monkey'], 'id': 699, 'def': 'any of various long-tailed primates', 'name': 'monkey'}, {'frequency': 'f', 'synset': 'motor.n.01', 'synonyms': ['motor'], 'id': 700, 'def': 'machine that converts other forms of energy into mechanical energy and so imparts motion', 'name': 'motor'}, {'frequency': 'f', 'synset': 'motor_scooter.n.01', 'synonyms': ['motor_scooter', 'scooter'], 'id': 701, 'def': 'a wheeled vehicle with small wheels and a low-powered engine', 'name': 'motor_scooter'}, {'frequency': 'r', 'synset': 'motor_vehicle.n.01', 'synonyms': ['motor_vehicle', 'automotive_vehicle'], 'id': 702, 'def': 'a self-propelled wheeled vehicle that does not run on rails', 'name': 'motor_vehicle'}, {'frequency': 'f', 'synset': 'motorcycle.n.01', 'synonyms': ['motorcycle'], 'id': 703, 'def': 'a motor vehicle with two wheels and a strong frame', 'name': 'motorcycle'}, {'frequency': 'f', 'synset': 'mound.n.01', 'synonyms': ['mound_(baseball)', "pitcher's_mound"], 'id': 704, 'def': '(baseball) the slight elevation on which the pitcher stands', 'name': 'mound_(baseball)'}, {'frequency': 'f', 'synset': 'mouse.n.04', 'synonyms': ['mouse_(computer_equipment)', 'computer_mouse'], 'id': 705, 'def': 'a computer input device that controls an on-screen pointer (does not include trackpads / touchpads)', 'name': 'mouse_(computer_equipment)'}, {'frequency': 'f', 'synset': 'mousepad.n.01', 'synonyms': ['mousepad'], 'id': 706, 'def': 'a small portable pad that provides an operating surface for a computer mouse', 'name': 'mousepad'}, {'frequency': 'c', 'synset': 'muffin.n.01', 'synonyms': ['muffin'], 'id': 707, 'def': 'a sweet quick bread baked in a cup-shaped pan', 'name': 'muffin'}, {'frequency': 'f', 'synset': 'mug.n.04', 'synonyms': ['mug'], 'id': 708, 'def': 'with handle and usually cylindrical', 'name': 'mug'}, {'frequency': 'f', 'synset': 'mushroom.n.02', 'synonyms': ['mushroom'], 'id': 709, 'def': 'a common mushroom', 'name': 'mushroom'}, {'frequency': 'r', 'synset': 'music_stool.n.01', 'synonyms': ['music_stool', 'piano_stool'], 'id': 710, 'def': 'a stool for piano players; usually adjustable in height', 'name': 'music_stool'}, {'frequency': 'c', 'synset': 'musical_instrument.n.01', 'synonyms': ['musical_instrument', 'instrument_(musical)'], 'id': 711, 'def': 'any of various devices or contrivances that can be used to produce musical tones or sounds', 'name': 'musical_instrument'}, {'frequency': 'r', 'synset': 'nailfile.n.01', 'synonyms': ['nailfile'], 'id': 712, 'def': 'a small flat file for shaping the nails', 'name': 'nailfile'}, {'frequency': 'f', 'synset': 'napkin.n.01', 'synonyms': ['napkin', 'table_napkin', 'serviette'], 'id': 713, 'def': 'a small piece of table linen or paper that is used to wipe the mouth and to cover the lap in order to protect clothing', 'name': 'napkin'}, {'frequency': 'r', 'synset': 'neckerchief.n.01', 'synonyms': ['neckerchief'], 'id': 714, 'def': 'a kerchief worn around the neck', 'name': 'neckerchief'}, {'frequency': 'f', 'synset': 'necklace.n.01', 'synonyms': ['necklace'], 'id': 715, 'def': 'jewelry consisting of a cord or chain (often bearing gems) worn about the neck as an ornament', 'name': 'necklace'}, {'frequency': 'f', 'synset': 'necktie.n.01', 'synonyms': ['necktie', 'tie_(necktie)'], 'id': 716, 'def': 'neckwear consisting of a long narrow piece of material worn under a collar and tied in knot at the front', 'name': 'necktie'}, {'frequency': 'c', 'synset': 'needle.n.03', 'synonyms': ['needle'], 'id': 717, 'def': 'a sharp pointed implement (usually metal)', 'name': 'needle'}, {'frequency': 'c', 'synset': 'nest.n.01', 'synonyms': ['nest'], 'id': 718, 'def': 'a structure in which animals lay eggs or give birth to their young', 'name': 'nest'}, {'frequency': 'f', 'synset': 'newspaper.n.01', 'synonyms': ['newspaper', 'paper_(newspaper)'], 'id': 719, 'def': 'a daily or weekly publication on folded sheets containing news, articles, and advertisements', 'name': 'newspaper'}, {'frequency': 'c', 'synset': 'newsstand.n.01', 'synonyms': ['newsstand'], 'id': 720, 'def': 'a stall where newspapers and other periodicals are sold', 'name': 'newsstand'}, {'frequency': 'c', 'synset': 'nightwear.n.01', 'synonyms': ['nightshirt', 'nightwear', 'sleepwear', 'nightclothes'], 'id': 721, 'def': 'garments designed to be worn in bed', 'name': 'nightshirt'}, {'frequency': 'r', 'synset': 'nosebag.n.01', 'synonyms': ['nosebag_(for_animals)', 'feedbag'], 'id': 722, 'def': 'a canvas bag that is used to feed an animal (such as a horse); covers the muzzle and fastens at the top of the head', 'name': 'nosebag_(for_animals)'}, {'frequency': 'c', 'synset': 'noseband.n.01', 'synonyms': ['noseband_(for_animals)', 'nosepiece_(for_animals)'], 'id': 723, 'def': "a strap that is the part of a bridle that goes over the animal's nose", 'name': 'noseband_(for_animals)'}, {'frequency': 'f', 'synset': 'notebook.n.01', 'synonyms': ['notebook'], 'id': 724, 'def': 'a book with blank pages for recording notes or memoranda', 'name': 'notebook'}, {'frequency': 'c', 'synset': 'notepad.n.01', 'synonyms': ['notepad'], 'id': 725, 'def': 'a pad of paper for keeping notes', 'name': 'notepad'}, {'frequency': 'f', 'synset': 'nut.n.03', 'synonyms': ['nut'], 'id': 726, 'def': 'a small metal block (usually square or hexagonal) with internal screw thread to be fitted onto a bolt', 'name': 'nut'}, {'frequency': 'r', 'synset': 'nutcracker.n.01', 'synonyms': ['nutcracker'], 'id': 727, 'def': 'a hand tool used to crack nuts open', 'name': 'nutcracker'}, {'frequency': 'f', 'synset': 'oar.n.01', 'synonyms': ['oar'], 'id': 728, 'def': 'an implement used to propel or steer a boat', 'name': 'oar'}, {'frequency': 'r', 'synset': 'octopus.n.01', 'synonyms': ['octopus_(food)'], 'id': 729, 'def': 'tentacles of octopus prepared as food', 'name': 'octopus_(food)'}, {'frequency': 'r', 'synset': 'octopus.n.02', 'synonyms': ['octopus_(animal)'], 'id': 730, 'def': 'bottom-living cephalopod having a soft oval body with eight long tentacles', 'name': 'octopus_(animal)'}, {'frequency': 'c', 'synset': 'oil_lamp.n.01', 'synonyms': ['oil_lamp', 'kerosene_lamp', 'kerosine_lamp'], 'id': 731, 'def': 'a lamp that burns oil (as kerosine) for light', 'name': 'oil_lamp'}, {'frequency': 'c', 'synset': 'olive_oil.n.01', 'synonyms': ['olive_oil'], 'id': 732, 'def': 'oil from olives', 'name': 'olive_oil'}, {'frequency': 'r', 'synset': 'omelet.n.01', 'synonyms': ['omelet', 'omelette'], 'id': 733, 'def': 'beaten eggs cooked until just set; may be folded around e.g. ham or cheese or jelly', 'name': 'omelet'}, {'frequency': 'f', 'synset': 'onion.n.01', 'synonyms': ['onion'], 'id': 734, 'def': 'the bulb of an onion plant', 'name': 'onion'}, {'frequency': 'f', 'synset': 'orange.n.01', 'synonyms': ['orange_(fruit)'], 'id': 735, 'def': 'orange (FRUIT of an orange tree)', 'name': 'orange_(fruit)'}, {'frequency': 'c', 'synset': 'orange_juice.n.01', 'synonyms': ['orange_juice'], 'id': 736, 'def': 'bottled or freshly squeezed juice of oranges', 'name': 'orange_juice'}, {'frequency': 'c', 'synset': 'ostrich.n.02', 'synonyms': ['ostrich'], 'id': 737, 'def': 'fast-running African flightless bird with two-toed feet; largest living bird', 'name': 'ostrich'}, {'frequency': 'f', 'synset': 'ottoman.n.03', 'synonyms': ['ottoman', 'pouf', 'pouffe', 'hassock'], 'id': 738, 'def': 'a thick standalone cushion used as a seat or footrest, often next to a chair', 'name': 'ottoman'}, {'frequency': 'f', 'synset': 'oven.n.01', 'synonyms': ['oven'], 'id': 739, 'def': 'kitchen appliance used for baking or roasting', 'name': 'oven'}, {'frequency': 'c', 'synset': 'overall.n.01', 'synonyms': ['overalls_(clothing)'], 'id': 740, 'def': 'work clothing consisting of denim trousers usually with a bib and shoulder straps', 'name': 'overalls_(clothing)'}, {'frequency': 'c', 'synset': 'owl.n.01', 'synonyms': ['owl'], 'id': 741, 'def': 'nocturnal bird of prey with hawk-like beak and claws and large head with front-facing eyes', 'name': 'owl'}, {'frequency': 'c', 'synset': 'packet.n.03', 'synonyms': ['packet'], 'id': 742, 'def': 'a small package or bundle', 'name': 'packet'}, {'frequency': 'r', 'synset': 'pad.n.03', 'synonyms': ['inkpad', 'inking_pad', 'stamp_pad'], 'id': 743, 'def': 'absorbent material saturated with ink used to transfer ink evenly to a rubber stamp', 'name': 'inkpad'}, {'frequency': 'c', 'synset': 'pad.n.04', 'synonyms': ['pad'], 'id': 744, 'def': 'mostly arm/knee pads labeled', 'name': 'pad'}, {'frequency': 'f', 'synset': 'paddle.n.04', 'synonyms': ['paddle', 'boat_paddle'], 'id': 745, 'def': 'a short light oar used without an oarlock to propel a canoe or small boat', 'name': 'paddle'}, {'frequency': 'c', 'synset': 'padlock.n.01', 'synonyms': ['padlock'], 'id': 746, 'def': 'a detachable, portable lock', 'name': 'padlock'}, {'frequency': 'c', 'synset': 'paintbrush.n.01', 'synonyms': ['paintbrush'], 'id': 747, 'def': 'a brush used as an applicator to apply paint', 'name': 'paintbrush'}, {'frequency': 'f', 'synset': 'painting.n.01', 'synonyms': ['painting'], 'id': 748, 'def': 'graphic art consisting of an artistic composition made by applying paints to a surface', 'name': 'painting'}, {'frequency': 'f', 'synset': 'pajama.n.02', 'synonyms': ['pajamas', 'pyjamas'], 'id': 749, 'def': 'loose-fitting nightclothes worn for sleeping or lounging', 'name': 'pajamas'}, {'frequency': 'c', 'synset': 'palette.n.02', 'synonyms': ['palette', 'pallet'], 'id': 750, 'def': 'board that provides a flat surface on which artists mix paints and the range of colors used', 'name': 'palette'}, {'frequency': 'f', 'synset': 'pan.n.01', 'synonyms': ['pan_(for_cooking)', 'cooking_pan'], 'id': 751, 'def': 'cooking utensil consisting of a wide metal vessel', 'name': 'pan_(for_cooking)'}, {'frequency': 'r', 'synset': 'pan.n.03', 'synonyms': ['pan_(metal_container)'], 'id': 752, 'def': 'shallow container made of metal', 'name': 'pan_(metal_container)'}, {'frequency': 'c', 'synset': 'pancake.n.01', 'synonyms': ['pancake'], 'id': 753, 'def': 'a flat cake of thin batter fried on both sides on a griddle', 'name': 'pancake'}, {'frequency': 'r', 'synset': 'pantyhose.n.01', 'synonyms': ['pantyhose'], 'id': 754, 'def': "a woman's tights consisting of underpants and stockings", 'name': 'pantyhose'}, {'frequency': 'r', 'synset': 'papaya.n.02', 'synonyms': ['papaya'], 'id': 755, 'def': 'large oval melon-like tropical fruit with yellowish flesh', 'name': 'papaya'}, {'frequency': 'f', 'synset': 'paper_plate.n.01', 'synonyms': ['paper_plate'], 'id': 756, 'def': 'a disposable plate made of cardboard', 'name': 'paper_plate'}, {'frequency': 'f', 'synset': 'paper_towel.n.01', 'synonyms': ['paper_towel'], 'id': 757, 'def': 'a disposable towel made of absorbent paper', 'name': 'paper_towel'}, {'frequency': 'r', 'synset': 'paperback_book.n.01', 'synonyms': ['paperback_book', 'paper-back_book', 'softback_book', 'soft-cover_book'], 'id': 758, 'def': 'a book with paper covers', 'name': 'paperback_book'}, {'frequency': 'r', 'synset': 'paperweight.n.01', 'synonyms': ['paperweight'], 'id': 759, 'def': 'a weight used to hold down a stack of papers', 'name': 'paperweight'}, {'frequency': 'c', 'synset': 'parachute.n.01', 'synonyms': ['parachute'], 'id': 760, 'def': 'rescue equipment consisting of a device that fills with air and retards your fall', 'name': 'parachute'}, {'frequency': 'c', 'synset': 'parakeet.n.01', 'synonyms': ['parakeet', 'parrakeet', 'parroket', 'paraquet', 'paroquet', 'parroquet'], 'id': 761, 'def': 'any of numerous small slender long-tailed parrots', 'name': 'parakeet'}, {'frequency': 'c', 'synset': 'parasail.n.01', 'synonyms': ['parasail_(sports)'], 'id': 762, 'def': 'parachute that will lift a person up into the air when it is towed by a motorboat or a car', 'name': 'parasail_(sports)'}, {'frequency': 'c', 'synset': 'parasol.n.01', 'synonyms': ['parasol', 'sunshade'], 'id': 763, 'def': 'a handheld collapsible source of shade', 'name': 'parasol'}, {'frequency': 'r', 'synset': 'parchment.n.01', 'synonyms': ['parchment'], 'id': 764, 'def': 'a superior paper resembling sheepskin', 'name': 'parchment'}, {'frequency': 'c', 'synset': 'parka.n.01', 'synonyms': ['parka', 'anorak'], 'id': 765, 'def': "a kind of heavy jacket (`windcheater' is a British term)", 'name': 'parka'}, {'frequency': 'f', 'synset': 'parking_meter.n.01', 'synonyms': ['parking_meter'], 'id': 766, 'def': 'a coin-operated timer located next to a parking space', 'name': 'parking_meter'}, {'frequency': 'c', 'synset': 'parrot.n.01', 'synonyms': ['parrot'], 'id': 767, 'def': 'usually brightly colored tropical birds with short hooked beaks and the ability to mimic sounds', 'name': 'parrot'}, {'frequency': 'c', 'synset': 'passenger_car.n.01', 'synonyms': ['passenger_car_(part_of_a_train)', 'coach_(part_of_a_train)'], 'id': 768, 'def': 'a railcar where passengers ride', 'name': 'passenger_car_(part_of_a_train)'}, {'frequency': 'r', 'synset': 'passenger_ship.n.01', 'synonyms': ['passenger_ship'], 'id': 769, 'def': 'a ship built to carry passengers', 'name': 'passenger_ship'}, {'frequency': 'c', 'synset': 'passport.n.02', 'synonyms': ['passport'], 'id': 770, 'def': 'a document issued by a country to a citizen allowing that person to travel abroad and re-enter the home country', 'name': 'passport'}, {'frequency': 'f', 'synset': 'pastry.n.02', 'synonyms': ['pastry'], 'id': 771, 'def': 'any of various baked foods made of dough or batter', 'name': 'pastry'}, {'frequency': 'r', 'synset': 'patty.n.01', 'synonyms': ['patty_(food)'], 'id': 772, 'def': 'small flat mass of chopped food', 'name': 'patty_(food)'}, {'frequency': 'c', 'synset': 'pea.n.01', 'synonyms': ['pea_(food)'], 'id': 773, 'def': 'seed of a pea plant used for food', 'name': 'pea_(food)'}, {'frequency': 'c', 'synset': 'peach.n.03', 'synonyms': ['peach'], 'id': 774, 'def': 'downy juicy fruit with sweet yellowish or whitish flesh', 'name': 'peach'}, {'frequency': 'c', 'synset': 'peanut_butter.n.01', 'synonyms': ['peanut_butter'], 'id': 775, 'def': 'a spread made from ground peanuts', 'name': 'peanut_butter'}, {'frequency': 'f', 'synset': 'pear.n.01', 'synonyms': ['pear'], 'id': 776, 'def': 'sweet juicy gritty-textured fruit available in many varieties', 'name': 'pear'}, {'frequency': 'c', 'synset': 'peeler.n.03', 'synonyms': ['peeler_(tool_for_fruit_and_vegetables)'], 'id': 777, 'def': 'a device for peeling vegetables or fruits', 'name': 'peeler_(tool_for_fruit_and_vegetables)'}, {'frequency': 'r', 'synset': 'peg.n.04', 'synonyms': ['wooden_leg', 'pegleg'], 'id': 778, 'def': 'a prosthesis that replaces a missing leg', 'name': 'wooden_leg'}, {'frequency': 'r', 'synset': 'pegboard.n.01', 'synonyms': ['pegboard'], 'id': 779, 'def': 'a board perforated with regularly spaced holes into which pegs can be fitted', 'name': 'pegboard'}, {'frequency': 'c', 'synset': 'pelican.n.01', 'synonyms': ['pelican'], 'id': 780, 'def': 'large long-winged warm-water seabird having a large bill with a distensible pouch for fish', 'name': 'pelican'}, {'frequency': 'f', 'synset': 'pen.n.01', 'synonyms': ['pen'], 'id': 781, 'def': 'a writing implement with a point from which ink flows', 'name': 'pen'}, {'frequency': 'f', 'synset': 'pencil.n.01', 'synonyms': ['pencil'], 'id': 782, 'def': 'a thin cylindrical pointed writing implement made of wood and graphite', 'name': 'pencil'}, {'frequency': 'r', 'synset': 'pencil_box.n.01', 'synonyms': ['pencil_box', 'pencil_case'], 'id': 783, 'def': 'a box for holding pencils', 'name': 'pencil_box'}, {'frequency': 'r', 'synset': 'pencil_sharpener.n.01', 'synonyms': ['pencil_sharpener'], 'id': 784, 'def': 'a rotary implement for sharpening the point on pencils', 'name': 'pencil_sharpener'}, {'frequency': 'r', 'synset': 'pendulum.n.01', 'synonyms': ['pendulum'], 'id': 785, 'def': 'an apparatus consisting of an object mounted so that it swings freely under the influence of gravity', 'name': 'pendulum'}, {'frequency': 'c', 'synset': 'penguin.n.01', 'synonyms': ['penguin'], 'id': 786, 'def': 'short-legged flightless birds of cold southern regions having webbed feet and wings modified as flippers', 'name': 'penguin'}, {'frequency': 'r', 'synset': 'pennant.n.02', 'synonyms': ['pennant'], 'id': 787, 'def': 'a flag longer than it is wide (and often tapering)', 'name': 'pennant'}, {'frequency': 'r', 'synset': 'penny.n.02', 'synonyms': ['penny_(coin)'], 'id': 788, 'def': 'a coin worth one-hundredth of the value of the basic unit', 'name': 'penny_(coin)'}, {'frequency': 'f', 'synset': 'pepper.n.03', 'synonyms': ['pepper', 'peppercorn'], 'id': 789, 'def': 'pungent seasoning from the berry of the common pepper plant; whole or ground', 'name': 'pepper'}, {'frequency': 'c', 'synset': 'pepper_mill.n.01', 'synonyms': ['pepper_mill', 'pepper_grinder'], 'id': 790, 'def': 'a mill for grinding pepper', 'name': 'pepper_mill'}, {'frequency': 'c', 'synset': 'perfume.n.02', 'synonyms': ['perfume'], 'id': 791, 'def': 'a toiletry that emits and diffuses a fragrant odor', 'name': 'perfume'}, {'frequency': 'r', 'synset': 'persimmon.n.02', 'synonyms': ['persimmon'], 'id': 792, 'def': 'orange fruit resembling a plum; edible when fully ripe', 'name': 'persimmon'}, {'frequency': 'f', 'synset': 'person.n.01', 'synonyms': ['person', 'baby', 'child', 'boy', 'girl', 'man', 'woman', 'human'], 'id': 793, 'def': 'a human being', 'name': 'person'}, {'frequency': 'c', 'synset': 'pet.n.01', 'synonyms': ['pet'], 'id': 794, 'def': 'a domesticated animal kept for companionship or amusement', 'name': 'pet'}, {'frequency': 'c', 'synset': 'pew.n.01', 'synonyms': ['pew_(church_bench)', 'church_bench'], 'id': 795, 'def': 'long bench with backs; used in church by the congregation', 'name': 'pew_(church_bench)'}, {'frequency': 'r', 'synset': 'phonebook.n.01', 'synonyms': ['phonebook', 'telephone_book', 'telephone_directory'], 'id': 796, 'def': 'a directory containing an alphabetical list of telephone subscribers and their telephone numbers', 'name': 'phonebook'}, {'frequency': 'c', 'synset': 'phonograph_record.n.01', 'synonyms': ['phonograph_record', 'phonograph_recording', 'record_(phonograph_recording)'], 'id': 797, 'def': 'sound recording consisting of a typically black disk with a continuous groove', 'name': 'phonograph_record'}, {'frequency': 'f', 'synset': 'piano.n.01', 'synonyms': ['piano'], 'id': 798, 'def': 'a keyboard instrument that is played by depressing keys that cause hammers to strike tuned strings and produce sounds', 'name': 'piano'}, {'frequency': 'f', 'synset': 'pickle.n.01', 'synonyms': ['pickle'], 'id': 799, 'def': 'vegetables (especially cucumbers) preserved in brine or vinegar', 'name': 'pickle'}, {'frequency': 'f', 'synset': 'pickup.n.01', 'synonyms': ['pickup_truck'], 'id': 800, 'def': 'a light truck with an open body and low sides and a tailboard', 'name': 'pickup_truck'}, {'frequency': 'c', 'synset': 'pie.n.01', 'synonyms': ['pie'], 'id': 801, 'def': 'dish baked in pastry-lined pan often with a pastry top', 'name': 'pie'}, {'frequency': 'c', 'synset': 'pigeon.n.01', 'synonyms': ['pigeon'], 'id': 802, 'def': 'wild and domesticated birds having a heavy body and short legs', 'name': 'pigeon'}, {'frequency': 'r', 'synset': 'piggy_bank.n.01', 'synonyms': ['piggy_bank', 'penny_bank'], 'id': 803, 'def': "a child's coin bank (often shaped like a pig)", 'name': 'piggy_bank'}, {'frequency': 'f', 'synset': 'pillow.n.01', 'synonyms': ['pillow'], 'id': 804, 'def': 'a cushion to support the head of a sleeping person', 'name': 'pillow'}, {'frequency': 'r', 'synset': 'pin.n.09', 'synonyms': ['pin_(non_jewelry)'], 'id': 805, 'def': 'a small slender (often pointed) piece of wood or metal used to support or fasten or attach things', 'name': 'pin_(non_jewelry)'}, {'frequency': 'f', 'synset': 'pineapple.n.02', 'synonyms': ['pineapple'], 'id': 806, 'def': 'large sweet fleshy tropical fruit with a tuft of stiff leaves', 'name': 'pineapple'}, {'frequency': 'c', 'synset': 'pinecone.n.01', 'synonyms': ['pinecone'], 'id': 807, 'def': 'the seed-producing cone of a pine tree', 'name': 'pinecone'}, {'frequency': 'r', 'synset': 'ping-pong_ball.n.01', 'synonyms': ['ping-pong_ball'], 'id': 808, 'def': 'light hollow ball used in playing table tennis', 'name': 'ping-pong_ball'}, {'frequency': 'r', 'synset': 'pinwheel.n.03', 'synonyms': ['pinwheel'], 'id': 809, 'def': 'a toy consisting of vanes of colored paper or plastic that is pinned to a stick and spins when it is pointed into the wind', 'name': 'pinwheel'}, {'frequency': 'r', 'synset': 'pipe.n.01', 'synonyms': ['tobacco_pipe'], 'id': 810, 'def': 'a tube with a small bowl at one end; used for smoking tobacco', 'name': 'tobacco_pipe'}, {'frequency': 'f', 'synset': 'pipe.n.02', 'synonyms': ['pipe', 'piping'], 'id': 811, 'def': 'a long tube made of metal or plastic that is used to carry water or oil or gas etc.', 'name': 'pipe'}, {'frequency': 'r', 'synset': 'pistol.n.01', 'synonyms': ['pistol', 'handgun'], 'id': 812, 'def': 'a firearm that is held and fired with one hand', 'name': 'pistol'}, {'frequency': 'c', 'synset': 'pita.n.01', 'synonyms': ['pita_(bread)', 'pocket_bread'], 'id': 813, 'def': 'usually small round bread that can open into a pocket for filling', 'name': 'pita_(bread)'}, {'frequency': 'f', 'synset': 'pitcher.n.02', 'synonyms': ['pitcher_(vessel_for_liquid)', 'ewer'], 'id': 814, 'def': 'an open vessel with a handle and a spout for pouring', 'name': 'pitcher_(vessel_for_liquid)'}, {'frequency': 'r', 'synset': 'pitchfork.n.01', 'synonyms': ['pitchfork'], 'id': 815, 'def': 'a long-handled hand tool with sharp widely spaced prongs for lifting and pitching hay', 'name': 'pitchfork'}, {'frequency': 'f', 'synset': 'pizza.n.01', 'synonyms': ['pizza'], 'id': 816, 'def': 'Italian open pie made of thin bread dough spread with a spiced mixture of e.g. tomato sauce and cheese', 'name': 'pizza'}, {'frequency': 'f', 'synset': 'place_mat.n.01', 'synonyms': ['place_mat'], 'id': 817, 'def': 'a mat placed on a table for an individual place setting', 'name': 'place_mat'}, {'frequency': 'f', 'synset': 'plate.n.04', 'synonyms': ['plate'], 'id': 818, 'def': 'dish on which food is served or from which food is eaten', 'name': 'plate'}, {'frequency': 'c', 'synset': 'platter.n.01', 'synonyms': ['platter'], 'id': 819, 'def': 'a large shallow dish used for serving food', 'name': 'platter'}, {'frequency': 'r', 'synset': 'playpen.n.01', 'synonyms': ['playpen'], 'id': 820, 'def': 'a portable enclosure in which babies may be left to play', 'name': 'playpen'}, {'frequency': 'c', 'synset': 'pliers.n.01', 'synonyms': ['pliers', 'plyers'], 'id': 821, 'def': 'a gripping hand tool with two hinged arms and (usually) serrated jaws', 'name': 'pliers'}, {'frequency': 'r', 'synset': 'plow.n.01', 'synonyms': ['plow_(farm_equipment)', 'plough_(farm_equipment)'], 'id': 822, 'def': 'a farm tool having one or more heavy blades to break the soil and cut a furrow prior to sowing', 'name': 'plow_(farm_equipment)'}, {'frequency': 'r', 'synset': 'plume.n.02', 'synonyms': ['plume'], 'id': 823, 'def': 'a feather or cluster of feathers worn as an ornament', 'name': 'plume'}, {'frequency': 'r', 'synset': 'pocket_watch.n.01', 'synonyms': ['pocket_watch'], 'id': 824, 'def': 'a watch that is carried in a small watch pocket', 'name': 'pocket_watch'}, {'frequency': 'c', 'synset': 'pocketknife.n.01', 'synonyms': ['pocketknife'], 'id': 825, 'def': 'a knife with a blade that folds into the handle; suitable for carrying in the pocket', 'name': 'pocketknife'}, {'frequency': 'c', 'synset': 'poker.n.01', 'synonyms': ['poker_(fire_stirring_tool)', 'stove_poker', 'fire_hook'], 'id': 826, 'def': 'fire iron consisting of a metal rod with a handle; used to stir a fire', 'name': 'poker_(fire_stirring_tool)'}, {'frequency': 'f', 'synset': 'pole.n.01', 'synonyms': ['pole', 'post'], 'id': 827, 'def': 'a long (usually round) rod of wood or metal or plastic', 'name': 'pole'}, {'frequency': 'f', 'synset': 'polo_shirt.n.01', 'synonyms': ['polo_shirt', 'sport_shirt'], 'id': 828, 'def': 'a shirt with short sleeves designed for comfort and casual wear', 'name': 'polo_shirt'}, {'frequency': 'r', 'synset': 'poncho.n.01', 'synonyms': ['poncho'], 'id': 829, 'def': 'a blanket-like cloak with a hole in the center for the head', 'name': 'poncho'}, {'frequency': 'c', 'synset': 'pony.n.05', 'synonyms': ['pony'], 'id': 830, 'def': 'any of various breeds of small gentle horses usually less than five feet high at the shoulder', 'name': 'pony'}, {'frequency': 'r', 'synset': 'pool_table.n.01', 'synonyms': ['pool_table', 'billiard_table', 'snooker_table'], 'id': 831, 'def': 'game equipment consisting of a heavy table on which pool is played', 'name': 'pool_table'}, {'frequency': 'f', 'synset': 'pop.n.02', 'synonyms': ['pop_(soda)', 'soda_(pop)', 'tonic', 'soft_drink'], 'id': 832, 'def': 'a sweet drink containing carbonated water and flavoring', 'name': 'pop_(soda)'}, {'frequency': 'c', 'synset': 'postbox.n.01', 'synonyms': ['postbox_(public)', 'mailbox_(public)'], 'id': 833, 'def': 'public box for deposit of mail', 'name': 'postbox_(public)'}, {'frequency': 'c', 'synset': 'postcard.n.01', 'synonyms': ['postcard', 'postal_card', 'mailing-card'], 'id': 834, 'def': 'a card for sending messages by post without an envelope', 'name': 'postcard'}, {'frequency': 'f', 'synset': 'poster.n.01', 'synonyms': ['poster', 'placard'], 'id': 835, 'def': 'a sign posted in a public place as an advertisement', 'name': 'poster'}, {'frequency': 'f', 'synset': 'pot.n.01', 'synonyms': ['pot'], 'id': 836, 'def': 'metal or earthenware cooking vessel that is usually round and deep; often has a handle and lid', 'name': 'pot'}, {'frequency': 'f', 'synset': 'pot.n.04', 'synonyms': ['flowerpot'], 'id': 837, 'def': 'a container in which plants are cultivated', 'name': 'flowerpot'}, {'frequency': 'f', 'synset': 'potato.n.01', 'synonyms': ['potato'], 'id': 838, 'def': 'an edible tuber native to South America', 'name': 'potato'}, {'frequency': 'c', 'synset': 'potholder.n.01', 'synonyms': ['potholder'], 'id': 839, 'def': 'an insulated pad for holding hot pots', 'name': 'potholder'}, {'frequency': 'c', 'synset': 'pottery.n.01', 'synonyms': ['pottery', 'clayware'], 'id': 840, 'def': 'ceramic ware made from clay and baked in a kiln', 'name': 'pottery'}, {'frequency': 'c', 'synset': 'pouch.n.01', 'synonyms': ['pouch'], 'id': 841, 'def': 'a small or medium size container for holding or carrying things', 'name': 'pouch'}, {'frequency': 'c', 'synset': 'power_shovel.n.01', 'synonyms': ['power_shovel', 'excavator', 'digger'], 'id': 842, 'def': 'a machine for excavating', 'name': 'power_shovel'}, {'frequency': 'c', 'synset': 'prawn.n.01', 'synonyms': ['prawn', 'shrimp'], 'id': 843, 'def': 'any of various edible decapod crustaceans', 'name': 'prawn'}, {'frequency': 'c', 'synset': 'pretzel.n.01', 'synonyms': ['pretzel'], 'id': 844, 'def': 'glazed and salted cracker typically in the shape of a loose knot', 'name': 'pretzel'}, {'frequency': 'f', 'synset': 'printer.n.03', 'synonyms': ['printer', 'printing_machine'], 'id': 845, 'def': 'a machine that prints', 'name': 'printer'}, {'frequency': 'c', 'synset': 'projectile.n.01', 'synonyms': ['projectile_(weapon)', 'missile'], 'id': 846, 'def': 'a weapon that is forcibly thrown or projected at a targets', 'name': 'projectile_(weapon)'}, {'frequency': 'c', 'synset': 'projector.n.02', 'synonyms': ['projector'], 'id': 847, 'def': 'an optical instrument that projects an enlarged image onto a screen', 'name': 'projector'}, {'frequency': 'f', 'synset': 'propeller.n.01', 'synonyms': ['propeller', 'propellor'], 'id': 848, 'def': 'a mechanical device that rotates to push against air or water', 'name': 'propeller'}, {'frequency': 'r', 'synset': 'prune.n.01', 'synonyms': ['prune'], 'id': 849, 'def': 'dried plum', 'name': 'prune'}, {'frequency': 'r', 'synset': 'pudding.n.01', 'synonyms': ['pudding'], 'id': 850, 'def': 'any of various soft thick unsweetened baked dishes', 'name': 'pudding'}, {'frequency': 'r', 'synset': 'puffer.n.02', 'synonyms': ['puffer_(fish)', 'pufferfish', 'blowfish', 'globefish'], 'id': 851, 'def': 'fishes whose elongated spiny body can inflate itself with water or air to form a globe', 'name': 'puffer_(fish)'}, {'frequency': 'r', 'synset': 'puffin.n.01', 'synonyms': ['puffin'], 'id': 852, 'def': 'seabirds having short necks and brightly colored compressed bills', 'name': 'puffin'}, {'frequency': 'r', 'synset': 'pug.n.01', 'synonyms': ['pug-dog'], 'id': 853, 'def': 'small compact smooth-coated breed of Asiatic origin having a tightly curled tail and broad flat wrinkled muzzle', 'name': 'pug-dog'}, {'frequency': 'c', 'synset': 'pumpkin.n.02', 'synonyms': ['pumpkin'], 'id': 854, 'def': 'usually large pulpy deep-yellow round fruit of the squash family maturing in late summer or early autumn', 'name': 'pumpkin'}, {'frequency': 'r', 'synset': 'punch.n.03', 'synonyms': ['puncher'], 'id': 855, 'def': 'a tool for making holes or indentations', 'name': 'puncher'}, {'frequency': 'r', 'synset': 'puppet.n.01', 'synonyms': ['puppet', 'marionette'], 'id': 856, 'def': 'a small figure of a person operated from above with strings by a puppeteer', 'name': 'puppet'}, {'frequency': 'c', 'synset': 'puppy.n.01', 'synonyms': ['puppy'], 'id': 857, 'def': 'a young dog', 'name': 'puppy'}, {'frequency': 'r', 'synset': 'quesadilla.n.01', 'synonyms': ['quesadilla'], 'id': 858, 'def': 'a tortilla that is filled with cheese and heated', 'name': 'quesadilla'}, {'frequency': 'r', 'synset': 'quiche.n.02', 'synonyms': ['quiche'], 'id': 859, 'def': 'a tart filled with rich unsweetened custard; often contains other ingredients (as cheese or ham or seafood or vegetables)', 'name': 'quiche'}, {'frequency': 'f', 'synset': 'quilt.n.01', 'synonyms': ['quilt', 'comforter'], 'id': 860, 'def': 'bedding made of two layers of cloth filled with stuffing and stitched together', 'name': 'quilt'}, {'frequency': 'c', 'synset': 'rabbit.n.01', 'synonyms': ['rabbit'], 'id': 861, 'def': 'any of various burrowing animals of the family Leporidae having long ears and short tails', 'name': 'rabbit'}, {'frequency': 'r', 'synset': 'racer.n.02', 'synonyms': ['race_car', 'racing_car'], 'id': 862, 'def': 'a fast car that competes in races', 'name': 'race_car'}, {'frequency': 'c', 'synset': 'racket.n.04', 'synonyms': ['racket', 'racquet'], 'id': 863, 'def': 'a sports implement used to strike a ball in various games', 'name': 'racket'}, {'frequency': 'r', 'synset': 'radar.n.01', 'synonyms': ['radar'], 'id': 864, 'def': 'measuring instrument in which the echo of a pulse of microwave radiation is used to detect and locate distant objects', 'name': 'radar'}, {'frequency': 'f', 'synset': 'radiator.n.03', 'synonyms': ['radiator'], 'id': 865, 'def': 'a mechanism consisting of a metal honeycomb through which hot fluids circulate', 'name': 'radiator'}, {'frequency': 'c', 'synset': 'radio_receiver.n.01', 'synonyms': ['radio_receiver', 'radio_set', 'radio', 'tuner_(radio)'], 'id': 866, 'def': 'an electronic receiver that detects and demodulates and amplifies transmitted radio signals', 'name': 'radio_receiver'}, {'frequency': 'c', 'synset': 'radish.n.03', 'synonyms': ['radish', 'daikon'], 'id': 867, 'def': 'pungent edible root of any of various cultivated radish plants', 'name': 'radish'}, {'frequency': 'c', 'synset': 'raft.n.01', 'synonyms': ['raft'], 'id': 868, 'def': 'a flat float (usually made of logs or planks) that can be used for transport or as a platform for swimmers', 'name': 'raft'}, {'frequency': 'r', 'synset': 'rag_doll.n.01', 'synonyms': ['rag_doll'], 'id': 869, 'def': 'a cloth doll that is stuffed and (usually) painted', 'name': 'rag_doll'}, {'frequency': 'c', 'synset': 'raincoat.n.01', 'synonyms': ['raincoat', 'waterproof_jacket'], 'id': 870, 'def': 'a water-resistant coat', 'name': 'raincoat'}, {'frequency': 'c', 'synset': 'ram.n.05', 'synonyms': ['ram_(animal)'], 'id': 871, 'def': 'uncastrated adult male sheep', 'name': 'ram_(animal)'}, {'frequency': 'c', 'synset': 'raspberry.n.02', 'synonyms': ['raspberry'], 'id': 872, 'def': 'red or black edible aggregate berries usually smaller than the related blackberries', 'name': 'raspberry'}, {'frequency': 'r', 'synset': 'rat.n.01', 'synonyms': ['rat'], 'id': 873, 'def': 'any of various long-tailed rodents similar to but larger than a mouse', 'name': 'rat'}, {'frequency': 'c', 'synset': 'razorblade.n.01', 'synonyms': ['razorblade'], 'id': 874, 'def': 'a blade that has very sharp edge', 'name': 'razorblade'}, {'frequency': 'c', 'synset': 'reamer.n.01', 'synonyms': ['reamer_(juicer)', 'juicer', 'juice_reamer'], 'id': 875, 'def': 'a squeezer with a conical ridged center that is used for squeezing juice from citrus fruit', 'name': 'reamer_(juicer)'}, {'frequency': 'f', 'synset': 'rearview_mirror.n.01', 'synonyms': ['rearview_mirror'], 'id': 876, 'def': 'vehicle mirror (side or rearview)', 'name': 'rearview_mirror'}, {'frequency': 'c', 'synset': 'receipt.n.02', 'synonyms': ['receipt'], 'id': 877, 'def': 'an acknowledgment (usually tangible) that payment has been made', 'name': 'receipt'}, {'frequency': 'c', 'synset': 'recliner.n.01', 'synonyms': ['recliner', 'reclining_chair', 'lounger_(chair)'], 'id': 878, 'def': 'an armchair whose back can be lowered and foot can be raised to allow the sitter to recline in it', 'name': 'recliner'}, {'frequency': 'c', 'synset': 'record_player.n.01', 'synonyms': ['record_player', 'phonograph_(record_player)', 'turntable'], 'id': 879, 'def': 'machine in which rotating records cause a stylus to vibrate and the vibrations are amplified acoustically or electronically', 'name': 'record_player'}, {'frequency': 'f', 'synset': 'reflector.n.01', 'synonyms': ['reflector'], 'id': 880, 'def': 'device that reflects light, radiation, etc.', 'name': 'reflector'}, {'frequency': 'f', 'synset': 'remote_control.n.01', 'synonyms': ['remote_control'], 'id': 881, 'def': 'a device that can be used to control a machine or apparatus from a distance', 'name': 'remote_control'}, {'frequency': 'c', 'synset': 'rhinoceros.n.01', 'synonyms': ['rhinoceros'], 'id': 882, 'def': 'massive powerful herbivorous odd-toed ungulate of southeast Asia and Africa having very thick skin and one or two horns on the snout', 'name': 'rhinoceros'}, {'frequency': 'r', 'synset': 'rib.n.03', 'synonyms': ['rib_(food)'], 'id': 883, 'def': 'cut of meat including one or more ribs', 'name': 'rib_(food)'}, {'frequency': 'c', 'synset': 'rifle.n.01', 'synonyms': ['rifle'], 'id': 884, 'def': 'a shoulder firearm with a long barrel', 'name': 'rifle'}, {'frequency': 'f', 'synset': 'ring.n.08', 'synonyms': ['ring'], 'id': 885, 'def': 'jewelry consisting of a circlet of precious metal (often set with jewels) worn on the finger', 'name': 'ring'}, {'frequency': 'r', 'synset': 'river_boat.n.01', 'synonyms': ['river_boat'], 'id': 886, 'def': 'a boat used on rivers or to ply a river', 'name': 'river_boat'}, {'frequency': 'r', 'synset': 'road_map.n.02', 'synonyms': ['road_map'], 'id': 887, 'def': '(NOT A ROAD) a MAP showing roads (for automobile travel)', 'name': 'road_map'}, {'frequency': 'c', 'synset': 'robe.n.01', 'synonyms': ['robe'], 'id': 888, 'def': 'any loose flowing garment', 'name': 'robe'}, {'frequency': 'c', 'synset': 'rocking_chair.n.01', 'synonyms': ['rocking_chair'], 'id': 889, 'def': 'a chair mounted on rockers', 'name': 'rocking_chair'}, {'frequency': 'r', 'synset': 'rodent.n.01', 'synonyms': ['rodent'], 'id': 890, 'def': 'relatively small placental mammals having a single pair of constantly growing incisor teeth specialized for gnawing', 'name': 'rodent'}, {'frequency': 'r', 'synset': 'roller_skate.n.01', 'synonyms': ['roller_skate'], 'id': 891, 'def': 'a shoe with pairs of rollers (small hard wheels) fixed to the sole', 'name': 'roller_skate'}, {'frequency': 'r', 'synset': 'rollerblade.n.01', 'synonyms': ['Rollerblade'], 'id': 892, 'def': 'an in-line variant of a roller skate', 'name': 'Rollerblade'}, {'frequency': 'c', 'synset': 'rolling_pin.n.01', 'synonyms': ['rolling_pin'], 'id': 893, 'def': 'utensil consisting of a cylinder (usually of wood) with a handle at each end; used to roll out dough', 'name': 'rolling_pin'}, {'frequency': 'r', 'synset': 'root_beer.n.01', 'synonyms': ['root_beer'], 'id': 894, 'def': 'carbonated drink containing extracts of roots and herbs', 'name': 'root_beer'}, {'frequency': 'c', 'synset': 'router.n.02', 'synonyms': ['router_(computer_equipment)'], 'id': 895, 'def': 'a device that forwards data packets between computer networks', 'name': 'router_(computer_equipment)'}, {'frequency': 'f', 'synset': 'rubber_band.n.01', 'synonyms': ['rubber_band', 'elastic_band'], 'id': 896, 'def': 'a narrow band of elastic rubber used to hold things (such as papers) together', 'name': 'rubber_band'}, {'frequency': 'c', 'synset': 'runner.n.08', 'synonyms': ['runner_(carpet)'], 'id': 897, 'def': 'a long narrow carpet', 'name': 'runner_(carpet)'}, {'frequency': 'f', 'synset': 'sack.n.01', 'synonyms': ['plastic_bag', 'paper_bag'], 'id': 898, 'def': "a bag made of paper or plastic for holding customer's purchases", 'name': 'plastic_bag'}, {'frequency': 'f', 'synset': 'saddle.n.01', 'synonyms': ['saddle_(on_an_animal)'], 'id': 899, 'def': 'a seat for the rider of a horse or camel', 'name': 'saddle_(on_an_animal)'}, {'frequency': 'f', 'synset': 'saddle_blanket.n.01', 'synonyms': ['saddle_blanket', 'saddlecloth', 'horse_blanket'], 'id': 900, 'def': 'stable gear consisting of a blanket placed under the saddle', 'name': 'saddle_blanket'}, {'frequency': 'c', 'synset': 'saddlebag.n.01', 'synonyms': ['saddlebag'], 'id': 901, 'def': 'a large bag (or pair of bags) hung over a saddle', 'name': 'saddlebag'}, {'frequency': 'r', 'synset': 'safety_pin.n.01', 'synonyms': ['safety_pin'], 'id': 902, 'def': 'a pin in the form of a clasp; has a guard so the point of the pin will not stick the user', 'name': 'safety_pin'}, {'frequency': 'f', 'synset': 'sail.n.01', 'synonyms': ['sail'], 'id': 903, 'def': 'a large piece of fabric by means of which wind is used to propel a sailing vessel', 'name': 'sail'}, {'frequency': 'f', 'synset': 'salad.n.01', 'synonyms': ['salad'], 'id': 904, 'def': 'food mixtures either arranged on a plate or tossed and served with a moist dressing; usually consisting of or including greens', 'name': 'salad'}, {'frequency': 'r', 'synset': 'salad_plate.n.01', 'synonyms': ['salad_plate', 'salad_bowl'], 'id': 905, 'def': 'a plate or bowl for individual servings of salad', 'name': 'salad_plate'}, {'frequency': 'c', 'synset': 'salami.n.01', 'synonyms': ['salami'], 'id': 906, 'def': 'highly seasoned fatty sausage of pork and beef usually dried', 'name': 'salami'}, {'frequency': 'c', 'synset': 'salmon.n.01', 'synonyms': ['salmon_(fish)'], 'id': 907, 'def': 'any of various large food and game fishes of northern waters', 'name': 'salmon_(fish)'}, {'frequency': 'r', 'synset': 'salmon.n.03', 'synonyms': ['salmon_(food)'], 'id': 908, 'def': 'flesh of any of various marine or freshwater fish of the family Salmonidae', 'name': 'salmon_(food)'}, {'frequency': 'c', 'synset': 'salsa.n.01', 'synonyms': ['salsa'], 'id': 909, 'def': 'spicy sauce of tomatoes and onions and chili peppers to accompany Mexican foods', 'name': 'salsa'}, {'frequency': 'f', 'synset': 'saltshaker.n.01', 'synonyms': ['saltshaker'], 'id': 910, 'def': 'a shaker with a perforated top for sprinkling salt', 'name': 'saltshaker'}, {'frequency': 'f', 'synset': 'sandal.n.01', 'synonyms': ['sandal_(type_of_shoe)'], 'id': 911, 'def': 'a shoe consisting of a sole fastened by straps to the foot', 'name': 'sandal_(type_of_shoe)'}, {'frequency': 'f', 'synset': 'sandwich.n.01', 'synonyms': ['sandwich'], 'id': 912, 'def': 'two (or more) slices of bread with a filling between them', 'name': 'sandwich'}, {'frequency': 'r', 'synset': 'satchel.n.01', 'synonyms': ['satchel'], 'id': 913, 'def': 'luggage consisting of a small case with a flat bottom and (usually) a shoulder strap', 'name': 'satchel'}, {'frequency': 'r', 'synset': 'saucepan.n.01', 'synonyms': ['saucepan'], 'id': 914, 'def': 'a deep pan with a handle; used for stewing or boiling', 'name': 'saucepan'}, {'frequency': 'f', 'synset': 'saucer.n.02', 'synonyms': ['saucer'], 'id': 915, 'def': 'a small shallow dish for holding a cup at the table', 'name': 'saucer'}, {'frequency': 'f', 'synset': 'sausage.n.01', 'synonyms': ['sausage'], 'id': 916, 'def': 'highly seasoned minced meat stuffed in casings', 'name': 'sausage'}, {'frequency': 'r', 'synset': 'sawhorse.n.01', 'synonyms': ['sawhorse', 'sawbuck'], 'id': 917, 'def': 'a framework for holding wood that is being sawed', 'name': 'sawhorse'}, {'frequency': 'r', 'synset': 'sax.n.02', 'synonyms': ['saxophone'], 'id': 918, 'def': "a wind instrument with a `J'-shaped form typically made of brass", 'name': 'saxophone'}, {'frequency': 'f', 'synset': 'scale.n.07', 'synonyms': ['scale_(measuring_instrument)'], 'id': 919, 'def': 'a measuring instrument for weighing; shows amount of mass', 'name': 'scale_(measuring_instrument)'}, {'frequency': 'r', 'synset': 'scarecrow.n.01', 'synonyms': ['scarecrow', 'strawman'], 'id': 920, 'def': 'an effigy in the shape of a man to frighten birds away from seeds', 'name': 'scarecrow'}, {'frequency': 'f', 'synset': 'scarf.n.01', 'synonyms': ['scarf'], 'id': 921, 'def': 'a garment worn around the head or neck or shoulders for warmth or decoration', 'name': 'scarf'}, {'frequency': 'c', 'synset': 'school_bus.n.01', 'synonyms': ['school_bus'], 'id': 922, 'def': 'a bus used to transport children to or from school', 'name': 'school_bus'}, {'frequency': 'f', 'synset': 'scissors.n.01', 'synonyms': ['scissors'], 'id': 923, 'def': 'a tool having two crossed pivoting blades with looped handles', 'name': 'scissors'}, {'frequency': 'f', 'synset': 'scoreboard.n.01', 'synonyms': ['scoreboard'], 'id': 924, 'def': 'a large board for displaying the score of a contest (and some other information)', 'name': 'scoreboard'}, {'frequency': 'r', 'synset': 'scraper.n.01', 'synonyms': ['scraper'], 'id': 925, 'def': 'any of various hand tools for scraping', 'name': 'scraper'}, {'frequency': 'c', 'synset': 'screwdriver.n.01', 'synonyms': ['screwdriver'], 'id': 926, 'def': 'a hand tool for driving screws; has a tip that fits into the head of a screw', 'name': 'screwdriver'}, {'frequency': 'f', 'synset': 'scrub_brush.n.01', 'synonyms': ['scrubbing_brush'], 'id': 927, 'def': 'a brush with short stiff bristles for heavy cleaning', 'name': 'scrubbing_brush'}, {'frequency': 'c', 'synset': 'sculpture.n.01', 'synonyms': ['sculpture'], 'id': 928, 'def': 'a three-dimensional work of art', 'name': 'sculpture'}, {'frequency': 'c', 'synset': 'seabird.n.01', 'synonyms': ['seabird', 'seafowl'], 'id': 929, 'def': 'a bird that frequents coastal waters and the open ocean: gulls; pelicans; gannets; cormorants; albatrosses; petrels; etc.', 'name': 'seabird'}, {'frequency': 'c', 'synset': 'seahorse.n.02', 'synonyms': ['seahorse'], 'id': 930, 'def': 'small fish with horse-like heads bent sharply downward and curled tails', 'name': 'seahorse'}, {'frequency': 'r', 'synset': 'seaplane.n.01', 'synonyms': ['seaplane', 'hydroplane'], 'id': 931, 'def': 'an airplane that can land on or take off from water', 'name': 'seaplane'}, {'frequency': 'c', 'synset': 'seashell.n.01', 'synonyms': ['seashell'], 'id': 932, 'def': 'the shell of a marine organism', 'name': 'seashell'}, {'frequency': 'c', 'synset': 'sewing_machine.n.01', 'synonyms': ['sewing_machine'], 'id': 933, 'def': 'a textile machine used as a home appliance for sewing', 'name': 'sewing_machine'}, {'frequency': 'c', 'synset': 'shaker.n.03', 'synonyms': ['shaker'], 'id': 934, 'def': 'a container in which something can be shaken', 'name': 'shaker'}, {'frequency': 'c', 'synset': 'shampoo.n.01', 'synonyms': ['shampoo'], 'id': 935, 'def': 'cleansing agent consisting of soaps or detergents used for washing the hair', 'name': 'shampoo'}, {'frequency': 'c', 'synset': 'shark.n.01', 'synonyms': ['shark'], 'id': 936, 'def': 'typically large carnivorous fishes with sharpe teeth', 'name': 'shark'}, {'frequency': 'r', 'synset': 'sharpener.n.01', 'synonyms': ['sharpener'], 'id': 937, 'def': 'any implement that is used to make something (an edge or a point) sharper', 'name': 'sharpener'}, {'frequency': 'r', 'synset': 'sharpie.n.03', 'synonyms': ['Sharpie'], 'id': 938, 'def': 'a pen with indelible ink that will write on any surface', 'name': 'Sharpie'}, {'frequency': 'r', 'synset': 'shaver.n.03', 'synonyms': ['shaver_(electric)', 'electric_shaver', 'electric_razor'], 'id': 939, 'def': 'a razor powered by an electric motor', 'name': 'shaver_(electric)'}, {'frequency': 'c', 'synset': 'shaving_cream.n.01', 'synonyms': ['shaving_cream', 'shaving_soap'], 'id': 940, 'def': 'toiletry consisting that forms a rich lather for softening the beard before shaving', 'name': 'shaving_cream'}, {'frequency': 'r', 'synset': 'shawl.n.01', 'synonyms': ['shawl'], 'id': 941, 'def': 'cloak consisting of an oblong piece of cloth used to cover the head and shoulders', 'name': 'shawl'}, {'frequency': 'r', 'synset': 'shears.n.01', 'synonyms': ['shears'], 'id': 942, 'def': 'large scissors with strong blades', 'name': 'shears'}, {'frequency': 'f', 'synset': 'sheep.n.01', 'synonyms': ['sheep'], 'id': 943, 'def': 'woolly usually horned ruminant mammal related to the goat', 'name': 'sheep'}, {'frequency': 'r', 'synset': 'shepherd_dog.n.01', 'synonyms': ['shepherd_dog', 'sheepdog'], 'id': 944, 'def': 'any of various usually long-haired breeds of dog reared to herd and guard sheep', 'name': 'shepherd_dog'}, {'frequency': 'r', 'synset': 'sherbert.n.01', 'synonyms': ['sherbert', 'sherbet'], 'id': 945, 'def': 'a frozen dessert made primarily of fruit juice and sugar', 'name': 'sherbert'}, {'frequency': 'c', 'synset': 'shield.n.02', 'synonyms': ['shield'], 'id': 946, 'def': 'armor carried on the arm to intercept blows', 'name': 'shield'}, {'frequency': 'f', 'synset': 'shirt.n.01', 'synonyms': ['shirt'], 'id': 947, 'def': 'a garment worn on the upper half of the body', 'name': 'shirt'}, {'frequency': 'f', 'synset': 'shoe.n.01', 'synonyms': ['shoe', 'sneaker_(type_of_shoe)', 'tennis_shoe'], 'id': 948, 'def': 'common footwear covering the foot', 'name': 'shoe'}, {'frequency': 'f', 'synset': 'shopping_bag.n.01', 'synonyms': ['shopping_bag'], 'id': 949, 'def': 'a bag made of plastic or strong paper (often with handles); used to transport goods after shopping', 'name': 'shopping_bag'}, {'frequency': 'c', 'synset': 'shopping_cart.n.01', 'synonyms': ['shopping_cart'], 'id': 950, 'def': 'a handcart that holds groceries or other goods while shopping', 'name': 'shopping_cart'}, {'frequency': 'f', 'synset': 'short_pants.n.01', 'synonyms': ['short_pants', 'shorts_(clothing)', 'trunks_(clothing)'], 'id': 951, 'def': 'trousers that end at or above the knee', 'name': 'short_pants'}, {'frequency': 'r', 'synset': 'shot_glass.n.01', 'synonyms': ['shot_glass'], 'id': 952, 'def': 'a small glass adequate to hold a single swallow of whiskey', 'name': 'shot_glass'}, {'frequency': 'f', 'synset': 'shoulder_bag.n.01', 'synonyms': ['shoulder_bag'], 'id': 953, 'def': 'a large handbag that can be carried by a strap looped over the shoulder', 'name': 'shoulder_bag'}, {'frequency': 'c', 'synset': 'shovel.n.01', 'synonyms': ['shovel'], 'id': 954, 'def': 'a hand tool for lifting loose material such as snow, dirt, etc.', 'name': 'shovel'}, {'frequency': 'f', 'synset': 'shower.n.01', 'synonyms': ['shower_head'], 'id': 955, 'def': 'a plumbing fixture that sprays water over you', 'name': 'shower_head'}, {'frequency': 'r', 'synset': 'shower_cap.n.01', 'synonyms': ['shower_cap'], 'id': 956, 'def': 'a tight cap worn to keep hair dry while showering', 'name': 'shower_cap'}, {'frequency': 'f', 'synset': 'shower_curtain.n.01', 'synonyms': ['shower_curtain'], 'id': 957, 'def': 'a curtain that keeps water from splashing out of the shower area', 'name': 'shower_curtain'}, {'frequency': 'r', 'synset': 'shredder.n.01', 'synonyms': ['shredder_(for_paper)'], 'id': 958, 'def': 'a device that shreds documents', 'name': 'shredder_(for_paper)'}, {'frequency': 'f', 'synset': 'signboard.n.01', 'synonyms': ['signboard'], 'id': 959, 'def': 'structure displaying a board on which advertisements can be posted', 'name': 'signboard'}, {'frequency': 'c', 'synset': 'silo.n.01', 'synonyms': ['silo'], 'id': 960, 'def': 'a cylindrical tower used for storing goods', 'name': 'silo'}, {'frequency': 'f', 'synset': 'sink.n.01', 'synonyms': ['sink'], 'id': 961, 'def': 'plumbing fixture consisting of a water basin fixed to a wall or floor and having a drainpipe', 'name': 'sink'}, {'frequency': 'f', 'synset': 'skateboard.n.01', 'synonyms': ['skateboard'], 'id': 962, 'def': 'a board with wheels that is ridden in a standing or crouching position and propelled by foot', 'name': 'skateboard'}, {'frequency': 'c', 'synset': 'skewer.n.01', 'synonyms': ['skewer'], 'id': 963, 'def': 'a long pin for holding meat in position while it is being roasted', 'name': 'skewer'}, {'frequency': 'f', 'synset': 'ski.n.01', 'synonyms': ['ski'], 'id': 964, 'def': 'sports equipment for skiing on snow', 'name': 'ski'}, {'frequency': 'f', 'synset': 'ski_boot.n.01', 'synonyms': ['ski_boot'], 'id': 965, 'def': 'a stiff boot that is fastened to a ski with a ski binding', 'name': 'ski_boot'}, {'frequency': 'f', 'synset': 'ski_parka.n.01', 'synonyms': ['ski_parka', 'ski_jacket'], 'id': 966, 'def': 'a parka to be worn while skiing', 'name': 'ski_parka'}, {'frequency': 'f', 'synset': 'ski_pole.n.01', 'synonyms': ['ski_pole'], 'id': 967, 'def': 'a pole with metal points used as an aid in skiing', 'name': 'ski_pole'}, {'frequency': 'f', 'synset': 'skirt.n.02', 'synonyms': ['skirt'], 'id': 968, 'def': 'a garment hanging from the waist; worn mainly by girls and women', 'name': 'skirt'}, {'frequency': 'r', 'synset': 'skullcap.n.01', 'synonyms': ['skullcap'], 'id': 969, 'def': 'rounded brimless cap fitting the crown of the head', 'name': 'skullcap'}, {'frequency': 'c', 'synset': 'sled.n.01', 'synonyms': ['sled', 'sledge', 'sleigh'], 'id': 970, 'def': 'a vehicle or flat object for transportation over snow by sliding or pulled by dogs, etc.', 'name': 'sled'}, {'frequency': 'c', 'synset': 'sleeping_bag.n.01', 'synonyms': ['sleeping_bag'], 'id': 971, 'def': 'large padded bag designed to be slept in outdoors', 'name': 'sleeping_bag'}, {'frequency': 'r', 'synset': 'sling.n.05', 'synonyms': ['sling_(bandage)', 'triangular_bandage'], 'id': 972, 'def': 'bandage to support an injured forearm; slung over the shoulder or neck', 'name': 'sling_(bandage)'}, {'frequency': 'c', 'synset': 'slipper.n.01', 'synonyms': ['slipper_(footwear)', 'carpet_slipper_(footwear)'], 'id': 973, 'def': 'low footwear that can be slipped on and off easily; usually worn indoors', 'name': 'slipper_(footwear)'}, {'frequency': 'r', 'synset': 'smoothie.n.02', 'synonyms': ['smoothie'], 'id': 974, 'def': 'a thick smooth drink consisting of fresh fruit pureed with ice cream or yoghurt or milk', 'name': 'smoothie'}, {'frequency': 'r', 'synset': 'snake.n.01', 'synonyms': ['snake', 'serpent'], 'id': 975, 'def': 'limbless scaly elongate reptile; some are venomous', 'name': 'snake'}, {'frequency': 'f', 'synset': 'snowboard.n.01', 'synonyms': ['snowboard'], 'id': 976, 'def': 'a board that resembles a broad ski or a small surfboard; used in a standing position to slide down snow-covered slopes', 'name': 'snowboard'}, {'frequency': 'c', 'synset': 'snowman.n.01', 'synonyms': ['snowman'], 'id': 977, 'def': 'a figure of a person made of packed snow', 'name': 'snowman'}, {'frequency': 'c', 'synset': 'snowmobile.n.01', 'synonyms': ['snowmobile'], 'id': 978, 'def': 'tracked vehicle for travel on snow having skis in front', 'name': 'snowmobile'}, {'frequency': 'f', 'synset': 'soap.n.01', 'synonyms': ['soap'], 'id': 979, 'def': 'a cleansing agent made from the salts of vegetable or animal fats', 'name': 'soap'}, {'frequency': 'f', 'synset': 'soccer_ball.n.01', 'synonyms': ['soccer_ball'], 'id': 980, 'def': "an inflated ball used in playing soccer (called `football' outside of the United States)", 'name': 'soccer_ball'}, {'frequency': 'f', 'synset': 'sock.n.01', 'synonyms': ['sock'], 'id': 981, 'def': 'cloth covering for the foot; worn inside the shoe; reaches to between the ankle and the knee', 'name': 'sock'}, {'frequency': 'f', 'synset': 'sofa.n.01', 'synonyms': ['sofa', 'couch', 'lounge'], 'id': 982, 'def': 'an upholstered seat for more than one person', 'name': 'sofa'}, {'frequency': 'r', 'synset': 'softball.n.01', 'synonyms': ['softball'], 'id': 983, 'def': 'ball used in playing softball', 'name': 'softball'}, {'frequency': 'c', 'synset': 'solar_array.n.01', 'synonyms': ['solar_array', 'solar_battery', 'solar_panel'], 'id': 984, 'def': 'electrical device consisting of a large array of connected solar cells', 'name': 'solar_array'}, {'frequency': 'r', 'synset': 'sombrero.n.02', 'synonyms': ['sombrero'], 'id': 985, 'def': 'a straw hat with a tall crown and broad brim; worn in American southwest and in Mexico', 'name': 'sombrero'}, {'frequency': 'f', 'synset': 'soup.n.01', 'synonyms': ['soup'], 'id': 986, 'def': 'liquid food especially of meat or fish or vegetable stock often containing pieces of solid food', 'name': 'soup'}, {'frequency': 'r', 'synset': 'soup_bowl.n.01', 'synonyms': ['soup_bowl'], 'id': 987, 'def': 'a bowl for serving soup', 'name': 'soup_bowl'}, {'frequency': 'c', 'synset': 'soupspoon.n.01', 'synonyms': ['soupspoon'], 'id': 988, 'def': 'a spoon with a rounded bowl for eating soup', 'name': 'soupspoon'}, {'frequency': 'c', 'synset': 'sour_cream.n.01', 'synonyms': ['sour_cream', 'soured_cream'], 'id': 989, 'def': 'soured light cream', 'name': 'sour_cream'}, {'frequency': 'r', 'synset': 'soya_milk.n.01', 'synonyms': ['soya_milk', 'soybean_milk', 'soymilk'], 'id': 990, 'def': 'a milk substitute containing soybean flour and water; used in some infant formulas and in making tofu', 'name': 'soya_milk'}, {'frequency': 'r', 'synset': 'space_shuttle.n.01', 'synonyms': ['space_shuttle'], 'id': 991, 'def': "a reusable spacecraft with wings for a controlled descent through the Earth's atmosphere", 'name': 'space_shuttle'}, {'frequency': 'r', 'synset': 'sparkler.n.02', 'synonyms': ['sparkler_(fireworks)'], 'id': 992, 'def': 'a firework that burns slowly and throws out a shower of sparks', 'name': 'sparkler_(fireworks)'}, {'frequency': 'f', 'synset': 'spatula.n.02', 'synonyms': ['spatula'], 'id': 993, 'def': 'a hand tool with a thin flexible blade used to mix or spread soft substances', 'name': 'spatula'}, {'frequency': 'r', 'synset': 'spear.n.01', 'synonyms': ['spear', 'lance'], 'id': 994, 'def': 'a long pointed rod used as a tool or weapon', 'name': 'spear'}, {'frequency': 'f', 'synset': 'spectacles.n.01', 'synonyms': ['spectacles', 'specs', 'eyeglasses', 'glasses'], 'id': 995, 'def': 'optical instrument consisting of a frame that holds a pair of lenses for correcting defective vision', 'name': 'spectacles'}, {'frequency': 'c', 'synset': 'spice_rack.n.01', 'synonyms': ['spice_rack'], 'id': 996, 'def': 'a rack for displaying containers filled with spices', 'name': 'spice_rack'}, {'frequency': 'c', 'synset': 'spider.n.01', 'synonyms': ['spider'], 'id': 997, 'def': 'predatory arachnid with eight legs, two poison fangs, two feelers, and usually two silk-spinning organs at the back end of the body', 'name': 'spider'}, {'frequency': 'r', 'synset': 'spiny_lobster.n.02', 'synonyms': ['crawfish', 'crayfish'], 'id': 998, 'def': 'large edible marine crustacean having a spiny carapace but lacking the large pincers of true lobsters', 'name': 'crawfish'}, {'frequency': 'c', 'synset': 'sponge.n.01', 'synonyms': ['sponge'], 'id': 999, 'def': 'a porous mass usable to absorb water typically used for cleaning', 'name': 'sponge'}, {'frequency': 'f', 'synset': 'spoon.n.01', 'synonyms': ['spoon'], 'id': 1000, 'def': 'a piece of cutlery with a shallow bowl-shaped container and a handle', 'name': 'spoon'}, {'frequency': 'c', 'synset': 'sportswear.n.01', 'synonyms': ['sportswear', 'athletic_wear', 'activewear'], 'id': 1001, 'def': 'attire worn for sport or for casual wear', 'name': 'sportswear'}, {'frequency': 'c', 'synset': 'spotlight.n.02', 'synonyms': ['spotlight'], 'id': 1002, 'def': 'a lamp that produces a strong beam of light to illuminate a restricted area; used to focus attention of a stage performer', 'name': 'spotlight'}, {'frequency': 'r', 'synset': 'squid.n.01', 'synonyms': ['squid_(food)', 'calamari', 'calamary'], 'id': 1003, 'def': '(Italian cuisine) squid prepared as food', 'name': 'squid_(food)'}, {'frequency': 'c', 'synset': 'squirrel.n.01', 'synonyms': ['squirrel'], 'id': 1004, 'def': 'a kind of arboreal rodent having a long bushy tail', 'name': 'squirrel'}, {'frequency': 'r', 'synset': 'stagecoach.n.01', 'synonyms': ['stagecoach'], 'id': 1005, 'def': 'a large coach-and-four formerly used to carry passengers and mail on regular routes between towns', 'name': 'stagecoach'}, {'frequency': 'c', 'synset': 'stapler.n.01', 'synonyms': ['stapler_(stapling_machine)'], 'id': 1006, 'def': 'a machine that inserts staples into sheets of paper in order to fasten them together', 'name': 'stapler_(stapling_machine)'}, {'frequency': 'c', 'synset': 'starfish.n.01', 'synonyms': ['starfish', 'sea_star'], 'id': 1007, 'def': 'echinoderms characterized by five arms extending from a central disk', 'name': 'starfish'}, {'frequency': 'f', 'synset': 'statue.n.01', 'synonyms': ['statue_(sculpture)'], 'id': 1008, 'def': 'a sculpture representing a human or animal', 'name': 'statue_(sculpture)'}, {'frequency': 'c', 'synset': 'steak.n.01', 'synonyms': ['steak_(food)'], 'id': 1009, 'def': 'a slice of meat cut from the fleshy part of an animal or large fish', 'name': 'steak_(food)'}, {'frequency': 'r', 'synset': 'steak_knife.n.01', 'synonyms': ['steak_knife'], 'id': 1010, 'def': 'a sharp table knife used in eating steak', 'name': 'steak_knife'}, {'frequency': 'f', 'synset': 'steering_wheel.n.01', 'synonyms': ['steering_wheel'], 'id': 1011, 'def': 'a handwheel that is used for steering', 'name': 'steering_wheel'}, {'frequency': 'r', 'synset': 'step_ladder.n.01', 'synonyms': ['stepladder'], 'id': 1012, 'def': 'a folding portable ladder hinged at the top', 'name': 'stepladder'}, {'frequency': 'c', 'synset': 'step_stool.n.01', 'synonyms': ['step_stool'], 'id': 1013, 'def': 'a stool that has one or two steps that fold under the seat', 'name': 'step_stool'}, {'frequency': 'c', 'synset': 'stereo.n.01', 'synonyms': ['stereo_(sound_system)'], 'id': 1014, 'def': 'electronic device for playing audio', 'name': 'stereo_(sound_system)'}, {'frequency': 'r', 'synset': 'stew.n.02', 'synonyms': ['stew'], 'id': 1015, 'def': 'food prepared by stewing especially meat or fish with vegetables', 'name': 'stew'}, {'frequency': 'r', 'synset': 'stirrer.n.02', 'synonyms': ['stirrer'], 'id': 1016, 'def': 'an implement used for stirring', 'name': 'stirrer'}, {'frequency': 'f', 'synset': 'stirrup.n.01', 'synonyms': ['stirrup'], 'id': 1017, 'def': "support consisting of metal loops into which rider's feet go", 'name': 'stirrup'}, {'frequency': 'f', 'synset': 'stool.n.01', 'synonyms': ['stool'], 'id': 1018, 'def': 'a simple seat without a back or arms', 'name': 'stool'}, {'frequency': 'f', 'synset': 'stop_sign.n.01', 'synonyms': ['stop_sign'], 'id': 1019, 'def': 'a traffic sign to notify drivers that they must come to a complete stop', 'name': 'stop_sign'}, {'frequency': 'f', 'synset': 'stoplight.n.01', 'synonyms': ['brake_light'], 'id': 1020, 'def': 'a red light on the rear of a motor vehicle that signals when the brakes are applied', 'name': 'brake_light'}, {'frequency': 'f', 'synset': 'stove.n.01', 'synonyms': ['stove', 'kitchen_stove', 'range_(kitchen_appliance)', 'kitchen_range', 'cooking_stove'], 'id': 1021, 'def': 'a kitchen appliance used for cooking food', 'name': 'stove'}, {'frequency': 'c', 'synset': 'strainer.n.01', 'synonyms': ['strainer'], 'id': 1022, 'def': 'a filter to retain larger pieces while smaller pieces and liquids pass through', 'name': 'strainer'}, {'frequency': 'f', 'synset': 'strap.n.01', 'synonyms': ['strap'], 'id': 1023, 'def': 'an elongated strip of material for binding things together or holding', 'name': 'strap'}, {'frequency': 'f', 'synset': 'straw.n.04', 'synonyms': ['straw_(for_drinking)', 'drinking_straw'], 'id': 1024, 'def': 'a thin paper or plastic tube used to suck liquids into the mouth', 'name': 'straw_(for_drinking)'}, {'frequency': 'f', 'synset': 'strawberry.n.01', 'synonyms': ['strawberry'], 'id': 1025, 'def': 'sweet fleshy red fruit', 'name': 'strawberry'}, {'frequency': 'f', 'synset': 'street_sign.n.01', 'synonyms': ['street_sign'], 'id': 1026, 'def': 'a sign visible from the street', 'name': 'street_sign'}, {'frequency': 'f', 'synset': 'streetlight.n.01', 'synonyms': ['streetlight', 'street_lamp'], 'id': 1027, 'def': 'a lamp supported on a lamppost; for illuminating a street', 'name': 'streetlight'}, {'frequency': 'r', 'synset': 'string_cheese.n.01', 'synonyms': ['string_cheese'], 'id': 1028, 'def': 'cheese formed in long strings twisted together', 'name': 'string_cheese'}, {'frequency': 'r', 'synset': 'stylus.n.02', 'synonyms': ['stylus'], 'id': 1029, 'def': 'a pointed tool for writing or drawing or engraving, including pens', 'name': 'stylus'}, {'frequency': 'r', 'synset': 'subwoofer.n.01', 'synonyms': ['subwoofer'], 'id': 1030, 'def': 'a loudspeaker that is designed to reproduce very low bass frequencies', 'name': 'subwoofer'}, {'frequency': 'r', 'synset': 'sugar_bowl.n.01', 'synonyms': ['sugar_bowl'], 'id': 1031, 'def': 'a dish in which sugar is served', 'name': 'sugar_bowl'}, {'frequency': 'r', 'synset': 'sugarcane.n.01', 'synonyms': ['sugarcane_(plant)'], 'id': 1032, 'def': 'juicy canes whose sap is a source of molasses and commercial sugar; fresh canes are sometimes chewed for the juice', 'name': 'sugarcane_(plant)'}, {'frequency': 'f', 'synset': 'suit.n.01', 'synonyms': ['suit_(clothing)'], 'id': 1033, 'def': 'a set of garments (usually including a jacket and trousers or skirt) for outerwear all of the same fabric and color', 'name': 'suit_(clothing)'}, {'frequency': 'c', 'synset': 'sunflower.n.01', 'synonyms': ['sunflower'], 'id': 1034, 'def': 'any plant of the genus Helianthus having large flower heads with dark disk florets and showy yellow rays', 'name': 'sunflower'}, {'frequency': 'f', 'synset': 'sunglasses.n.01', 'synonyms': ['sunglasses'], 'id': 1035, 'def': 'spectacles that are darkened or polarized to protect the eyes from the glare of the sun', 'name': 'sunglasses'}, {'frequency': 'c', 'synset': 'sunhat.n.01', 'synonyms': ['sunhat'], 'id': 1036, 'def': 'a hat with a broad brim that protects the face from direct exposure to the sun', 'name': 'sunhat'}, {'frequency': 'f', 'synset': 'surfboard.n.01', 'synonyms': ['surfboard'], 'id': 1037, 'def': 'a narrow buoyant board for riding surf', 'name': 'surfboard'}, {'frequency': 'c', 'synset': 'sushi.n.01', 'synonyms': ['sushi'], 'id': 1038, 'def': 'rice (with raw fish) wrapped in seaweed', 'name': 'sushi'}, {'frequency': 'c', 'synset': 'swab.n.02', 'synonyms': ['mop'], 'id': 1039, 'def': 'cleaning implement consisting of absorbent material fastened to a handle; for cleaning floors', 'name': 'mop'}, {'frequency': 'c', 'synset': 'sweat_pants.n.01', 'synonyms': ['sweat_pants'], 'id': 1040, 'def': 'loose-fitting trousers with elastic cuffs; worn by athletes', 'name': 'sweat_pants'}, {'frequency': 'c', 'synset': 'sweatband.n.02', 'synonyms': ['sweatband'], 'id': 1041, 'def': 'a band of material tied around the forehead or wrist to absorb sweat', 'name': 'sweatband'}, {'frequency': 'f', 'synset': 'sweater.n.01', 'synonyms': ['sweater'], 'id': 1042, 'def': 'a crocheted or knitted garment covering the upper part of the body', 'name': 'sweater'}, {'frequency': 'f', 'synset': 'sweatshirt.n.01', 'synonyms': ['sweatshirt'], 'id': 1043, 'def': 'cotton knit pullover with long sleeves worn during athletic activity', 'name': 'sweatshirt'}, {'frequency': 'c', 'synset': 'sweet_potato.n.02', 'synonyms': ['sweet_potato'], 'id': 1044, 'def': 'the edible tuberous root of the sweet potato vine', 'name': 'sweet_potato'}, {'frequency': 'f', 'synset': 'swimsuit.n.01', 'synonyms': ['swimsuit', 'swimwear', 'bathing_suit', 'swimming_costume', 'bathing_costume', 'swimming_trunks', 'bathing_trunks'], 'id': 1045, 'def': 'garment worn for swimming', 'name': 'swimsuit'}, {'frequency': 'c', 'synset': 'sword.n.01', 'synonyms': ['sword'], 'id': 1046, 'def': 'a cutting or thrusting weapon that has a long metal blade', 'name': 'sword'}, {'frequency': 'r', 'synset': 'syringe.n.01', 'synonyms': ['syringe'], 'id': 1047, 'def': 'a medical instrument used to inject or withdraw fluids', 'name': 'syringe'}, {'frequency': 'r', 'synset': 'tabasco.n.02', 'synonyms': ['Tabasco_sauce'], 'id': 1048, 'def': 'very spicy sauce (trade name Tabasco) made from fully-aged red peppers', 'name': 'Tabasco_sauce'}, {'frequency': 'r', 'synset': 'table-tennis_table.n.01', 'synonyms': ['table-tennis_table', 'ping-pong_table'], 'id': 1049, 'def': 'a table used for playing table tennis', 'name': 'table-tennis_table'}, {'frequency': 'f', 'synset': 'table.n.02', 'synonyms': ['table'], 'id': 1050, 'def': 'a piece of furniture having a smooth flat top that is usually supported by one or more vertical legs', 'name': 'table'}, {'frequency': 'c', 'synset': 'table_lamp.n.01', 'synonyms': ['table_lamp'], 'id': 1051, 'def': 'a lamp that sits on a table', 'name': 'table_lamp'}, {'frequency': 'f', 'synset': 'tablecloth.n.01', 'synonyms': ['tablecloth'], 'id': 1052, 'def': 'a covering spread over a dining table', 'name': 'tablecloth'}, {'frequency': 'r', 'synset': 'tachometer.n.01', 'synonyms': ['tachometer'], 'id': 1053, 'def': 'measuring instrument for indicating speed of rotation', 'name': 'tachometer'}, {'frequency': 'r', 'synset': 'taco.n.02', 'synonyms': ['taco'], 'id': 1054, 'def': 'a small tortilla cupped around a filling', 'name': 'taco'}, {'frequency': 'f', 'synset': 'tag.n.02', 'synonyms': ['tag'], 'id': 1055, 'def': 'a label associated with something for the purpose of identification or information', 'name': 'tag'}, {'frequency': 'f', 'synset': 'taillight.n.01', 'synonyms': ['taillight', 'rear_light'], 'id': 1056, 'def': 'lamp (usually red) mounted at the rear of a motor vehicle', 'name': 'taillight'}, {'frequency': 'r', 'synset': 'tambourine.n.01', 'synonyms': ['tambourine'], 'id': 1057, 'def': 'a shallow drum with a single drumhead and with metallic disks in the sides', 'name': 'tambourine'}, {'frequency': 'r', 'synset': 'tank.n.01', 'synonyms': ['army_tank', 'armored_combat_vehicle', 'armoured_combat_vehicle'], 'id': 1058, 'def': 'an enclosed armored military vehicle; has a cannon and moves on caterpillar treads', 'name': 'army_tank'}, {'frequency': 'f', 'synset': 'tank.n.02', 'synonyms': ['tank_(storage_vessel)', 'storage_tank'], 'id': 1059, 'def': 'a large (usually metallic) vessel for holding gases or liquids', 'name': 'tank_(storage_vessel)'}, {'frequency': 'f', 'synset': 'tank_top.n.01', 'synonyms': ['tank_top_(clothing)'], 'id': 1060, 'def': 'a tight-fitting sleeveless shirt with wide shoulder straps and low neck and no front opening', 'name': 'tank_top_(clothing)'}, {'frequency': 'f', 'synset': 'tape.n.01', 'synonyms': ['tape_(sticky_cloth_or_paper)'], 'id': 1061, 'def': 'a long thin piece of cloth or paper as used for binding or fastening', 'name': 'tape_(sticky_cloth_or_paper)'}, {'frequency': 'c', 'synset': 'tape.n.04', 'synonyms': ['tape_measure', 'measuring_tape'], 'id': 1062, 'def': 'measuring instrument consisting of a narrow strip (cloth or metal) marked in inches or centimeters and used for measuring lengths', 'name': 'tape_measure'}, {'frequency': 'c', 'synset': 'tapestry.n.02', 'synonyms': ['tapestry'], 'id': 1063, 'def': 'a heavy textile with a woven design; used for curtains and upholstery', 'name': 'tapestry'}, {'frequency': 'f', 'synset': 'tarpaulin.n.01', 'synonyms': ['tarp'], 'id': 1064, 'def': 'waterproofed canvas', 'name': 'tarp'}, {'frequency': 'c', 'synset': 'tartan.n.01', 'synonyms': ['tartan', 'plaid'], 'id': 1065, 'def': 'a cloth having a crisscross design', 'name': 'tartan'}, {'frequency': 'c', 'synset': 'tassel.n.01', 'synonyms': ['tassel'], 'id': 1066, 'def': 'adornment consisting of a bunch of cords fastened at one end', 'name': 'tassel'}, {'frequency': 'c', 'synset': 'tea_bag.n.01', 'synonyms': ['tea_bag'], 'id': 1067, 'def': 'a measured amount of tea in a bag for an individual serving of tea', 'name': 'tea_bag'}, {'frequency': 'c', 'synset': 'teacup.n.02', 'synonyms': ['teacup'], 'id': 1068, 'def': 'a cup from which tea is drunk', 'name': 'teacup'}, {'frequency': 'c', 'synset': 'teakettle.n.01', 'synonyms': ['teakettle'], 'id': 1069, 'def': 'kettle for boiling water to make tea', 'name': 'teakettle'}, {'frequency': 'f', 'synset': 'teapot.n.01', 'synonyms': ['teapot'], 'id': 1070, 'def': 'pot for brewing tea; usually has a spout and handle', 'name': 'teapot'}, {'frequency': 'f', 'synset': 'teddy.n.01', 'synonyms': ['teddy_bear'], 'id': 1071, 'def': "plaything consisting of a child's toy bear (usually plush and stuffed with soft materials)", 'name': 'teddy_bear'}, {'frequency': 'f', 'synset': 'telephone.n.01', 'synonyms': ['telephone', 'phone', 'telephone_set'], 'id': 1072, 'def': 'electronic device for communicating by voice over long distances (includes wired and wireless/cell phones)', 'name': 'telephone'}, {'frequency': 'c', 'synset': 'telephone_booth.n.01', 'synonyms': ['telephone_booth', 'phone_booth', 'call_box', 'telephone_box', 'telephone_kiosk'], 'id': 1073, 'def': 'booth for using a telephone', 'name': 'telephone_booth'}, {'frequency': 'f', 'synset': 'telephone_pole.n.01', 'synonyms': ['telephone_pole', 'telegraph_pole', 'telegraph_post'], 'id': 1074, 'def': 'tall pole supporting telephone wires', 'name': 'telephone_pole'}, {'frequency': 'r', 'synset': 'telephoto_lens.n.01', 'synonyms': ['telephoto_lens', 'zoom_lens'], 'id': 1075, 'def': 'a camera lens that magnifies the image', 'name': 'telephoto_lens'}, {'frequency': 'c', 'synset': 'television_camera.n.01', 'synonyms': ['television_camera', 'tv_camera'], 'id': 1076, 'def': 'television equipment for capturing and recording video', 'name': 'television_camera'}, {'frequency': 'f', 'synset': 'television_receiver.n.01', 'synonyms': ['television_set', 'tv', 'tv_set'], 'id': 1077, 'def': 'an electronic device that receives television signals and displays them on a screen', 'name': 'television_set'}, {'frequency': 'f', 'synset': 'tennis_ball.n.01', 'synonyms': ['tennis_ball'], 'id': 1078, 'def': 'ball about the size of a fist used in playing tennis', 'name': 'tennis_ball'}, {'frequency': 'f', 'synset': 'tennis_racket.n.01', 'synonyms': ['tennis_racket'], 'id': 1079, 'def': 'a racket used to play tennis', 'name': 'tennis_racket'}, {'frequency': 'r', 'synset': 'tequila.n.01', 'synonyms': ['tequila'], 'id': 1080, 'def': 'Mexican liquor made from fermented juices of an agave plant', 'name': 'tequila'}, {'frequency': 'c', 'synset': 'thermometer.n.01', 'synonyms': ['thermometer'], 'id': 1081, 'def': 'measuring instrument for measuring temperature', 'name': 'thermometer'}, {'frequency': 'c', 'synset': 'thermos.n.01', 'synonyms': ['thermos_bottle'], 'id': 1082, 'def': 'vacuum flask that preserves temperature of hot or cold drinks', 'name': 'thermos_bottle'}, {'frequency': 'f', 'synset': 'thermostat.n.01', 'synonyms': ['thermostat'], 'id': 1083, 'def': 'a regulator for automatically regulating temperature by starting or stopping the supply of heat', 'name': 'thermostat'}, {'frequency': 'r', 'synset': 'thimble.n.02', 'synonyms': ['thimble'], 'id': 1084, 'def': 'a small metal cap to protect the finger while sewing; can be used as a small container', 'name': 'thimble'}, {'frequency': 'c', 'synset': 'thread.n.01', 'synonyms': ['thread', 'yarn'], 'id': 1085, 'def': 'a fine cord of twisted fibers (of cotton or silk or wool or nylon etc.) used in sewing and weaving', 'name': 'thread'}, {'frequency': 'c', 'synset': 'thumbtack.n.01', 'synonyms': ['thumbtack', 'drawing_pin', 'pushpin'], 'id': 1086, 'def': 'a tack for attaching papers to a bulletin board or drawing board', 'name': 'thumbtack'}, {'frequency': 'c', 'synset': 'tiara.n.01', 'synonyms': ['tiara'], 'id': 1087, 'def': 'a jeweled headdress worn by women on formal occasions', 'name': 'tiara'}, {'frequency': 'c', 'synset': 'tiger.n.02', 'synonyms': ['tiger'], 'id': 1088, 'def': 'large feline of forests in most of Asia having a tawny coat with black stripes', 'name': 'tiger'}, {'frequency': 'c', 'synset': 'tights.n.01', 'synonyms': ['tights_(clothing)', 'leotards'], 'id': 1089, 'def': 'skintight knit hose covering the body from the waist to the feet worn by acrobats and dancers and as stockings by women and girls', 'name': 'tights_(clothing)'}, {'frequency': 'c', 'synset': 'timer.n.01', 'synonyms': ['timer', 'stopwatch'], 'id': 1090, 'def': 'a timepiece that measures a time interval and signals its end', 'name': 'timer'}, {'frequency': 'f', 'synset': 'tinfoil.n.01', 'synonyms': ['tinfoil'], 'id': 1091, 'def': 'foil made of tin or an alloy of tin and lead', 'name': 'tinfoil'}, {'frequency': 'c', 'synset': 'tinsel.n.01', 'synonyms': ['tinsel'], 'id': 1092, 'def': 'a showy decoration that is basically valueless', 'name': 'tinsel'}, {'frequency': 'f', 'synset': 'tissue.n.02', 'synonyms': ['tissue_paper'], 'id': 1093, 'def': 'a soft thin (usually translucent) paper', 'name': 'tissue_paper'}, {'frequency': 'c', 'synset': 'toast.n.01', 'synonyms': ['toast_(food)'], 'id': 1094, 'def': 'slice of bread that has been toasted', 'name': 'toast_(food)'}, {'frequency': 'f', 'synset': 'toaster.n.02', 'synonyms': ['toaster'], 'id': 1095, 'def': 'a kitchen appliance (usually electric) for toasting bread', 'name': 'toaster'}, {'frequency': 'f', 'synset': 'toaster_oven.n.01', 'synonyms': ['toaster_oven'], 'id': 1096, 'def': 'kitchen appliance consisting of a small electric oven for toasting or warming food', 'name': 'toaster_oven'}, {'frequency': 'f', 'synset': 'toilet.n.02', 'synonyms': ['toilet'], 'id': 1097, 'def': 'a plumbing fixture for defecation and urination', 'name': 'toilet'}, {'frequency': 'f', 'synset': 'toilet_tissue.n.01', 'synonyms': ['toilet_tissue', 'toilet_paper', 'bathroom_tissue'], 'id': 1098, 'def': 'a soft thin absorbent paper for use in toilets', 'name': 'toilet_tissue'}, {'frequency': 'f', 'synset': 'tomato.n.01', 'synonyms': ['tomato'], 'id': 1099, 'def': 'mildly acid red or yellow pulpy fruit eaten as a vegetable', 'name': 'tomato'}, {'frequency': 'f', 'synset': 'tongs.n.01', 'synonyms': ['tongs'], 'id': 1100, 'def': 'any of various devices for taking hold of objects; usually have two hinged legs with handles above and pointed hooks below', 'name': 'tongs'}, {'frequency': 'c', 'synset': 'toolbox.n.01', 'synonyms': ['toolbox'], 'id': 1101, 'def': 'a box or chest or cabinet for holding hand tools', 'name': 'toolbox'}, {'frequency': 'f', 'synset': 'toothbrush.n.01', 'synonyms': ['toothbrush'], 'id': 1102, 'def': 'small brush; has long handle; used to clean teeth', 'name': 'toothbrush'}, {'frequency': 'f', 'synset': 'toothpaste.n.01', 'synonyms': ['toothpaste'], 'id': 1103, 'def': 'a dentifrice in the form of a paste', 'name': 'toothpaste'}, {'frequency': 'f', 'synset': 'toothpick.n.01', 'synonyms': ['toothpick'], 'id': 1104, 'def': 'pick consisting of a small strip of wood or plastic; used to pick food from between the teeth', 'name': 'toothpick'}, {'frequency': 'f', 'synset': 'top.n.09', 'synonyms': ['cover'], 'id': 1105, 'def': 'covering for a hole (especially a hole in the top of a container)', 'name': 'cover'}, {'frequency': 'c', 'synset': 'tortilla.n.01', 'synonyms': ['tortilla'], 'id': 1106, 'def': 'thin unleavened pancake made from cornmeal or wheat flour', 'name': 'tortilla'}, {'frequency': 'c', 'synset': 'tow_truck.n.01', 'synonyms': ['tow_truck'], 'id': 1107, 'def': 'a truck equipped to hoist and pull wrecked cars (or to remove cars from no-parking zones)', 'name': 'tow_truck'}, {'frequency': 'f', 'synset': 'towel.n.01', 'synonyms': ['towel'], 'id': 1108, 'def': 'a rectangular piece of absorbent cloth (or paper) for drying or wiping', 'name': 'towel'}, {'frequency': 'f', 'synset': 'towel_rack.n.01', 'synonyms': ['towel_rack', 'towel_rail', 'towel_bar'], 'id': 1109, 'def': 'a rack consisting of one or more bars on which towels can be hung', 'name': 'towel_rack'}, {'frequency': 'f', 'synset': 'toy.n.03', 'synonyms': ['toy'], 'id': 1110, 'def': 'a device regarded as providing amusement', 'name': 'toy'}, {'frequency': 'c', 'synset': 'tractor.n.01', 'synonyms': ['tractor_(farm_equipment)'], 'id': 1111, 'def': 'a wheeled vehicle with large wheels; used in farming and other applications', 'name': 'tractor_(farm_equipment)'}, {'frequency': 'f', 'synset': 'traffic_light.n.01', 'synonyms': ['traffic_light'], 'id': 1112, 'def': 'a device to control vehicle traffic often consisting of three or more lights', 'name': 'traffic_light'}, {'frequency': 'c', 'synset': 'trail_bike.n.01', 'synonyms': ['dirt_bike'], 'id': 1113, 'def': 'a lightweight motorcycle equipped with rugged tires and suspension for off-road use', 'name': 'dirt_bike'}, {'frequency': 'f', 'synset': 'trailer_truck.n.01', 'synonyms': ['trailer_truck', 'tractor_trailer', 'trucking_rig', 'articulated_lorry', 'semi_truck'], 'id': 1114, 'def': 'a truck consisting of a tractor and trailer together', 'name': 'trailer_truck'}, {'frequency': 'f', 'synset': 'train.n.01', 'synonyms': ['train_(railroad_vehicle)', 'railroad_train'], 'id': 1115, 'def': 'public or private transport provided by a line of railway cars coupled together and drawn by a locomotive', 'name': 'train_(railroad_vehicle)'}, {'frequency': 'r', 'synset': 'trampoline.n.01', 'synonyms': ['trampoline'], 'id': 1116, 'def': 'gymnastic apparatus consisting of a strong canvas sheet attached with springs to a metal frame', 'name': 'trampoline'}, {'frequency': 'f', 'synset': 'tray.n.01', 'synonyms': ['tray'], 'id': 1117, 'def': 'an open receptacle for holding or displaying or serving articles or food', 'name': 'tray'}, {'frequency': 'r', 'synset': 'trench_coat.n.01', 'synonyms': ['trench_coat'], 'id': 1118, 'def': 'a military style raincoat; belted with deep pockets', 'name': 'trench_coat'}, {'frequency': 'r', 'synset': 'triangle.n.05', 'synonyms': ['triangle_(musical_instrument)'], 'id': 1119, 'def': 'a percussion instrument consisting of a metal bar bent in the shape of an open triangle', 'name': 'triangle_(musical_instrument)'}, {'frequency': 'c', 'synset': 'tricycle.n.01', 'synonyms': ['tricycle'], 'id': 1120, 'def': 'a vehicle with three wheels that is moved by foot pedals', 'name': 'tricycle'}, {'frequency': 'f', 'synset': 'tripod.n.01', 'synonyms': ['tripod'], 'id': 1121, 'def': 'a three-legged rack used for support', 'name': 'tripod'}, {'frequency': 'f', 'synset': 'trouser.n.01', 'synonyms': ['trousers', 'pants_(clothing)'], 'id': 1122, 'def': 'a garment extending from the waist to the knee or ankle, covering each leg separately', 'name': 'trousers'}, {'frequency': 'f', 'synset': 'truck.n.01', 'synonyms': ['truck'], 'id': 1123, 'def': 'an automotive vehicle suitable for hauling', 'name': 'truck'}, {'frequency': 'r', 'synset': 'truffle.n.03', 'synonyms': ['truffle_(chocolate)', 'chocolate_truffle'], 'id': 1124, 'def': 'creamy chocolate candy', 'name': 'truffle_(chocolate)'}, {'frequency': 'c', 'synset': 'trunk.n.02', 'synonyms': ['trunk'], 'id': 1125, 'def': 'luggage consisting of a large strong case used when traveling or for storage', 'name': 'trunk'}, {'frequency': 'r', 'synset': 'tub.n.02', 'synonyms': ['vat'], 'id': 1126, 'def': 'a large vessel for holding or storing liquids', 'name': 'vat'}, {'frequency': 'c', 'synset': 'turban.n.01', 'synonyms': ['turban'], 'id': 1127, 'def': 'a traditional headdress consisting of a long scarf wrapped around the head', 'name': 'turban'}, {'frequency': 'c', 'synset': 'turkey.n.04', 'synonyms': ['turkey_(food)'], 'id': 1128, 'def': 'flesh of large domesticated fowl usually roasted', 'name': 'turkey_(food)'}, {'frequency': 'r', 'synset': 'turnip.n.01', 'synonyms': ['turnip'], 'id': 1129, 'def': 'widely cultivated plant having a large fleshy edible white or yellow root', 'name': 'turnip'}, {'frequency': 'c', 'synset': 'turtle.n.02', 'synonyms': ['turtle'], 'id': 1130, 'def': 'any of various aquatic and land reptiles having a bony shell and flipper-like limbs for swimming', 'name': 'turtle'}, {'frequency': 'c', 'synset': 'turtleneck.n.01', 'synonyms': ['turtleneck_(clothing)', 'polo-neck'], 'id': 1131, 'def': 'a sweater or jersey with a high close-fitting collar', 'name': 'turtleneck_(clothing)'}, {'frequency': 'c', 'synset': 'typewriter.n.01', 'synonyms': ['typewriter'], 'id': 1132, 'def': 'hand-operated character printer for printing written messages one character at a time', 'name': 'typewriter'}, {'frequency': 'f', 'synset': 'umbrella.n.01', 'synonyms': ['umbrella'], 'id': 1133, 'def': 'a lightweight handheld collapsible canopy', 'name': 'umbrella'}, {'frequency': 'f', 'synset': 'underwear.n.01', 'synonyms': ['underwear', 'underclothes', 'underclothing', 'underpants'], 'id': 1134, 'def': 'undergarment worn next to the skin and under the outer garments', 'name': 'underwear'}, {'frequency': 'r', 'synset': 'unicycle.n.01', 'synonyms': ['unicycle'], 'id': 1135, 'def': 'a vehicle with a single wheel that is driven by pedals', 'name': 'unicycle'}, {'frequency': 'f', 'synset': 'urinal.n.01', 'synonyms': ['urinal'], 'id': 1136, 'def': 'a plumbing fixture (usually attached to the wall) used by men to urinate', 'name': 'urinal'}, {'frequency': 'c', 'synset': 'urn.n.01', 'synonyms': ['urn'], 'id': 1137, 'def': 'a large vase that usually has a pedestal or feet', 'name': 'urn'}, {'frequency': 'c', 'synset': 'vacuum.n.04', 'synonyms': ['vacuum_cleaner'], 'id': 1138, 'def': 'an electrical home appliance that cleans by suction', 'name': 'vacuum_cleaner'}, {'frequency': 'f', 'synset': 'vase.n.01', 'synonyms': ['vase'], 'id': 1139, 'def': 'an open jar of glass or porcelain used as an ornament or to hold flowers', 'name': 'vase'}, {'frequency': 'c', 'synset': 'vending_machine.n.01', 'synonyms': ['vending_machine'], 'id': 1140, 'def': 'a slot machine for selling goods', 'name': 'vending_machine'}, {'frequency': 'f', 'synset': 'vent.n.01', 'synonyms': ['vent', 'blowhole', 'air_vent'], 'id': 1141, 'def': 'a hole for the escape of gas or air', 'name': 'vent'}, {'frequency': 'f', 'synset': 'vest.n.01', 'synonyms': ['vest', 'waistcoat'], 'id': 1142, 'def': "a man's sleeveless garment worn underneath a coat", 'name': 'vest'}, {'frequency': 'c', 'synset': 'videotape.n.01', 'synonyms': ['videotape'], 'id': 1143, 'def': 'a video recording made on magnetic tape', 'name': 'videotape'}, {'frequency': 'r', 'synset': 'vinegar.n.01', 'synonyms': ['vinegar'], 'id': 1144, 'def': 'sour-tasting liquid produced usually by oxidation of the alcohol in wine or cider and used as a condiment or food preservative', 'name': 'vinegar'}, {'frequency': 'r', 'synset': 'violin.n.01', 'synonyms': ['violin', 'fiddle'], 'id': 1145, 'def': 'bowed stringed instrument that is the highest member of the violin family', 'name': 'violin'}, {'frequency': 'r', 'synset': 'vodka.n.01', 'synonyms': ['vodka'], 'id': 1146, 'def': 'unaged colorless liquor originating in Russia', 'name': 'vodka'}, {'frequency': 'c', 'synset': 'volleyball.n.02', 'synonyms': ['volleyball'], 'id': 1147, 'def': 'an inflated ball used in playing volleyball', 'name': 'volleyball'}, {'frequency': 'r', 'synset': 'vulture.n.01', 'synonyms': ['vulture'], 'id': 1148, 'def': 'any of various large birds of prey having naked heads and weak claws and feeding chiefly on carrion', 'name': 'vulture'}, {'frequency': 'c', 'synset': 'waffle.n.01', 'synonyms': ['waffle'], 'id': 1149, 'def': 'pancake batter baked in a waffle iron', 'name': 'waffle'}, {'frequency': 'r', 'synset': 'waffle_iron.n.01', 'synonyms': ['waffle_iron'], 'id': 1150, 'def': 'a kitchen appliance for baking waffles', 'name': 'waffle_iron'}, {'frequency': 'c', 'synset': 'wagon.n.01', 'synonyms': ['wagon'], 'id': 1151, 'def': 'any of various kinds of wheeled vehicles drawn by an animal or a tractor', 'name': 'wagon'}, {'frequency': 'c', 'synset': 'wagon_wheel.n.01', 'synonyms': ['wagon_wheel'], 'id': 1152, 'def': 'a wheel of a wagon', 'name': 'wagon_wheel'}, {'frequency': 'c', 'synset': 'walking_stick.n.01', 'synonyms': ['walking_stick'], 'id': 1153, 'def': 'a stick carried in the hand for support in walking', 'name': 'walking_stick'}, {'frequency': 'c', 'synset': 'wall_clock.n.01', 'synonyms': ['wall_clock'], 'id': 1154, 'def': 'a clock mounted on a wall', 'name': 'wall_clock'}, {'frequency': 'f', 'synset': 'wall_socket.n.01', 'synonyms': ['wall_socket', 'wall_plug', 'electric_outlet', 'electrical_outlet', 'outlet', 'electric_receptacle'], 'id': 1155, 'def': 'receptacle providing a place in a wiring system where current can be taken to run electrical devices', 'name': 'wall_socket'}, {'frequency': 'f', 'synset': 'wallet.n.01', 'synonyms': ['wallet', 'billfold'], 'id': 1156, 'def': 'a pocket-size case for holding papers and paper money', 'name': 'wallet'}, {'frequency': 'r', 'synset': 'walrus.n.01', 'synonyms': ['walrus'], 'id': 1157, 'def': 'either of two large northern marine mammals having ivory tusks and tough hide over thick blubber', 'name': 'walrus'}, {'frequency': 'r', 'synset': 'wardrobe.n.01', 'synonyms': ['wardrobe'], 'id': 1158, 'def': 'a tall piece of furniture that provides storage space for clothes; has a door and rails or hooks for hanging clothes', 'name': 'wardrobe'}, {'frequency': 'r', 'synset': 'washbasin.n.01', 'synonyms': ['washbasin', 'basin_(for_washing)', 'washbowl', 'washstand', 'handbasin'], 'id': 1159, 'def': 'a bathroom sink that is permanently installed and connected to a water supply and drainpipe; where you can wash your hands and face', 'name': 'washbasin'}, {'frequency': 'c', 'synset': 'washer.n.03', 'synonyms': ['automatic_washer', 'washing_machine'], 'id': 1160, 'def': 'a home appliance for washing clothes and linens automatically', 'name': 'automatic_washer'}, {'frequency': 'f', 'synset': 'watch.n.01', 'synonyms': ['watch', 'wristwatch'], 'id': 1161, 'def': 'a small, portable timepiece', 'name': 'watch'}, {'frequency': 'f', 'synset': 'water_bottle.n.01', 'synonyms': ['water_bottle'], 'id': 1162, 'def': 'a bottle for holding water', 'name': 'water_bottle'}, {'frequency': 'c', 'synset': 'water_cooler.n.01', 'synonyms': ['water_cooler'], 'id': 1163, 'def': 'a device for cooling and dispensing drinking water', 'name': 'water_cooler'}, {'frequency': 'c', 'synset': 'water_faucet.n.01', 'synonyms': ['water_faucet', 'water_tap', 'tap_(water_faucet)'], 'id': 1164, 'def': 'a faucet for drawing water from a pipe or cask', 'name': 'water_faucet'}, {'frequency': 'r', 'synset': 'water_heater.n.01', 'synonyms': ['water_heater', 'hot-water_heater'], 'id': 1165, 'def': 'a heater and storage tank to supply heated water', 'name': 'water_heater'}, {'frequency': 'c', 'synset': 'water_jug.n.01', 'synonyms': ['water_jug'], 'id': 1166, 'def': 'a jug that holds water', 'name': 'water_jug'}, {'frequency': 'r', 'synset': 'water_pistol.n.01', 'synonyms': ['water_gun', 'squirt_gun'], 'id': 1167, 'def': 'plaything consisting of a toy pistol that squirts water', 'name': 'water_gun'}, {'frequency': 'c', 'synset': 'water_scooter.n.01', 'synonyms': ['water_scooter', 'sea_scooter', 'jet_ski'], 'id': 1168, 'def': 'a motorboat resembling a motor scooter (NOT A SURFBOARD OR WATER SKI)', 'name': 'water_scooter'}, {'frequency': 'c', 'synset': 'water_ski.n.01', 'synonyms': ['water_ski'], 'id': 1169, 'def': 'broad ski for skimming over water towed by a speedboat (DO NOT MARK WATER)', 'name': 'water_ski'}, {'frequency': 'c', 'synset': 'water_tower.n.01', 'synonyms': ['water_tower'], 'id': 1170, 'def': 'a large reservoir for water', 'name': 'water_tower'}, {'frequency': 'c', 'synset': 'watering_can.n.01', 'synonyms': ['watering_can'], 'id': 1171, 'def': 'a container with a handle and a spout with a perforated nozzle; used to sprinkle water over plants', 'name': 'watering_can'}, {'frequency': 'f', 'synset': 'watermelon.n.02', 'synonyms': ['watermelon'], 'id': 1172, 'def': 'large oblong or roundish melon with a hard green rind and sweet watery red or occasionally yellowish pulp', 'name': 'watermelon'}, {'frequency': 'f', 'synset': 'weathervane.n.01', 'synonyms': ['weathervane', 'vane_(weathervane)', 'wind_vane'], 'id': 1173, 'def': 'mechanical device attached to an elevated structure; rotates freely to show the direction of the wind', 'name': 'weathervane'}, {'frequency': 'c', 'synset': 'webcam.n.01', 'synonyms': ['webcam'], 'id': 1174, 'def': 'a digital camera designed to take digital photographs and transmit them over the internet', 'name': 'webcam'}, {'frequency': 'c', 'synset': 'wedding_cake.n.01', 'synonyms': ['wedding_cake', 'bridecake'], 'id': 1175, 'def': 'a rich cake with two or more tiers and covered with frosting and decorations; served at a wedding reception', 'name': 'wedding_cake'}, {'frequency': 'c', 'synset': 'wedding_ring.n.01', 'synonyms': ['wedding_ring', 'wedding_band'], 'id': 1176, 'def': 'a ring given to the bride and/or groom at the wedding', 'name': 'wedding_ring'}, {'frequency': 'f', 'synset': 'wet_suit.n.01', 'synonyms': ['wet_suit'], 'id': 1177, 'def': 'a close-fitting garment made of a permeable material; worn in cold water to retain body heat', 'name': 'wet_suit'}, {'frequency': 'f', 'synset': 'wheel.n.01', 'synonyms': ['wheel'], 'id': 1178, 'def': 'a circular frame with spokes (or a solid disc) that can rotate on a shaft or axle', 'name': 'wheel'}, {'frequency': 'c', 'synset': 'wheelchair.n.01', 'synonyms': ['wheelchair'], 'id': 1179, 'def': 'a movable chair mounted on large wheels', 'name': 'wheelchair'}, {'frequency': 'c', 'synset': 'whipped_cream.n.01', 'synonyms': ['whipped_cream'], 'id': 1180, 'def': 'cream that has been beaten until light and fluffy', 'name': 'whipped_cream'}, {'frequency': 'c', 'synset': 'whistle.n.03', 'synonyms': ['whistle'], 'id': 1181, 'def': 'a small wind instrument that produces a whistling sound by blowing into it', 'name': 'whistle'}, {'frequency': 'c', 'synset': 'wig.n.01', 'synonyms': ['wig'], 'id': 1182, 'def': 'hairpiece covering the head and made of real or synthetic hair', 'name': 'wig'}, {'frequency': 'c', 'synset': 'wind_chime.n.01', 'synonyms': ['wind_chime'], 'id': 1183, 'def': 'a decorative arrangement of pieces of metal or glass or pottery that hang together loosely so the wind can cause them to tinkle', 'name': 'wind_chime'}, {'frequency': 'c', 'synset': 'windmill.n.01', 'synonyms': ['windmill'], 'id': 1184, 'def': 'A mill or turbine that is powered by wind', 'name': 'windmill'}, {'frequency': 'c', 'synset': 'window_box.n.01', 'synonyms': ['window_box_(for_plants)'], 'id': 1185, 'def': 'a container for growing plants on a windowsill', 'name': 'window_box_(for_plants)'}, {'frequency': 'f', 'synset': 'windshield_wiper.n.01', 'synonyms': ['windshield_wiper', 'windscreen_wiper', 'wiper_(for_windshield/screen)'], 'id': 1186, 'def': 'a mechanical device that cleans the windshield', 'name': 'windshield_wiper'}, {'frequency': 'c', 'synset': 'windsock.n.01', 'synonyms': ['windsock', 'air_sock', 'air-sleeve', 'wind_sleeve', 'wind_cone'], 'id': 1187, 'def': 'a truncated cloth cone mounted on a mast/pole; shows wind direction', 'name': 'windsock'}, {'frequency': 'f', 'synset': 'wine_bottle.n.01', 'synonyms': ['wine_bottle'], 'id': 1188, 'def': 'a bottle for holding wine', 'name': 'wine_bottle'}, {'frequency': 'c', 'synset': 'wine_bucket.n.01', 'synonyms': ['wine_bucket', 'wine_cooler'], 'id': 1189, 'def': 'a bucket of ice used to chill a bottle of wine', 'name': 'wine_bucket'}, {'frequency': 'f', 'synset': 'wineglass.n.01', 'synonyms': ['wineglass'], 'id': 1190, 'def': 'a glass that has a stem and in which wine is served', 'name': 'wineglass'}, {'frequency': 'f', 'synset': 'winker.n.02', 'synonyms': ['blinder_(for_horses)'], 'id': 1191, 'def': 'blinds that prevent a horse from seeing something on either side', 'name': 'blinder_(for_horses)'}, {'frequency': 'c', 'synset': 'wok.n.01', 'synonyms': ['wok'], 'id': 1192, 'def': 'pan with a convex bottom; used for frying in Chinese cooking', 'name': 'wok'}, {'frequency': 'r', 'synset': 'wolf.n.01', 'synonyms': ['wolf'], 'id': 1193, 'def': 'a wild carnivorous mammal of the dog family, living and hunting in packs', 'name': 'wolf'}, {'frequency': 'c', 'synset': 'wooden_spoon.n.02', 'synonyms': ['wooden_spoon'], 'id': 1194, 'def': 'a spoon made of wood', 'name': 'wooden_spoon'}, {'frequency': 'c', 'synset': 'wreath.n.01', 'synonyms': ['wreath'], 'id': 1195, 'def': 'an arrangement of flowers, leaves, or stems fastened in a ring', 'name': 'wreath'}, {'frequency': 'c', 'synset': 'wrench.n.03', 'synonyms': ['wrench', 'spanner'], 'id': 1196, 'def': 'a hand tool that is used to hold or twist a nut or bolt', 'name': 'wrench'}, {'frequency': 'f', 'synset': 'wristband.n.01', 'synonyms': ['wristband'], 'id': 1197, 'def': 'band consisting of a part of a sleeve that covers the wrist', 'name': 'wristband'}, {'frequency': 'f', 'synset': 'wristlet.n.01', 'synonyms': ['wristlet', 'wrist_band'], 'id': 1198, 'def': 'a band or bracelet worn around the wrist', 'name': 'wristlet'}, {'frequency': 'c', 'synset': 'yacht.n.01', 'synonyms': ['yacht'], 'id': 1199, 'def': 'an expensive vessel propelled by sail or power and used for cruising or racing', 'name': 'yacht'}, {'frequency': 'c', 'synset': 'yogurt.n.01', 'synonyms': ['yogurt', 'yoghurt', 'yoghourt'], 'id': 1200, 'def': 'a custard-like food made from curdled milk', 'name': 'yogurt'}, {'frequency': 'c', 'synset': 'yoke.n.07', 'synonyms': ['yoke_(animal_equipment)'], 'id': 1201, 'def': 'gear joining two animals at the neck; NOT egg yolk', 'name': 'yoke_(animal_equipment)'}, {'frequency': 'f', 'synset': 'zebra.n.01', 'synonyms': ['zebra'], 'id': 1202, 'def': 'any of several fleet black-and-white striped African equines', 'name': 'zebra'}, {'frequency': 'c', 'synset': 'zucchini.n.02', 'synonyms': ['zucchini', 'courgette'], 'id': 1203, 'def': 'small cucumber-shaped vegetable marrow; typically dark green', 'name': 'zucchini'}] # noqa
+# fmt: on
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/lvis_v1_category_image_count.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/lvis_v1_category_image_count.py
new file mode 100644
index 0000000000000000000000000000000000000000..31bf0cfcd5096ab87835db86a28671d474514c40
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/lvis_v1_category_image_count.py
@@ -0,0 +1,20 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Autogen with
+# with open("lvis_v1_train.json", "r") as f:
+# a = json.load(f)
+# c = a["categories"]
+# for x in c:
+# del x["name"]
+# del x["instance_count"]
+# del x["def"]
+# del x["synonyms"]
+# del x["frequency"]
+# del x["synset"]
+# LVIS_CATEGORY_IMAGE_COUNT = repr(c) + " # noqa"
+# with open("/tmp/lvis_category_image_count.py", "wt") as f:
+# f.write(f"LVIS_CATEGORY_IMAGE_COUNT = {LVIS_CATEGORY_IMAGE_COUNT}")
+# Then paste the contents of that file below
+
+# fmt: off
+LVIS_CATEGORY_IMAGE_COUNT = [{'id': 1, 'image_count': 64}, {'id': 2, 'image_count': 364}, {'id': 3, 'image_count': 1911}, {'id': 4, 'image_count': 149}, {'id': 5, 'image_count': 29}, {'id': 6, 'image_count': 26}, {'id': 7, 'image_count': 59}, {'id': 8, 'image_count': 22}, {'id': 9, 'image_count': 12}, {'id': 10, 'image_count': 28}, {'id': 11, 'image_count': 505}, {'id': 12, 'image_count': 1207}, {'id': 13, 'image_count': 4}, {'id': 14, 'image_count': 10}, {'id': 15, 'image_count': 500}, {'id': 16, 'image_count': 33}, {'id': 17, 'image_count': 3}, {'id': 18, 'image_count': 44}, {'id': 19, 'image_count': 561}, {'id': 20, 'image_count': 8}, {'id': 21, 'image_count': 9}, {'id': 22, 'image_count': 33}, {'id': 23, 'image_count': 1883}, {'id': 24, 'image_count': 98}, {'id': 25, 'image_count': 70}, {'id': 26, 'image_count': 46}, {'id': 27, 'image_count': 117}, {'id': 28, 'image_count': 41}, {'id': 29, 'image_count': 1395}, {'id': 30, 'image_count': 7}, {'id': 31, 'image_count': 1}, {'id': 32, 'image_count': 314}, {'id': 33, 'image_count': 31}, {'id': 34, 'image_count': 1905}, {'id': 35, 'image_count': 1859}, {'id': 36, 'image_count': 1623}, {'id': 37, 'image_count': 47}, {'id': 38, 'image_count': 3}, {'id': 39, 'image_count': 3}, {'id': 40, 'image_count': 1}, {'id': 41, 'image_count': 305}, {'id': 42, 'image_count': 6}, {'id': 43, 'image_count': 210}, {'id': 44, 'image_count': 36}, {'id': 45, 'image_count': 1787}, {'id': 46, 'image_count': 17}, {'id': 47, 'image_count': 51}, {'id': 48, 'image_count': 138}, {'id': 49, 'image_count': 3}, {'id': 50, 'image_count': 1470}, {'id': 51, 'image_count': 3}, {'id': 52, 'image_count': 2}, {'id': 53, 'image_count': 186}, {'id': 54, 'image_count': 76}, {'id': 55, 'image_count': 26}, {'id': 56, 'image_count': 303}, {'id': 57, 'image_count': 738}, {'id': 58, 'image_count': 1799}, {'id': 59, 'image_count': 1934}, {'id': 60, 'image_count': 1609}, {'id': 61, 'image_count': 1622}, {'id': 62, 'image_count': 41}, {'id': 63, 'image_count': 4}, {'id': 64, 'image_count': 11}, {'id': 65, 'image_count': 270}, {'id': 66, 'image_count': 349}, {'id': 67, 'image_count': 42}, {'id': 68, 'image_count': 823}, {'id': 69, 'image_count': 6}, {'id': 70, 'image_count': 48}, {'id': 71, 'image_count': 3}, {'id': 72, 'image_count': 42}, {'id': 73, 'image_count': 24}, {'id': 74, 'image_count': 16}, {'id': 75, 'image_count': 605}, {'id': 76, 'image_count': 646}, {'id': 77, 'image_count': 1765}, {'id': 78, 'image_count': 2}, {'id': 79, 'image_count': 125}, {'id': 80, 'image_count': 1420}, {'id': 81, 'image_count': 140}, {'id': 82, 'image_count': 4}, {'id': 83, 'image_count': 322}, {'id': 84, 'image_count': 60}, {'id': 85, 'image_count': 2}, {'id': 86, 'image_count': 231}, {'id': 87, 'image_count': 333}, {'id': 88, 'image_count': 1941}, {'id': 89, 'image_count': 367}, {'id': 90, 'image_count': 1922}, {'id': 91, 'image_count': 18}, {'id': 92, 'image_count': 81}, {'id': 93, 'image_count': 1}, {'id': 94, 'image_count': 1852}, {'id': 95, 'image_count': 430}, {'id': 96, 'image_count': 247}, {'id': 97, 'image_count': 94}, {'id': 98, 'image_count': 21}, {'id': 99, 'image_count': 1821}, {'id': 100, 'image_count': 16}, {'id': 101, 'image_count': 12}, {'id': 102, 'image_count': 25}, {'id': 103, 'image_count': 41}, {'id': 104, 'image_count': 244}, {'id': 105, 'image_count': 7}, {'id': 106, 'image_count': 1}, {'id': 107, 'image_count': 40}, {'id': 108, 'image_count': 40}, {'id': 109, 'image_count': 104}, {'id': 110, 'image_count': 1671}, {'id': 111, 'image_count': 49}, {'id': 112, 'image_count': 243}, {'id': 113, 'image_count': 2}, {'id': 114, 'image_count': 242}, {'id': 115, 'image_count': 271}, {'id': 116, 'image_count': 104}, {'id': 117, 'image_count': 8}, {'id': 118, 'image_count': 1758}, {'id': 119, 'image_count': 1}, {'id': 120, 'image_count': 48}, {'id': 121, 'image_count': 14}, {'id': 122, 'image_count': 40}, {'id': 123, 'image_count': 1}, {'id': 124, 'image_count': 37}, {'id': 125, 'image_count': 1510}, {'id': 126, 'image_count': 6}, {'id': 127, 'image_count': 1903}, {'id': 128, 'image_count': 70}, {'id': 129, 'image_count': 86}, {'id': 130, 'image_count': 7}, {'id': 131, 'image_count': 5}, {'id': 132, 'image_count': 1406}, {'id': 133, 'image_count': 1901}, {'id': 134, 'image_count': 15}, {'id': 135, 'image_count': 28}, {'id': 136, 'image_count': 6}, {'id': 137, 'image_count': 494}, {'id': 138, 'image_count': 234}, {'id': 139, 'image_count': 1922}, {'id': 140, 'image_count': 1}, {'id': 141, 'image_count': 35}, {'id': 142, 'image_count': 5}, {'id': 143, 'image_count': 1828}, {'id': 144, 'image_count': 8}, {'id': 145, 'image_count': 63}, {'id': 146, 'image_count': 1668}, {'id': 147, 'image_count': 4}, {'id': 148, 'image_count': 95}, {'id': 149, 'image_count': 17}, {'id': 150, 'image_count': 1567}, {'id': 151, 'image_count': 2}, {'id': 152, 'image_count': 103}, {'id': 153, 'image_count': 50}, {'id': 154, 'image_count': 1309}, {'id': 155, 'image_count': 6}, {'id': 156, 'image_count': 92}, {'id': 157, 'image_count': 19}, {'id': 158, 'image_count': 37}, {'id': 159, 'image_count': 4}, {'id': 160, 'image_count': 709}, {'id': 161, 'image_count': 9}, {'id': 162, 'image_count': 82}, {'id': 163, 'image_count': 15}, {'id': 164, 'image_count': 3}, {'id': 165, 'image_count': 61}, {'id': 166, 'image_count': 51}, {'id': 167, 'image_count': 5}, {'id': 168, 'image_count': 13}, {'id': 169, 'image_count': 642}, {'id': 170, 'image_count': 24}, {'id': 171, 'image_count': 255}, {'id': 172, 'image_count': 9}, {'id': 173, 'image_count': 1808}, {'id': 174, 'image_count': 31}, {'id': 175, 'image_count': 158}, {'id': 176, 'image_count': 80}, {'id': 177, 'image_count': 1884}, {'id': 178, 'image_count': 158}, {'id': 179, 'image_count': 2}, {'id': 180, 'image_count': 12}, {'id': 181, 'image_count': 1659}, {'id': 182, 'image_count': 7}, {'id': 183, 'image_count': 834}, {'id': 184, 'image_count': 57}, {'id': 185, 'image_count': 174}, {'id': 186, 'image_count': 95}, {'id': 187, 'image_count': 27}, {'id': 188, 'image_count': 22}, {'id': 189, 'image_count': 1391}, {'id': 190, 'image_count': 90}, {'id': 191, 'image_count': 40}, {'id': 192, 'image_count': 445}, {'id': 193, 'image_count': 21}, {'id': 194, 'image_count': 1132}, {'id': 195, 'image_count': 177}, {'id': 196, 'image_count': 4}, {'id': 197, 'image_count': 17}, {'id': 198, 'image_count': 84}, {'id': 199, 'image_count': 55}, {'id': 200, 'image_count': 30}, {'id': 201, 'image_count': 25}, {'id': 202, 'image_count': 2}, {'id': 203, 'image_count': 125}, {'id': 204, 'image_count': 1135}, {'id': 205, 'image_count': 19}, {'id': 206, 'image_count': 72}, {'id': 207, 'image_count': 1926}, {'id': 208, 'image_count': 159}, {'id': 209, 'image_count': 7}, {'id': 210, 'image_count': 1}, {'id': 211, 'image_count': 13}, {'id': 212, 'image_count': 35}, {'id': 213, 'image_count': 18}, {'id': 214, 'image_count': 8}, {'id': 215, 'image_count': 6}, {'id': 216, 'image_count': 35}, {'id': 217, 'image_count': 1222}, {'id': 218, 'image_count': 103}, {'id': 219, 'image_count': 28}, {'id': 220, 'image_count': 63}, {'id': 221, 'image_count': 28}, {'id': 222, 'image_count': 5}, {'id': 223, 'image_count': 7}, {'id': 224, 'image_count': 14}, {'id': 225, 'image_count': 1918}, {'id': 226, 'image_count': 133}, {'id': 227, 'image_count': 16}, {'id': 228, 'image_count': 27}, {'id': 229, 'image_count': 110}, {'id': 230, 'image_count': 1895}, {'id': 231, 'image_count': 4}, {'id': 232, 'image_count': 1927}, {'id': 233, 'image_count': 8}, {'id': 234, 'image_count': 1}, {'id': 235, 'image_count': 263}, {'id': 236, 'image_count': 10}, {'id': 237, 'image_count': 2}, {'id': 238, 'image_count': 3}, {'id': 239, 'image_count': 87}, {'id': 240, 'image_count': 9}, {'id': 241, 'image_count': 71}, {'id': 242, 'image_count': 13}, {'id': 243, 'image_count': 18}, {'id': 244, 'image_count': 2}, {'id': 245, 'image_count': 5}, {'id': 246, 'image_count': 45}, {'id': 247, 'image_count': 1}, {'id': 248, 'image_count': 23}, {'id': 249, 'image_count': 32}, {'id': 250, 'image_count': 4}, {'id': 251, 'image_count': 1}, {'id': 252, 'image_count': 858}, {'id': 253, 'image_count': 661}, {'id': 254, 'image_count': 168}, {'id': 255, 'image_count': 210}, {'id': 256, 'image_count': 65}, {'id': 257, 'image_count': 4}, {'id': 258, 'image_count': 2}, {'id': 259, 'image_count': 159}, {'id': 260, 'image_count': 31}, {'id': 261, 'image_count': 811}, {'id': 262, 'image_count': 1}, {'id': 263, 'image_count': 42}, {'id': 264, 'image_count': 27}, {'id': 265, 'image_count': 2}, {'id': 266, 'image_count': 5}, {'id': 267, 'image_count': 95}, {'id': 268, 'image_count': 32}, {'id': 269, 'image_count': 1}, {'id': 270, 'image_count': 1}, {'id': 271, 'image_count': 1844}, {'id': 272, 'image_count': 897}, {'id': 273, 'image_count': 31}, {'id': 274, 'image_count': 23}, {'id': 275, 'image_count': 1}, {'id': 276, 'image_count': 202}, {'id': 277, 'image_count': 746}, {'id': 278, 'image_count': 44}, {'id': 279, 'image_count': 14}, {'id': 280, 'image_count': 26}, {'id': 281, 'image_count': 1}, {'id': 282, 'image_count': 2}, {'id': 283, 'image_count': 25}, {'id': 284, 'image_count': 238}, {'id': 285, 'image_count': 592}, {'id': 286, 'image_count': 26}, {'id': 287, 'image_count': 5}, {'id': 288, 'image_count': 42}, {'id': 289, 'image_count': 13}, {'id': 290, 'image_count': 46}, {'id': 291, 'image_count': 1}, {'id': 292, 'image_count': 8}, {'id': 293, 'image_count': 34}, {'id': 294, 'image_count': 5}, {'id': 295, 'image_count': 1}, {'id': 296, 'image_count': 1871}, {'id': 297, 'image_count': 717}, {'id': 298, 'image_count': 1010}, {'id': 299, 'image_count': 679}, {'id': 300, 'image_count': 3}, {'id': 301, 'image_count': 4}, {'id': 302, 'image_count': 1}, {'id': 303, 'image_count': 166}, {'id': 304, 'image_count': 2}, {'id': 305, 'image_count': 266}, {'id': 306, 'image_count': 101}, {'id': 307, 'image_count': 6}, {'id': 308, 'image_count': 14}, {'id': 309, 'image_count': 133}, {'id': 310, 'image_count': 2}, {'id': 311, 'image_count': 38}, {'id': 312, 'image_count': 95}, {'id': 313, 'image_count': 1}, {'id': 314, 'image_count': 12}, {'id': 315, 'image_count': 49}, {'id': 316, 'image_count': 5}, {'id': 317, 'image_count': 5}, {'id': 318, 'image_count': 16}, {'id': 319, 'image_count': 216}, {'id': 320, 'image_count': 12}, {'id': 321, 'image_count': 1}, {'id': 322, 'image_count': 54}, {'id': 323, 'image_count': 5}, {'id': 324, 'image_count': 245}, {'id': 325, 'image_count': 12}, {'id': 326, 'image_count': 7}, {'id': 327, 'image_count': 35}, {'id': 328, 'image_count': 36}, {'id': 329, 'image_count': 32}, {'id': 330, 'image_count': 1027}, {'id': 331, 'image_count': 10}, {'id': 332, 'image_count': 12}, {'id': 333, 'image_count': 1}, {'id': 334, 'image_count': 67}, {'id': 335, 'image_count': 71}, {'id': 336, 'image_count': 30}, {'id': 337, 'image_count': 48}, {'id': 338, 'image_count': 249}, {'id': 339, 'image_count': 13}, {'id': 340, 'image_count': 29}, {'id': 341, 'image_count': 14}, {'id': 342, 'image_count': 236}, {'id': 343, 'image_count': 15}, {'id': 344, 'image_count': 1521}, {'id': 345, 'image_count': 25}, {'id': 346, 'image_count': 249}, {'id': 347, 'image_count': 139}, {'id': 348, 'image_count': 2}, {'id': 349, 'image_count': 2}, {'id': 350, 'image_count': 1890}, {'id': 351, 'image_count': 1240}, {'id': 352, 'image_count': 1}, {'id': 353, 'image_count': 9}, {'id': 354, 'image_count': 1}, {'id': 355, 'image_count': 3}, {'id': 356, 'image_count': 11}, {'id': 357, 'image_count': 4}, {'id': 358, 'image_count': 236}, {'id': 359, 'image_count': 44}, {'id': 360, 'image_count': 19}, {'id': 361, 'image_count': 1100}, {'id': 362, 'image_count': 7}, {'id': 363, 'image_count': 69}, {'id': 364, 'image_count': 2}, {'id': 365, 'image_count': 8}, {'id': 366, 'image_count': 5}, {'id': 367, 'image_count': 227}, {'id': 368, 'image_count': 6}, {'id': 369, 'image_count': 106}, {'id': 370, 'image_count': 81}, {'id': 371, 'image_count': 17}, {'id': 372, 'image_count': 134}, {'id': 373, 'image_count': 312}, {'id': 374, 'image_count': 8}, {'id': 375, 'image_count': 271}, {'id': 376, 'image_count': 2}, {'id': 377, 'image_count': 103}, {'id': 378, 'image_count': 1938}, {'id': 379, 'image_count': 574}, {'id': 380, 'image_count': 120}, {'id': 381, 'image_count': 2}, {'id': 382, 'image_count': 2}, {'id': 383, 'image_count': 13}, {'id': 384, 'image_count': 29}, {'id': 385, 'image_count': 1710}, {'id': 386, 'image_count': 66}, {'id': 387, 'image_count': 1008}, {'id': 388, 'image_count': 1}, {'id': 389, 'image_count': 3}, {'id': 390, 'image_count': 1942}, {'id': 391, 'image_count': 19}, {'id': 392, 'image_count': 1488}, {'id': 393, 'image_count': 46}, {'id': 394, 'image_count': 106}, {'id': 395, 'image_count': 115}, {'id': 396, 'image_count': 19}, {'id': 397, 'image_count': 2}, {'id': 398, 'image_count': 1}, {'id': 399, 'image_count': 28}, {'id': 400, 'image_count': 9}, {'id': 401, 'image_count': 192}, {'id': 402, 'image_count': 12}, {'id': 403, 'image_count': 21}, {'id': 404, 'image_count': 247}, {'id': 405, 'image_count': 6}, {'id': 406, 'image_count': 64}, {'id': 407, 'image_count': 7}, {'id': 408, 'image_count': 40}, {'id': 409, 'image_count': 542}, {'id': 410, 'image_count': 2}, {'id': 411, 'image_count': 1898}, {'id': 412, 'image_count': 36}, {'id': 413, 'image_count': 4}, {'id': 414, 'image_count': 1}, {'id': 415, 'image_count': 191}, {'id': 416, 'image_count': 6}, {'id': 417, 'image_count': 41}, {'id': 418, 'image_count': 39}, {'id': 419, 'image_count': 46}, {'id': 420, 'image_count': 1}, {'id': 421, 'image_count': 1451}, {'id': 422, 'image_count': 1878}, {'id': 423, 'image_count': 11}, {'id': 424, 'image_count': 82}, {'id': 425, 'image_count': 18}, {'id': 426, 'image_count': 1}, {'id': 427, 'image_count': 7}, {'id': 428, 'image_count': 3}, {'id': 429, 'image_count': 575}, {'id': 430, 'image_count': 1907}, {'id': 431, 'image_count': 8}, {'id': 432, 'image_count': 4}, {'id': 433, 'image_count': 32}, {'id': 434, 'image_count': 11}, {'id': 435, 'image_count': 4}, {'id': 436, 'image_count': 54}, {'id': 437, 'image_count': 202}, {'id': 438, 'image_count': 32}, {'id': 439, 'image_count': 3}, {'id': 440, 'image_count': 130}, {'id': 441, 'image_count': 119}, {'id': 442, 'image_count': 141}, {'id': 443, 'image_count': 29}, {'id': 444, 'image_count': 525}, {'id': 445, 'image_count': 1323}, {'id': 446, 'image_count': 2}, {'id': 447, 'image_count': 113}, {'id': 448, 'image_count': 16}, {'id': 449, 'image_count': 7}, {'id': 450, 'image_count': 35}, {'id': 451, 'image_count': 1908}, {'id': 452, 'image_count': 353}, {'id': 453, 'image_count': 18}, {'id': 454, 'image_count': 14}, {'id': 455, 'image_count': 77}, {'id': 456, 'image_count': 8}, {'id': 457, 'image_count': 37}, {'id': 458, 'image_count': 1}, {'id': 459, 'image_count': 346}, {'id': 460, 'image_count': 19}, {'id': 461, 'image_count': 1779}, {'id': 462, 'image_count': 23}, {'id': 463, 'image_count': 25}, {'id': 464, 'image_count': 67}, {'id': 465, 'image_count': 19}, {'id': 466, 'image_count': 28}, {'id': 467, 'image_count': 4}, {'id': 468, 'image_count': 27}, {'id': 469, 'image_count': 1861}, {'id': 470, 'image_count': 11}, {'id': 471, 'image_count': 13}, {'id': 472, 'image_count': 13}, {'id': 473, 'image_count': 32}, {'id': 474, 'image_count': 1767}, {'id': 475, 'image_count': 42}, {'id': 476, 'image_count': 17}, {'id': 477, 'image_count': 128}, {'id': 478, 'image_count': 1}, {'id': 479, 'image_count': 9}, {'id': 480, 'image_count': 10}, {'id': 481, 'image_count': 4}, {'id': 482, 'image_count': 9}, {'id': 483, 'image_count': 18}, {'id': 484, 'image_count': 41}, {'id': 485, 'image_count': 28}, {'id': 486, 'image_count': 3}, {'id': 487, 'image_count': 65}, {'id': 488, 'image_count': 9}, {'id': 489, 'image_count': 23}, {'id': 490, 'image_count': 24}, {'id': 491, 'image_count': 1}, {'id': 492, 'image_count': 2}, {'id': 493, 'image_count': 59}, {'id': 494, 'image_count': 48}, {'id': 495, 'image_count': 17}, {'id': 496, 'image_count': 1877}, {'id': 497, 'image_count': 18}, {'id': 498, 'image_count': 1920}, {'id': 499, 'image_count': 50}, {'id': 500, 'image_count': 1890}, {'id': 501, 'image_count': 99}, {'id': 502, 'image_count': 1530}, {'id': 503, 'image_count': 3}, {'id': 504, 'image_count': 11}, {'id': 505, 'image_count': 19}, {'id': 506, 'image_count': 3}, {'id': 507, 'image_count': 63}, {'id': 508, 'image_count': 5}, {'id': 509, 'image_count': 6}, {'id': 510, 'image_count': 233}, {'id': 511, 'image_count': 54}, {'id': 512, 'image_count': 36}, {'id': 513, 'image_count': 10}, {'id': 514, 'image_count': 124}, {'id': 515, 'image_count': 101}, {'id': 516, 'image_count': 3}, {'id': 517, 'image_count': 363}, {'id': 518, 'image_count': 3}, {'id': 519, 'image_count': 30}, {'id': 520, 'image_count': 18}, {'id': 521, 'image_count': 199}, {'id': 522, 'image_count': 97}, {'id': 523, 'image_count': 32}, {'id': 524, 'image_count': 121}, {'id': 525, 'image_count': 16}, {'id': 526, 'image_count': 12}, {'id': 527, 'image_count': 2}, {'id': 528, 'image_count': 214}, {'id': 529, 'image_count': 48}, {'id': 530, 'image_count': 26}, {'id': 531, 'image_count': 13}, {'id': 532, 'image_count': 4}, {'id': 533, 'image_count': 11}, {'id': 534, 'image_count': 123}, {'id': 535, 'image_count': 7}, {'id': 536, 'image_count': 200}, {'id': 537, 'image_count': 91}, {'id': 538, 'image_count': 9}, {'id': 539, 'image_count': 72}, {'id': 540, 'image_count': 1886}, {'id': 541, 'image_count': 4}, {'id': 542, 'image_count': 1}, {'id': 543, 'image_count': 1}, {'id': 544, 'image_count': 1932}, {'id': 545, 'image_count': 4}, {'id': 546, 'image_count': 56}, {'id': 547, 'image_count': 854}, {'id': 548, 'image_count': 755}, {'id': 549, 'image_count': 1843}, {'id': 550, 'image_count': 96}, {'id': 551, 'image_count': 7}, {'id': 552, 'image_count': 74}, {'id': 553, 'image_count': 66}, {'id': 554, 'image_count': 57}, {'id': 555, 'image_count': 44}, {'id': 556, 'image_count': 1905}, {'id': 557, 'image_count': 4}, {'id': 558, 'image_count': 90}, {'id': 559, 'image_count': 1635}, {'id': 560, 'image_count': 8}, {'id': 561, 'image_count': 5}, {'id': 562, 'image_count': 50}, {'id': 563, 'image_count': 545}, {'id': 564, 'image_count': 20}, {'id': 565, 'image_count': 193}, {'id': 566, 'image_count': 285}, {'id': 567, 'image_count': 3}, {'id': 568, 'image_count': 1}, {'id': 569, 'image_count': 1904}, {'id': 570, 'image_count': 294}, {'id': 571, 'image_count': 3}, {'id': 572, 'image_count': 5}, {'id': 573, 'image_count': 24}, {'id': 574, 'image_count': 2}, {'id': 575, 'image_count': 2}, {'id': 576, 'image_count': 16}, {'id': 577, 'image_count': 8}, {'id': 578, 'image_count': 154}, {'id': 579, 'image_count': 66}, {'id': 580, 'image_count': 1}, {'id': 581, 'image_count': 24}, {'id': 582, 'image_count': 1}, {'id': 583, 'image_count': 4}, {'id': 584, 'image_count': 75}, {'id': 585, 'image_count': 6}, {'id': 586, 'image_count': 126}, {'id': 587, 'image_count': 24}, {'id': 588, 'image_count': 22}, {'id': 589, 'image_count': 1872}, {'id': 590, 'image_count': 16}, {'id': 591, 'image_count': 423}, {'id': 592, 'image_count': 1927}, {'id': 593, 'image_count': 38}, {'id': 594, 'image_count': 3}, {'id': 595, 'image_count': 1945}, {'id': 596, 'image_count': 35}, {'id': 597, 'image_count': 1}, {'id': 598, 'image_count': 13}, {'id': 599, 'image_count': 9}, {'id': 600, 'image_count': 14}, {'id': 601, 'image_count': 37}, {'id': 602, 'image_count': 3}, {'id': 603, 'image_count': 4}, {'id': 604, 'image_count': 100}, {'id': 605, 'image_count': 195}, {'id': 606, 'image_count': 1}, {'id': 607, 'image_count': 12}, {'id': 608, 'image_count': 24}, {'id': 609, 'image_count': 489}, {'id': 610, 'image_count': 10}, {'id': 611, 'image_count': 1689}, {'id': 612, 'image_count': 42}, {'id': 613, 'image_count': 81}, {'id': 614, 'image_count': 894}, {'id': 615, 'image_count': 1868}, {'id': 616, 'image_count': 7}, {'id': 617, 'image_count': 1567}, {'id': 618, 'image_count': 10}, {'id': 619, 'image_count': 8}, {'id': 620, 'image_count': 7}, {'id': 621, 'image_count': 629}, {'id': 622, 'image_count': 89}, {'id': 623, 'image_count': 15}, {'id': 624, 'image_count': 134}, {'id': 625, 'image_count': 4}, {'id': 626, 'image_count': 1802}, {'id': 627, 'image_count': 595}, {'id': 628, 'image_count': 1210}, {'id': 629, 'image_count': 48}, {'id': 630, 'image_count': 418}, {'id': 631, 'image_count': 1846}, {'id': 632, 'image_count': 5}, {'id': 633, 'image_count': 221}, {'id': 634, 'image_count': 10}, {'id': 635, 'image_count': 7}, {'id': 636, 'image_count': 76}, {'id': 637, 'image_count': 22}, {'id': 638, 'image_count': 10}, {'id': 639, 'image_count': 341}, {'id': 640, 'image_count': 1}, {'id': 641, 'image_count': 705}, {'id': 642, 'image_count': 1900}, {'id': 643, 'image_count': 188}, {'id': 644, 'image_count': 227}, {'id': 645, 'image_count': 861}, {'id': 646, 'image_count': 6}, {'id': 647, 'image_count': 115}, {'id': 648, 'image_count': 5}, {'id': 649, 'image_count': 43}, {'id': 650, 'image_count': 14}, {'id': 651, 'image_count': 6}, {'id': 652, 'image_count': 15}, {'id': 653, 'image_count': 1167}, {'id': 654, 'image_count': 15}, {'id': 655, 'image_count': 994}, {'id': 656, 'image_count': 28}, {'id': 657, 'image_count': 2}, {'id': 658, 'image_count': 338}, {'id': 659, 'image_count': 334}, {'id': 660, 'image_count': 15}, {'id': 661, 'image_count': 102}, {'id': 662, 'image_count': 1}, {'id': 663, 'image_count': 8}, {'id': 664, 'image_count': 1}, {'id': 665, 'image_count': 1}, {'id': 666, 'image_count': 28}, {'id': 667, 'image_count': 91}, {'id': 668, 'image_count': 260}, {'id': 669, 'image_count': 131}, {'id': 670, 'image_count': 128}, {'id': 671, 'image_count': 3}, {'id': 672, 'image_count': 10}, {'id': 673, 'image_count': 39}, {'id': 674, 'image_count': 2}, {'id': 675, 'image_count': 925}, {'id': 676, 'image_count': 354}, {'id': 677, 'image_count': 31}, {'id': 678, 'image_count': 10}, {'id': 679, 'image_count': 215}, {'id': 680, 'image_count': 71}, {'id': 681, 'image_count': 43}, {'id': 682, 'image_count': 28}, {'id': 683, 'image_count': 34}, {'id': 684, 'image_count': 16}, {'id': 685, 'image_count': 273}, {'id': 686, 'image_count': 2}, {'id': 687, 'image_count': 999}, {'id': 688, 'image_count': 4}, {'id': 689, 'image_count': 107}, {'id': 690, 'image_count': 2}, {'id': 691, 'image_count': 1}, {'id': 692, 'image_count': 454}, {'id': 693, 'image_count': 9}, {'id': 694, 'image_count': 1901}, {'id': 695, 'image_count': 61}, {'id': 696, 'image_count': 91}, {'id': 697, 'image_count': 46}, {'id': 698, 'image_count': 1402}, {'id': 699, 'image_count': 74}, {'id': 700, 'image_count': 421}, {'id': 701, 'image_count': 226}, {'id': 702, 'image_count': 10}, {'id': 703, 'image_count': 1720}, {'id': 704, 'image_count': 261}, {'id': 705, 'image_count': 1337}, {'id': 706, 'image_count': 293}, {'id': 707, 'image_count': 62}, {'id': 708, 'image_count': 814}, {'id': 709, 'image_count': 407}, {'id': 710, 'image_count': 6}, {'id': 711, 'image_count': 16}, {'id': 712, 'image_count': 7}, {'id': 713, 'image_count': 1791}, {'id': 714, 'image_count': 2}, {'id': 715, 'image_count': 1915}, {'id': 716, 'image_count': 1940}, {'id': 717, 'image_count': 13}, {'id': 718, 'image_count': 16}, {'id': 719, 'image_count': 448}, {'id': 720, 'image_count': 12}, {'id': 721, 'image_count': 18}, {'id': 722, 'image_count': 4}, {'id': 723, 'image_count': 71}, {'id': 724, 'image_count': 189}, {'id': 725, 'image_count': 74}, {'id': 726, 'image_count': 103}, {'id': 727, 'image_count': 3}, {'id': 728, 'image_count': 110}, {'id': 729, 'image_count': 5}, {'id': 730, 'image_count': 9}, {'id': 731, 'image_count': 15}, {'id': 732, 'image_count': 25}, {'id': 733, 'image_count': 7}, {'id': 734, 'image_count': 647}, {'id': 735, 'image_count': 824}, {'id': 736, 'image_count': 100}, {'id': 737, 'image_count': 47}, {'id': 738, 'image_count': 121}, {'id': 739, 'image_count': 731}, {'id': 740, 'image_count': 73}, {'id': 741, 'image_count': 49}, {'id': 742, 'image_count': 23}, {'id': 743, 'image_count': 4}, {'id': 744, 'image_count': 62}, {'id': 745, 'image_count': 118}, {'id': 746, 'image_count': 99}, {'id': 747, 'image_count': 40}, {'id': 748, 'image_count': 1036}, {'id': 749, 'image_count': 105}, {'id': 750, 'image_count': 21}, {'id': 751, 'image_count': 229}, {'id': 752, 'image_count': 7}, {'id': 753, 'image_count': 72}, {'id': 754, 'image_count': 9}, {'id': 755, 'image_count': 10}, {'id': 756, 'image_count': 328}, {'id': 757, 'image_count': 468}, {'id': 758, 'image_count': 1}, {'id': 759, 'image_count': 2}, {'id': 760, 'image_count': 24}, {'id': 761, 'image_count': 11}, {'id': 762, 'image_count': 72}, {'id': 763, 'image_count': 17}, {'id': 764, 'image_count': 10}, {'id': 765, 'image_count': 17}, {'id': 766, 'image_count': 489}, {'id': 767, 'image_count': 47}, {'id': 768, 'image_count': 93}, {'id': 769, 'image_count': 1}, {'id': 770, 'image_count': 12}, {'id': 771, 'image_count': 228}, {'id': 772, 'image_count': 5}, {'id': 773, 'image_count': 76}, {'id': 774, 'image_count': 71}, {'id': 775, 'image_count': 30}, {'id': 776, 'image_count': 109}, {'id': 777, 'image_count': 14}, {'id': 778, 'image_count': 1}, {'id': 779, 'image_count': 8}, {'id': 780, 'image_count': 26}, {'id': 781, 'image_count': 339}, {'id': 782, 'image_count': 153}, {'id': 783, 'image_count': 2}, {'id': 784, 'image_count': 3}, {'id': 785, 'image_count': 8}, {'id': 786, 'image_count': 47}, {'id': 787, 'image_count': 8}, {'id': 788, 'image_count': 6}, {'id': 789, 'image_count': 116}, {'id': 790, 'image_count': 69}, {'id': 791, 'image_count': 13}, {'id': 792, 'image_count': 6}, {'id': 793, 'image_count': 1928}, {'id': 794, 'image_count': 79}, {'id': 795, 'image_count': 14}, {'id': 796, 'image_count': 7}, {'id': 797, 'image_count': 20}, {'id': 798, 'image_count': 114}, {'id': 799, 'image_count': 221}, {'id': 800, 'image_count': 502}, {'id': 801, 'image_count': 62}, {'id': 802, 'image_count': 87}, {'id': 803, 'image_count': 4}, {'id': 804, 'image_count': 1912}, {'id': 805, 'image_count': 7}, {'id': 806, 'image_count': 186}, {'id': 807, 'image_count': 18}, {'id': 808, 'image_count': 4}, {'id': 809, 'image_count': 3}, {'id': 810, 'image_count': 7}, {'id': 811, 'image_count': 1413}, {'id': 812, 'image_count': 7}, {'id': 813, 'image_count': 12}, {'id': 814, 'image_count': 248}, {'id': 815, 'image_count': 4}, {'id': 816, 'image_count': 1881}, {'id': 817, 'image_count': 529}, {'id': 818, 'image_count': 1932}, {'id': 819, 'image_count': 50}, {'id': 820, 'image_count': 3}, {'id': 821, 'image_count': 28}, {'id': 822, 'image_count': 10}, {'id': 823, 'image_count': 5}, {'id': 824, 'image_count': 5}, {'id': 825, 'image_count': 18}, {'id': 826, 'image_count': 14}, {'id': 827, 'image_count': 1890}, {'id': 828, 'image_count': 660}, {'id': 829, 'image_count': 8}, {'id': 830, 'image_count': 25}, {'id': 831, 'image_count': 10}, {'id': 832, 'image_count': 218}, {'id': 833, 'image_count': 36}, {'id': 834, 'image_count': 16}, {'id': 835, 'image_count': 808}, {'id': 836, 'image_count': 479}, {'id': 837, 'image_count': 1404}, {'id': 838, 'image_count': 307}, {'id': 839, 'image_count': 57}, {'id': 840, 'image_count': 28}, {'id': 841, 'image_count': 80}, {'id': 842, 'image_count': 11}, {'id': 843, 'image_count': 92}, {'id': 844, 'image_count': 20}, {'id': 845, 'image_count': 194}, {'id': 846, 'image_count': 23}, {'id': 847, 'image_count': 52}, {'id': 848, 'image_count': 673}, {'id': 849, 'image_count': 2}, {'id': 850, 'image_count': 2}, {'id': 851, 'image_count': 1}, {'id': 852, 'image_count': 2}, {'id': 853, 'image_count': 8}, {'id': 854, 'image_count': 80}, {'id': 855, 'image_count': 3}, {'id': 856, 'image_count': 3}, {'id': 857, 'image_count': 15}, {'id': 858, 'image_count': 2}, {'id': 859, 'image_count': 10}, {'id': 860, 'image_count': 386}, {'id': 861, 'image_count': 65}, {'id': 862, 'image_count': 3}, {'id': 863, 'image_count': 35}, {'id': 864, 'image_count': 5}, {'id': 865, 'image_count': 180}, {'id': 866, 'image_count': 99}, {'id': 867, 'image_count': 49}, {'id': 868, 'image_count': 28}, {'id': 869, 'image_count': 1}, {'id': 870, 'image_count': 52}, {'id': 871, 'image_count': 36}, {'id': 872, 'image_count': 70}, {'id': 873, 'image_count': 6}, {'id': 874, 'image_count': 29}, {'id': 875, 'image_count': 24}, {'id': 876, 'image_count': 1115}, {'id': 877, 'image_count': 61}, {'id': 878, 'image_count': 18}, {'id': 879, 'image_count': 18}, {'id': 880, 'image_count': 665}, {'id': 881, 'image_count': 1096}, {'id': 882, 'image_count': 29}, {'id': 883, 'image_count': 8}, {'id': 884, 'image_count': 14}, {'id': 885, 'image_count': 1622}, {'id': 886, 'image_count': 2}, {'id': 887, 'image_count': 3}, {'id': 888, 'image_count': 32}, {'id': 889, 'image_count': 55}, {'id': 890, 'image_count': 1}, {'id': 891, 'image_count': 10}, {'id': 892, 'image_count': 10}, {'id': 893, 'image_count': 47}, {'id': 894, 'image_count': 3}, {'id': 895, 'image_count': 29}, {'id': 896, 'image_count': 342}, {'id': 897, 'image_count': 25}, {'id': 898, 'image_count': 1469}, {'id': 899, 'image_count': 521}, {'id': 900, 'image_count': 347}, {'id': 901, 'image_count': 35}, {'id': 902, 'image_count': 7}, {'id': 903, 'image_count': 207}, {'id': 904, 'image_count': 108}, {'id': 905, 'image_count': 2}, {'id': 906, 'image_count': 34}, {'id': 907, 'image_count': 12}, {'id': 908, 'image_count': 10}, {'id': 909, 'image_count': 13}, {'id': 910, 'image_count': 361}, {'id': 911, 'image_count': 1023}, {'id': 912, 'image_count': 782}, {'id': 913, 'image_count': 2}, {'id': 914, 'image_count': 5}, {'id': 915, 'image_count': 247}, {'id': 916, 'image_count': 221}, {'id': 917, 'image_count': 4}, {'id': 918, 'image_count': 8}, {'id': 919, 'image_count': 158}, {'id': 920, 'image_count': 3}, {'id': 921, 'image_count': 752}, {'id': 922, 'image_count': 64}, {'id': 923, 'image_count': 707}, {'id': 924, 'image_count': 143}, {'id': 925, 'image_count': 1}, {'id': 926, 'image_count': 49}, {'id': 927, 'image_count': 126}, {'id': 928, 'image_count': 76}, {'id': 929, 'image_count': 11}, {'id': 930, 'image_count': 11}, {'id': 931, 'image_count': 4}, {'id': 932, 'image_count': 39}, {'id': 933, 'image_count': 11}, {'id': 934, 'image_count': 13}, {'id': 935, 'image_count': 91}, {'id': 936, 'image_count': 14}, {'id': 937, 'image_count': 5}, {'id': 938, 'image_count': 3}, {'id': 939, 'image_count': 10}, {'id': 940, 'image_count': 18}, {'id': 941, 'image_count': 9}, {'id': 942, 'image_count': 6}, {'id': 943, 'image_count': 951}, {'id': 944, 'image_count': 2}, {'id': 945, 'image_count': 1}, {'id': 946, 'image_count': 19}, {'id': 947, 'image_count': 1942}, {'id': 948, 'image_count': 1916}, {'id': 949, 'image_count': 139}, {'id': 950, 'image_count': 43}, {'id': 951, 'image_count': 1969}, {'id': 952, 'image_count': 5}, {'id': 953, 'image_count': 134}, {'id': 954, 'image_count': 74}, {'id': 955, 'image_count': 381}, {'id': 956, 'image_count': 1}, {'id': 957, 'image_count': 381}, {'id': 958, 'image_count': 6}, {'id': 959, 'image_count': 1826}, {'id': 960, 'image_count': 28}, {'id': 961, 'image_count': 1635}, {'id': 962, 'image_count': 1967}, {'id': 963, 'image_count': 16}, {'id': 964, 'image_count': 1926}, {'id': 965, 'image_count': 1789}, {'id': 966, 'image_count': 401}, {'id': 967, 'image_count': 1968}, {'id': 968, 'image_count': 1167}, {'id': 969, 'image_count': 1}, {'id': 970, 'image_count': 56}, {'id': 971, 'image_count': 17}, {'id': 972, 'image_count': 1}, {'id': 973, 'image_count': 58}, {'id': 974, 'image_count': 9}, {'id': 975, 'image_count': 8}, {'id': 976, 'image_count': 1124}, {'id': 977, 'image_count': 31}, {'id': 978, 'image_count': 16}, {'id': 979, 'image_count': 491}, {'id': 980, 'image_count': 432}, {'id': 981, 'image_count': 1945}, {'id': 982, 'image_count': 1899}, {'id': 983, 'image_count': 5}, {'id': 984, 'image_count': 28}, {'id': 985, 'image_count': 7}, {'id': 986, 'image_count': 146}, {'id': 987, 'image_count': 1}, {'id': 988, 'image_count': 25}, {'id': 989, 'image_count': 22}, {'id': 990, 'image_count': 1}, {'id': 991, 'image_count': 10}, {'id': 992, 'image_count': 9}, {'id': 993, 'image_count': 308}, {'id': 994, 'image_count': 4}, {'id': 995, 'image_count': 1969}, {'id': 996, 'image_count': 45}, {'id': 997, 'image_count': 12}, {'id': 998, 'image_count': 1}, {'id': 999, 'image_count': 85}, {'id': 1000, 'image_count': 1127}, {'id': 1001, 'image_count': 11}, {'id': 1002, 'image_count': 60}, {'id': 1003, 'image_count': 1}, {'id': 1004, 'image_count': 16}, {'id': 1005, 'image_count': 1}, {'id': 1006, 'image_count': 65}, {'id': 1007, 'image_count': 13}, {'id': 1008, 'image_count': 655}, {'id': 1009, 'image_count': 51}, {'id': 1010, 'image_count': 1}, {'id': 1011, 'image_count': 673}, {'id': 1012, 'image_count': 5}, {'id': 1013, 'image_count': 36}, {'id': 1014, 'image_count': 54}, {'id': 1015, 'image_count': 5}, {'id': 1016, 'image_count': 8}, {'id': 1017, 'image_count': 305}, {'id': 1018, 'image_count': 297}, {'id': 1019, 'image_count': 1053}, {'id': 1020, 'image_count': 223}, {'id': 1021, 'image_count': 1037}, {'id': 1022, 'image_count': 63}, {'id': 1023, 'image_count': 1881}, {'id': 1024, 'image_count': 507}, {'id': 1025, 'image_count': 333}, {'id': 1026, 'image_count': 1911}, {'id': 1027, 'image_count': 1765}, {'id': 1028, 'image_count': 1}, {'id': 1029, 'image_count': 5}, {'id': 1030, 'image_count': 1}, {'id': 1031, 'image_count': 9}, {'id': 1032, 'image_count': 2}, {'id': 1033, 'image_count': 151}, {'id': 1034, 'image_count': 82}, {'id': 1035, 'image_count': 1931}, {'id': 1036, 'image_count': 41}, {'id': 1037, 'image_count': 1895}, {'id': 1038, 'image_count': 24}, {'id': 1039, 'image_count': 22}, {'id': 1040, 'image_count': 35}, {'id': 1041, 'image_count': 69}, {'id': 1042, 'image_count': 962}, {'id': 1043, 'image_count': 588}, {'id': 1044, 'image_count': 21}, {'id': 1045, 'image_count': 825}, {'id': 1046, 'image_count': 52}, {'id': 1047, 'image_count': 5}, {'id': 1048, 'image_count': 5}, {'id': 1049, 'image_count': 5}, {'id': 1050, 'image_count': 1860}, {'id': 1051, 'image_count': 56}, {'id': 1052, 'image_count': 1582}, {'id': 1053, 'image_count': 7}, {'id': 1054, 'image_count': 2}, {'id': 1055, 'image_count': 1562}, {'id': 1056, 'image_count': 1885}, {'id': 1057, 'image_count': 1}, {'id': 1058, 'image_count': 5}, {'id': 1059, 'image_count': 137}, {'id': 1060, 'image_count': 1094}, {'id': 1061, 'image_count': 134}, {'id': 1062, 'image_count': 29}, {'id': 1063, 'image_count': 22}, {'id': 1064, 'image_count': 522}, {'id': 1065, 'image_count': 50}, {'id': 1066, 'image_count': 68}, {'id': 1067, 'image_count': 16}, {'id': 1068, 'image_count': 40}, {'id': 1069, 'image_count': 35}, {'id': 1070, 'image_count': 135}, {'id': 1071, 'image_count': 1413}, {'id': 1072, 'image_count': 772}, {'id': 1073, 'image_count': 50}, {'id': 1074, 'image_count': 1015}, {'id': 1075, 'image_count': 1}, {'id': 1076, 'image_count': 65}, {'id': 1077, 'image_count': 1900}, {'id': 1078, 'image_count': 1302}, {'id': 1079, 'image_count': 1977}, {'id': 1080, 'image_count': 2}, {'id': 1081, 'image_count': 29}, {'id': 1082, 'image_count': 36}, {'id': 1083, 'image_count': 138}, {'id': 1084, 'image_count': 4}, {'id': 1085, 'image_count': 67}, {'id': 1086, 'image_count': 26}, {'id': 1087, 'image_count': 25}, {'id': 1088, 'image_count': 33}, {'id': 1089, 'image_count': 37}, {'id': 1090, 'image_count': 50}, {'id': 1091, 'image_count': 270}, {'id': 1092, 'image_count': 12}, {'id': 1093, 'image_count': 316}, {'id': 1094, 'image_count': 41}, {'id': 1095, 'image_count': 224}, {'id': 1096, 'image_count': 105}, {'id': 1097, 'image_count': 1925}, {'id': 1098, 'image_count': 1021}, {'id': 1099, 'image_count': 1213}, {'id': 1100, 'image_count': 172}, {'id': 1101, 'image_count': 28}, {'id': 1102, 'image_count': 745}, {'id': 1103, 'image_count': 187}, {'id': 1104, 'image_count': 147}, {'id': 1105, 'image_count': 136}, {'id': 1106, 'image_count': 34}, {'id': 1107, 'image_count': 41}, {'id': 1108, 'image_count': 636}, {'id': 1109, 'image_count': 570}, {'id': 1110, 'image_count': 1149}, {'id': 1111, 'image_count': 61}, {'id': 1112, 'image_count': 1890}, {'id': 1113, 'image_count': 18}, {'id': 1114, 'image_count': 143}, {'id': 1115, 'image_count': 1517}, {'id': 1116, 'image_count': 7}, {'id': 1117, 'image_count': 943}, {'id': 1118, 'image_count': 6}, {'id': 1119, 'image_count': 1}, {'id': 1120, 'image_count': 11}, {'id': 1121, 'image_count': 101}, {'id': 1122, 'image_count': 1909}, {'id': 1123, 'image_count': 800}, {'id': 1124, 'image_count': 1}, {'id': 1125, 'image_count': 44}, {'id': 1126, 'image_count': 3}, {'id': 1127, 'image_count': 44}, {'id': 1128, 'image_count': 31}, {'id': 1129, 'image_count': 7}, {'id': 1130, 'image_count': 20}, {'id': 1131, 'image_count': 11}, {'id': 1132, 'image_count': 13}, {'id': 1133, 'image_count': 1924}, {'id': 1134, 'image_count': 113}, {'id': 1135, 'image_count': 2}, {'id': 1136, 'image_count': 139}, {'id': 1137, 'image_count': 12}, {'id': 1138, 'image_count': 37}, {'id': 1139, 'image_count': 1866}, {'id': 1140, 'image_count': 47}, {'id': 1141, 'image_count': 1468}, {'id': 1142, 'image_count': 729}, {'id': 1143, 'image_count': 24}, {'id': 1144, 'image_count': 1}, {'id': 1145, 'image_count': 10}, {'id': 1146, 'image_count': 3}, {'id': 1147, 'image_count': 14}, {'id': 1148, 'image_count': 4}, {'id': 1149, 'image_count': 29}, {'id': 1150, 'image_count': 4}, {'id': 1151, 'image_count': 70}, {'id': 1152, 'image_count': 46}, {'id': 1153, 'image_count': 14}, {'id': 1154, 'image_count': 48}, {'id': 1155, 'image_count': 1855}, {'id': 1156, 'image_count': 113}, {'id': 1157, 'image_count': 1}, {'id': 1158, 'image_count': 1}, {'id': 1159, 'image_count': 10}, {'id': 1160, 'image_count': 54}, {'id': 1161, 'image_count': 1923}, {'id': 1162, 'image_count': 630}, {'id': 1163, 'image_count': 31}, {'id': 1164, 'image_count': 69}, {'id': 1165, 'image_count': 7}, {'id': 1166, 'image_count': 11}, {'id': 1167, 'image_count': 1}, {'id': 1168, 'image_count': 30}, {'id': 1169, 'image_count': 50}, {'id': 1170, 'image_count': 45}, {'id': 1171, 'image_count': 28}, {'id': 1172, 'image_count': 114}, {'id': 1173, 'image_count': 193}, {'id': 1174, 'image_count': 21}, {'id': 1175, 'image_count': 91}, {'id': 1176, 'image_count': 31}, {'id': 1177, 'image_count': 1469}, {'id': 1178, 'image_count': 1924}, {'id': 1179, 'image_count': 87}, {'id': 1180, 'image_count': 77}, {'id': 1181, 'image_count': 11}, {'id': 1182, 'image_count': 47}, {'id': 1183, 'image_count': 21}, {'id': 1184, 'image_count': 47}, {'id': 1185, 'image_count': 70}, {'id': 1186, 'image_count': 1838}, {'id': 1187, 'image_count': 19}, {'id': 1188, 'image_count': 531}, {'id': 1189, 'image_count': 11}, {'id': 1190, 'image_count': 941}, {'id': 1191, 'image_count': 113}, {'id': 1192, 'image_count': 26}, {'id': 1193, 'image_count': 5}, {'id': 1194, 'image_count': 56}, {'id': 1195, 'image_count': 73}, {'id': 1196, 'image_count': 32}, {'id': 1197, 'image_count': 128}, {'id': 1198, 'image_count': 623}, {'id': 1199, 'image_count': 12}, {'id': 1200, 'image_count': 52}, {'id': 1201, 'image_count': 11}, {'id': 1202, 'image_count': 1674}, {'id': 1203, 'image_count': 81}] # noqa
+# fmt: on
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/pascal_voc.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/pascal_voc.py
new file mode 100644
index 0000000000000000000000000000000000000000..919cc4920394d3cb87ad5232adcbedc250e4db26
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/pascal_voc.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import numpy as np
+import os
+import xml.etree.ElementTree as ET
+from typing import List, Tuple, Union
+
+from annotator.oneformer.detectron2.data import DatasetCatalog, MetadataCatalog
+from annotator.oneformer.detectron2.structures import BoxMode
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+__all__ = ["load_voc_instances", "register_pascal_voc"]
+
+
+# fmt: off
+CLASS_NAMES = (
+ "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
+ "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
+ "pottedplant", "sheep", "sofa", "train", "tvmonitor"
+)
+# fmt: on
+
+
+def load_voc_instances(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
+ """
+ Load Pascal VOC detection annotations to Detectron2 format.
+
+ Args:
+ dirname: Contain "Annotations", "ImageSets", "JPEGImages"
+ split (str): one of "train", "test", "val", "trainval"
+ class_names: list or tuple of class names
+ """
+ with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
+ fileids = np.loadtxt(f, dtype=np.str)
+
+ # Needs to read many small annotation files. Makes sense at local
+ annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/"))
+ dicts = []
+ for fileid in fileids:
+ anno_file = os.path.join(annotation_dirname, fileid + ".xml")
+ jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg")
+
+ with PathManager.open(anno_file) as f:
+ tree = ET.parse(f)
+
+ r = {
+ "file_name": jpeg_file,
+ "image_id": fileid,
+ "height": int(tree.findall("./size/height")[0].text),
+ "width": int(tree.findall("./size/width")[0].text),
+ }
+ instances = []
+
+ for obj in tree.findall("object"):
+ cls = obj.find("name").text
+ # We include "difficult" samples in training.
+ # Based on limited experiments, they don't hurt accuracy.
+ # difficult = int(obj.find("difficult").text)
+ # if difficult == 1:
+ # continue
+ bbox = obj.find("bndbox")
+ bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]]
+ # Original annotations are integers in the range [1, W or H]
+ # Assuming they mean 1-based pixel indices (inclusive),
+ # a box with annotation (xmin=1, xmax=W) covers the whole image.
+ # In coordinate space this is represented by (xmin=0, xmax=W)
+ bbox[0] -= 1.0
+ bbox[1] -= 1.0
+ instances.append(
+ {"category_id": class_names.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS}
+ )
+ r["annotations"] = instances
+ dicts.append(r)
+ return dicts
+
+
+def register_pascal_voc(name, dirname, split, year, class_names=CLASS_NAMES):
+ DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names))
+ MetadataCatalog.get(name).set(
+ thing_classes=list(class_names), dirname=dirname, year=year, split=split
+ )
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/register_coco.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/register_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..e564438d5bf016bcdbb65b4bbdc215d79f579f8a
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/datasets/register_coco.py
@@ -0,0 +1,3 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from .coco import register_coco_instances # noqa
+from .coco_panoptic import register_coco_panoptic_separated # noqa
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/detection_utils.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/detection_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b00ca9126d22ecde050d0bb8501871b2cf8f13ff
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/detection_utils.py
@@ -0,0 +1,659 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+"""
+Common data processing utilities that are used in a
+typical object detection data pipeline.
+"""
+import logging
+import numpy as np
+from typing import List, Union
+import annotator.oneformer.pycocotools.mask as mask_util
+import torch
+from PIL import Image
+
+from annotator.oneformer.detectron2.structures import (
+ BitMasks,
+ Boxes,
+ BoxMode,
+ Instances,
+ Keypoints,
+ PolygonMasks,
+ RotatedBoxes,
+ polygons_to_bitmask,
+)
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+from . import transforms as T
+from .catalog import MetadataCatalog
+
+__all__ = [
+ "SizeMismatchError",
+ "convert_image_to_rgb",
+ "check_image_size",
+ "transform_proposals",
+ "transform_instance_annotations",
+ "annotations_to_instances",
+ "annotations_to_instances_rotated",
+ "build_augmentation",
+ "build_transform_gen",
+ "create_keypoint_hflip_indices",
+ "filter_empty_instances",
+ "read_image",
+]
+
+
+class SizeMismatchError(ValueError):
+ """
+ When loaded image has difference width/height compared with annotation.
+ """
+
+
+# https://en.wikipedia.org/wiki/YUV#SDTV_with_BT.601
+_M_RGB2YUV = [[0.299, 0.587, 0.114], [-0.14713, -0.28886, 0.436], [0.615, -0.51499, -0.10001]]
+_M_YUV2RGB = [[1.0, 0.0, 1.13983], [1.0, -0.39465, -0.58060], [1.0, 2.03211, 0.0]]
+
+# https://www.exiv2.org/tags.html
+_EXIF_ORIENT = 274 # exif 'Orientation' tag
+
+
+def convert_PIL_to_numpy(image, format):
+ """
+ Convert PIL image to numpy array of target format.
+
+ Args:
+ image (PIL.Image): a PIL image
+ format (str): the format of output image
+
+ Returns:
+ (np.ndarray): also see `read_image`
+ """
+ if format is not None:
+ # PIL only supports RGB, so convert to RGB and flip channels over below
+ conversion_format = format
+ if format in ["BGR", "YUV-BT.601"]:
+ conversion_format = "RGB"
+ image = image.convert(conversion_format)
+ image = np.asarray(image)
+ # PIL squeezes out the channel dimension for "L", so make it HWC
+ if format == "L":
+ image = np.expand_dims(image, -1)
+
+ # handle formats not supported by PIL
+ elif format == "BGR":
+ # flip channels if needed
+ image = image[:, :, ::-1]
+ elif format == "YUV-BT.601":
+ image = image / 255.0
+ image = np.dot(image, np.array(_M_RGB2YUV).T)
+
+ return image
+
+
+def convert_image_to_rgb(image, format):
+ """
+ Convert an image from given format to RGB.
+
+ Args:
+ image (np.ndarray or Tensor): an HWC image
+ format (str): the format of input image, also see `read_image`
+
+ Returns:
+ (np.ndarray): (H,W,3) RGB image in 0-255 range, can be either float or uint8
+ """
+ if isinstance(image, torch.Tensor):
+ image = image.cpu().numpy()
+ if format == "BGR":
+ image = image[:, :, [2, 1, 0]]
+ elif format == "YUV-BT.601":
+ image = np.dot(image, np.array(_M_YUV2RGB).T)
+ image = image * 255.0
+ else:
+ if format == "L":
+ image = image[:, :, 0]
+ image = image.astype(np.uint8)
+ image = np.asarray(Image.fromarray(image, mode=format).convert("RGB"))
+ return image
+
+
+def _apply_exif_orientation(image):
+ """
+ Applies the exif orientation correctly.
+
+ This code exists per the bug:
+ https://github.com/python-pillow/Pillow/issues/3973
+ with the function `ImageOps.exif_transpose`. The Pillow source raises errors with
+ various methods, especially `tobytes`
+
+ Function based on:
+ https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59
+ https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527
+
+ Args:
+ image (PIL.Image): a PIL image
+
+ Returns:
+ (PIL.Image): the PIL image with exif orientation applied, if applicable
+ """
+ if not hasattr(image, "getexif"):
+ return image
+
+ try:
+ exif = image.getexif()
+ except Exception: # https://github.com/facebookresearch/detectron2/issues/1885
+ exif = None
+
+ if exif is None:
+ return image
+
+ orientation = exif.get(_EXIF_ORIENT)
+
+ method = {
+ 2: Image.FLIP_LEFT_RIGHT,
+ 3: Image.ROTATE_180,
+ 4: Image.FLIP_TOP_BOTTOM,
+ 5: Image.TRANSPOSE,
+ 6: Image.ROTATE_270,
+ 7: Image.TRANSVERSE,
+ 8: Image.ROTATE_90,
+ }.get(orientation)
+
+ if method is not None:
+ return image.transpose(method)
+ return image
+
+
+def read_image(file_name, format=None):
+ """
+ Read an image into the given format.
+ Will apply rotation and flipping if the image has such exif information.
+
+ Args:
+ file_name (str): image file path
+ format (str): one of the supported image modes in PIL, or "BGR" or "YUV-BT.601".
+
+ Returns:
+ image (np.ndarray):
+ an HWC image in the given format, which is 0-255, uint8 for
+ supported image modes in PIL or "BGR"; float (0-1 for Y) for YUV-BT.601.
+ """
+ with PathManager.open(file_name, "rb") as f:
+ image = Image.open(f)
+
+ # work around this bug: https://github.com/python-pillow/Pillow/issues/3973
+ image = _apply_exif_orientation(image)
+ return convert_PIL_to_numpy(image, format)
+
+
+def check_image_size(dataset_dict, image):
+ """
+ Raise an error if the image does not match the size specified in the dict.
+ """
+ if "width" in dataset_dict or "height" in dataset_dict:
+ image_wh = (image.shape[1], image.shape[0])
+ expected_wh = (dataset_dict["width"], dataset_dict["height"])
+ if not image_wh == expected_wh:
+ raise SizeMismatchError(
+ "Mismatched image shape{}, got {}, expect {}.".format(
+ " for image " + dataset_dict["file_name"]
+ if "file_name" in dataset_dict
+ else "",
+ image_wh,
+ expected_wh,
+ )
+ + " Please check the width/height in your annotation."
+ )
+
+ # To ensure bbox always remap to original image size
+ if "width" not in dataset_dict:
+ dataset_dict["width"] = image.shape[1]
+ if "height" not in dataset_dict:
+ dataset_dict["height"] = image.shape[0]
+
+
+def transform_proposals(dataset_dict, image_shape, transforms, *, proposal_topk, min_box_size=0):
+ """
+ Apply transformations to the proposals in dataset_dict, if any.
+
+ Args:
+ dataset_dict (dict): a dict read from the dataset, possibly
+ contains fields "proposal_boxes", "proposal_objectness_logits", "proposal_bbox_mode"
+ image_shape (tuple): height, width
+ transforms (TransformList):
+ proposal_topk (int): only keep top-K scoring proposals
+ min_box_size (int): proposals with either side smaller than this
+ threshold are removed
+
+ The input dict is modified in-place, with abovementioned keys removed. A new
+ key "proposals" will be added. Its value is an `Instances`
+ object which contains the transformed proposals in its field
+ "proposal_boxes" and "objectness_logits".
+ """
+ if "proposal_boxes" in dataset_dict:
+ # Transform proposal boxes
+ boxes = transforms.apply_box(
+ BoxMode.convert(
+ dataset_dict.pop("proposal_boxes"),
+ dataset_dict.pop("proposal_bbox_mode"),
+ BoxMode.XYXY_ABS,
+ )
+ )
+ boxes = Boxes(boxes)
+ objectness_logits = torch.as_tensor(
+ dataset_dict.pop("proposal_objectness_logits").astype("float32")
+ )
+
+ boxes.clip(image_shape)
+ keep = boxes.nonempty(threshold=min_box_size)
+ boxes = boxes[keep]
+ objectness_logits = objectness_logits[keep]
+
+ proposals = Instances(image_shape)
+ proposals.proposal_boxes = boxes[:proposal_topk]
+ proposals.objectness_logits = objectness_logits[:proposal_topk]
+ dataset_dict["proposals"] = proposals
+
+
+def get_bbox(annotation):
+ """
+ Get bbox from data
+ Args:
+ annotation (dict): dict of instance annotations for a single instance.
+ Returns:
+ bbox (ndarray): x1, y1, x2, y2 coordinates
+ """
+ # bbox is 1d (per-instance bounding box)
+ bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
+ return bbox
+
+
+def transform_instance_annotations(
+ annotation, transforms, image_size, *, keypoint_hflip_indices=None
+):
+ """
+ Apply transforms to box, segmentation and keypoints annotations of a single instance.
+
+ It will use `transforms.apply_box` for the box, and
+ `transforms.apply_coords` for segmentation polygons & keypoints.
+ If you need anything more specially designed for each data structure,
+ you'll need to implement your own version of this function or the transforms.
+
+ Args:
+ annotation (dict): dict of instance annotations for a single instance.
+ It will be modified in-place.
+ transforms (TransformList or list[Transform]):
+ image_size (tuple): the height, width of the transformed image
+ keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
+
+ Returns:
+ dict:
+ the same input dict with fields "bbox", "segmentation", "keypoints"
+ transformed according to `transforms`.
+ The "bbox_mode" field will be set to XYXY_ABS.
+ """
+ if isinstance(transforms, (tuple, list)):
+ transforms = T.TransformList(transforms)
+ # bbox is 1d (per-instance bounding box)
+ bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
+ # clip transformed bbox to image size
+ bbox = transforms.apply_box(np.array([bbox]))[0].clip(min=0)
+ annotation["bbox"] = np.minimum(bbox, list(image_size + image_size)[::-1])
+ annotation["bbox_mode"] = BoxMode.XYXY_ABS
+
+ if "segmentation" in annotation:
+ # each instance contains 1 or more polygons
+ segm = annotation["segmentation"]
+ if isinstance(segm, list):
+ # polygons
+ polygons = [np.asarray(p).reshape(-1, 2) for p in segm]
+ annotation["segmentation"] = [
+ p.reshape(-1) for p in transforms.apply_polygons(polygons)
+ ]
+ elif isinstance(segm, dict):
+ # RLE
+ mask = mask_util.decode(segm)
+ mask = transforms.apply_segmentation(mask)
+ assert tuple(mask.shape[:2]) == image_size
+ annotation["segmentation"] = mask
+ else:
+ raise ValueError(
+ "Cannot transform segmentation of type '{}'!"
+ "Supported types are: polygons as list[list[float] or ndarray],"
+ " COCO-style RLE as a dict.".format(type(segm))
+ )
+
+ if "keypoints" in annotation:
+ keypoints = transform_keypoint_annotations(
+ annotation["keypoints"], transforms, image_size, keypoint_hflip_indices
+ )
+ annotation["keypoints"] = keypoints
+
+ return annotation
+
+
+def transform_keypoint_annotations(keypoints, transforms, image_size, keypoint_hflip_indices=None):
+ """
+ Transform keypoint annotations of an image.
+ If a keypoint is transformed out of image boundary, it will be marked "unlabeled" (visibility=0)
+
+ Args:
+ keypoints (list[float]): Nx3 float in Detectron2's Dataset format.
+ Each point is represented by (x, y, visibility).
+ transforms (TransformList):
+ image_size (tuple): the height, width of the transformed image
+ keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
+ When `transforms` includes horizontal flip, will use the index
+ mapping to flip keypoints.
+ """
+ # (N*3,) -> (N, 3)
+ keypoints = np.asarray(keypoints, dtype="float64").reshape(-1, 3)
+ keypoints_xy = transforms.apply_coords(keypoints[:, :2])
+
+ # Set all out-of-boundary points to "unlabeled"
+ inside = (keypoints_xy >= np.array([0, 0])) & (keypoints_xy <= np.array(image_size[::-1]))
+ inside = inside.all(axis=1)
+ keypoints[:, :2] = keypoints_xy
+ keypoints[:, 2][~inside] = 0
+
+ # This assumes that HorizFlipTransform is the only one that does flip
+ do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
+
+ # Alternative way: check if probe points was horizontally flipped.
+ # probe = np.asarray([[0.0, 0.0], [image_width, 0.0]])
+ # probe_aug = transforms.apply_coords(probe.copy())
+ # do_hflip = np.sign(probe[1][0] - probe[0][0]) != np.sign(probe_aug[1][0] - probe_aug[0][0]) # noqa
+
+ # If flipped, swap each keypoint with its opposite-handed equivalent
+ if do_hflip:
+ if keypoint_hflip_indices is None:
+ raise ValueError("Cannot flip keypoints without providing flip indices!")
+ if len(keypoints) != len(keypoint_hflip_indices):
+ raise ValueError(
+ "Keypoint data has {} points, but metadata "
+ "contains {} points!".format(len(keypoints), len(keypoint_hflip_indices))
+ )
+ keypoints = keypoints[np.asarray(keypoint_hflip_indices, dtype=np.int32), :]
+
+ # Maintain COCO convention that if visibility == 0 (unlabeled), then x, y = 0
+ keypoints[keypoints[:, 2] == 0] = 0
+ return keypoints
+
+
+def annotations_to_instances(annos, image_size, mask_format="polygon"):
+ """
+ Create an :class:`Instances` object used by the models,
+ from instance annotations in the dataset dict.
+
+ Args:
+ annos (list[dict]): a list of instance annotations in one image, each
+ element for one instance.
+ image_size (tuple): height, width
+
+ Returns:
+ Instances:
+ It will contain fields "gt_boxes", "gt_classes",
+ "gt_masks", "gt_keypoints", if they can be obtained from `annos`.
+ This is the format that builtin models expect.
+ """
+ boxes = (
+ np.stack(
+ [BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos]
+ )
+ if len(annos)
+ else np.zeros((0, 4))
+ )
+ target = Instances(image_size)
+ target.gt_boxes = Boxes(boxes)
+
+ classes = [int(obj["category_id"]) for obj in annos]
+ classes = torch.tensor(classes, dtype=torch.int64)
+ target.gt_classes = classes
+
+ if len(annos) and "segmentation" in annos[0]:
+ segms = [obj["segmentation"] for obj in annos]
+ if mask_format == "polygon":
+ try:
+ masks = PolygonMasks(segms)
+ except ValueError as e:
+ raise ValueError(
+ "Failed to use mask_format=='polygon' from the given annotations!"
+ ) from e
+ else:
+ assert mask_format == "bitmask", mask_format
+ masks = []
+ for segm in segms:
+ if isinstance(segm, list):
+ # polygon
+ masks.append(polygons_to_bitmask(segm, *image_size))
+ elif isinstance(segm, dict):
+ # COCO RLE
+ masks.append(mask_util.decode(segm))
+ elif isinstance(segm, np.ndarray):
+ assert segm.ndim == 2, "Expect segmentation of 2 dimensions, got {}.".format(
+ segm.ndim
+ )
+ # mask array
+ masks.append(segm)
+ else:
+ raise ValueError(
+ "Cannot convert segmentation of type '{}' to BitMasks!"
+ "Supported types are: polygons as list[list[float] or ndarray],"
+ " COCO-style RLE as a dict, or a binary segmentation mask "
+ " in a 2D numpy array of shape HxW.".format(type(segm))
+ )
+ # torch.from_numpy does not support array with negative stride.
+ masks = BitMasks(
+ torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
+ )
+ target.gt_masks = masks
+
+ if len(annos) and "keypoints" in annos[0]:
+ kpts = [obj.get("keypoints", []) for obj in annos]
+ target.gt_keypoints = Keypoints(kpts)
+
+ return target
+
+
+def annotations_to_instances_rotated(annos, image_size):
+ """
+ Create an :class:`Instances` object used by the models,
+ from instance annotations in the dataset dict.
+ Compared to `annotations_to_instances`, this function is for rotated boxes only
+
+ Args:
+ annos (list[dict]): a list of instance annotations in one image, each
+ element for one instance.
+ image_size (tuple): height, width
+
+ Returns:
+ Instances:
+ Containing fields "gt_boxes", "gt_classes",
+ if they can be obtained from `annos`.
+ This is the format that builtin models expect.
+ """
+ boxes = [obj["bbox"] for obj in annos]
+ target = Instances(image_size)
+ boxes = target.gt_boxes = RotatedBoxes(boxes)
+ boxes.clip(image_size)
+
+ classes = [obj["category_id"] for obj in annos]
+ classes = torch.tensor(classes, dtype=torch.int64)
+ target.gt_classes = classes
+
+ return target
+
+
+def filter_empty_instances(
+ instances, by_box=True, by_mask=True, box_threshold=1e-5, return_mask=False
+):
+ """
+ Filter out empty instances in an `Instances` object.
+
+ Args:
+ instances (Instances):
+ by_box (bool): whether to filter out instances with empty boxes
+ by_mask (bool): whether to filter out instances with empty masks
+ box_threshold (float): minimum width and height to be considered non-empty
+ return_mask (bool): whether to return boolean mask of filtered instances
+
+ Returns:
+ Instances: the filtered instances.
+ tensor[bool], optional: boolean mask of filtered instances
+ """
+ assert by_box or by_mask
+ r = []
+ if by_box:
+ r.append(instances.gt_boxes.nonempty(threshold=box_threshold))
+ if instances.has("gt_masks") and by_mask:
+ r.append(instances.gt_masks.nonempty())
+
+ # TODO: can also filter visible keypoints
+
+ if not r:
+ return instances
+ m = r[0]
+ for x in r[1:]:
+ m = m & x
+ if return_mask:
+ return instances[m], m
+ return instances[m]
+
+
+def create_keypoint_hflip_indices(dataset_names: Union[str, List[str]]) -> List[int]:
+ """
+ Args:
+ dataset_names: list of dataset names
+
+ Returns:
+ list[int]: a list of size=#keypoints, storing the
+ horizontally-flipped keypoint indices.
+ """
+ if isinstance(dataset_names, str):
+ dataset_names = [dataset_names]
+
+ check_metadata_consistency("keypoint_names", dataset_names)
+ check_metadata_consistency("keypoint_flip_map", dataset_names)
+
+ meta = MetadataCatalog.get(dataset_names[0])
+ names = meta.keypoint_names
+ # TODO flip -> hflip
+ flip_map = dict(meta.keypoint_flip_map)
+ flip_map.update({v: k for k, v in flip_map.items()})
+ flipped_names = [i if i not in flip_map else flip_map[i] for i in names]
+ flip_indices = [names.index(i) for i in flipped_names]
+ return flip_indices
+
+
+def get_fed_loss_cls_weights(dataset_names: Union[str, List[str]], freq_weight_power=1.0):
+ """
+ Get frequency weight for each class sorted by class id.
+ We now calcualte freqency weight using image_count to the power freq_weight_power.
+
+ Args:
+ dataset_names: list of dataset names
+ freq_weight_power: power value
+ """
+ if isinstance(dataset_names, str):
+ dataset_names = [dataset_names]
+
+ check_metadata_consistency("class_image_count", dataset_names)
+
+ meta = MetadataCatalog.get(dataset_names[0])
+ class_freq_meta = meta.class_image_count
+ class_freq = torch.tensor(
+ [c["image_count"] for c in sorted(class_freq_meta, key=lambda x: x["id"])]
+ )
+ class_freq_weight = class_freq.float() ** freq_weight_power
+ return class_freq_weight
+
+
+def gen_crop_transform_with_instance(crop_size, image_size, instance):
+ """
+ Generate a CropTransform so that the cropping region contains
+ the center of the given instance.
+
+ Args:
+ crop_size (tuple): h, w in pixels
+ image_size (tuple): h, w
+ instance (dict): an annotation dict of one instance, in Detectron2's
+ dataset format.
+ """
+ crop_size = np.asarray(crop_size, dtype=np.int32)
+ bbox = BoxMode.convert(instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS)
+ center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5
+ assert (
+ image_size[0] >= center_yx[0] and image_size[1] >= center_yx[1]
+ ), "The annotation bounding box is outside of the image!"
+ assert (
+ image_size[0] >= crop_size[0] and image_size[1] >= crop_size[1]
+ ), "Crop size is larger than image size!"
+
+ min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0)
+ max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0)
+ max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32))
+
+ y0 = np.random.randint(min_yx[0], max_yx[0] + 1)
+ x0 = np.random.randint(min_yx[1], max_yx[1] + 1)
+ return T.CropTransform(x0, y0, crop_size[1], crop_size[0])
+
+
+def check_metadata_consistency(key, dataset_names):
+ """
+ Check that the datasets have consistent metadata.
+
+ Args:
+ key (str): a metadata key
+ dataset_names (list[str]): a list of dataset names
+
+ Raises:
+ AttributeError: if the key does not exist in the metadata
+ ValueError: if the given datasets do not have the same metadata values defined by key
+ """
+ if len(dataset_names) == 0:
+ return
+ logger = logging.getLogger(__name__)
+ entries_per_dataset = [getattr(MetadataCatalog.get(d), key) for d in dataset_names]
+ for idx, entry in enumerate(entries_per_dataset):
+ if entry != entries_per_dataset[0]:
+ logger.error(
+ "Metadata '{}' for dataset '{}' is '{}'".format(key, dataset_names[idx], str(entry))
+ )
+ logger.error(
+ "Metadata '{}' for dataset '{}' is '{}'".format(
+ key, dataset_names[0], str(entries_per_dataset[0])
+ )
+ )
+ raise ValueError("Datasets have different metadata '{}'!".format(key))
+
+
+def build_augmentation(cfg, is_train):
+ """
+ Create a list of default :class:`Augmentation` from config.
+ Now it includes resizing and flipping.
+
+ Returns:
+ list[Augmentation]
+ """
+ if is_train:
+ min_size = cfg.INPUT.MIN_SIZE_TRAIN
+ max_size = cfg.INPUT.MAX_SIZE_TRAIN
+ sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
+ else:
+ min_size = cfg.INPUT.MIN_SIZE_TEST
+ max_size = cfg.INPUT.MAX_SIZE_TEST
+ sample_style = "choice"
+ augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
+ if is_train and cfg.INPUT.RANDOM_FLIP != "none":
+ augmentation.append(
+ T.RandomFlip(
+ horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal",
+ vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
+ )
+ )
+ return augmentation
+
+
+build_transform_gen = build_augmentation
+"""
+Alias for backward-compatibility.
+"""
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/samplers/__init__.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/samplers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..85c9f1a9df8a4038fbd4246239b699402e382309
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/samplers/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from .distributed_sampler import (
+ InferenceSampler,
+ RandomSubsetTrainingSampler,
+ RepeatFactorTrainingSampler,
+ TrainingSampler,
+)
+
+from .grouped_batch_sampler import GroupedBatchSampler
+
+__all__ = [
+ "GroupedBatchSampler",
+ "TrainingSampler",
+ "RandomSubsetTrainingSampler",
+ "InferenceSampler",
+ "RepeatFactorTrainingSampler",
+]
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/samplers/distributed_sampler.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/samplers/distributed_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd4724eac8fbff2456bd26f95e6fea5e914b73e2
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/samplers/distributed_sampler.py
@@ -0,0 +1,278 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import itertools
+import logging
+import math
+from collections import defaultdict
+from typing import Optional
+import torch
+from torch.utils.data.sampler import Sampler
+
+from annotator.oneformer.detectron2.utils import comm
+
+logger = logging.getLogger(__name__)
+
+
+class TrainingSampler(Sampler):
+ """
+ In training, we only care about the "infinite stream" of training data.
+ So this sampler produces an infinite stream of indices and
+ all workers cooperate to correctly shuffle the indices and sample different indices.
+
+ The samplers in each worker effectively produces `indices[worker_id::num_workers]`
+ where `indices` is an infinite stream of indices consisting of
+ `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
+ or `range(size) + range(size) + ...` (if shuffle is False)
+
+ Note that this sampler does not shard based on pytorch DataLoader worker id.
+ A sampler passed to pytorch DataLoader is used only with map-style dataset
+ and will not be executed inside workers.
+ But if this sampler is used in a way that it gets execute inside a dataloader
+ worker, then extra work needs to be done to shard its outputs based on worker id.
+ This is required so that workers don't produce identical data.
+ :class:`ToIterableDataset` implements this logic.
+ This note is true for all samplers in detectron2.
+ """
+
+ def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None):
+ """
+ Args:
+ size (int): the total number of data of the underlying dataset to sample from
+ shuffle (bool): whether to shuffle the indices or not
+ seed (int): the initial seed of the shuffle. Must be the same
+ across all workers. If None, will use a random seed shared
+ among workers (require synchronization among all workers).
+ """
+ if not isinstance(size, int):
+ raise TypeError(f"TrainingSampler(size=) expects an int. Got type {type(size)}.")
+ if size <= 0:
+ raise ValueError(f"TrainingSampler(size=) expects a positive int. Got {size}.")
+ self._size = size
+ self._shuffle = shuffle
+ if seed is None:
+ seed = comm.shared_random_seed()
+ self._seed = int(seed)
+
+ self._rank = comm.get_rank()
+ self._world_size = comm.get_world_size()
+
+ def __iter__(self):
+ start = self._rank
+ yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
+
+ def _infinite_indices(self):
+ g = torch.Generator()
+ g.manual_seed(self._seed)
+ while True:
+ if self._shuffle:
+ yield from torch.randperm(self._size, generator=g).tolist()
+ else:
+ yield from torch.arange(self._size).tolist()
+
+
+class RandomSubsetTrainingSampler(TrainingSampler):
+ """
+ Similar to TrainingSampler, but only sample a random subset of indices.
+ This is useful when you want to estimate the accuracy vs data-number curves by
+ training the model with different subset_ratio.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ subset_ratio: float,
+ shuffle: bool = True,
+ seed_shuffle: Optional[int] = None,
+ seed_subset: Optional[int] = None,
+ ):
+ """
+ Args:
+ size (int): the total number of data of the underlying dataset to sample from
+ subset_ratio (float): the ratio of subset data to sample from the underlying dataset
+ shuffle (bool): whether to shuffle the indices or not
+ seed_shuffle (int): the initial seed of the shuffle. Must be the same
+ across all workers. If None, will use a random seed shared
+ among workers (require synchronization among all workers).
+ seed_subset (int): the seed to randomize the subset to be sampled.
+ Must be the same across all workers. If None, will use a random seed shared
+ among workers (require synchronization among all workers).
+ """
+ super().__init__(size=size, shuffle=shuffle, seed=seed_shuffle)
+
+ assert 0.0 < subset_ratio <= 1.0
+ self._size_subset = int(size * subset_ratio)
+ assert self._size_subset > 0
+ if seed_subset is None:
+ seed_subset = comm.shared_random_seed()
+ self._seed_subset = int(seed_subset)
+
+ # randomly generate the subset indexes to be sampled from
+ g = torch.Generator()
+ g.manual_seed(self._seed_subset)
+ indexes_randperm = torch.randperm(self._size, generator=g)
+ self._indexes_subset = indexes_randperm[: self._size_subset]
+
+ logger.info("Using RandomSubsetTrainingSampler......")
+ logger.info(f"Randomly sample {self._size_subset} data from the original {self._size} data")
+
+ def _infinite_indices(self):
+ g = torch.Generator()
+ g.manual_seed(self._seed) # self._seed equals seed_shuffle from __init__()
+ while True:
+ if self._shuffle:
+ # generate a random permutation to shuffle self._indexes_subset
+ randperm = torch.randperm(self._size_subset, generator=g)
+ yield from self._indexes_subset[randperm].tolist()
+ else:
+ yield from self._indexes_subset.tolist()
+
+
+class RepeatFactorTrainingSampler(Sampler):
+ """
+ Similar to TrainingSampler, but a sample may appear more times than others based
+ on its "repeat factor". This is suitable for training on class imbalanced datasets like LVIS.
+ """
+
+ def __init__(self, repeat_factors, *, shuffle=True, seed=None):
+ """
+ Args:
+ repeat_factors (Tensor): a float vector, the repeat factor for each indice. When it's
+ full of ones, it is equivalent to ``TrainingSampler(len(repeat_factors), ...)``.
+ shuffle (bool): whether to shuffle the indices or not
+ seed (int): the initial seed of the shuffle. Must be the same
+ across all workers. If None, will use a random seed shared
+ among workers (require synchronization among all workers).
+ """
+ self._shuffle = shuffle
+ if seed is None:
+ seed = comm.shared_random_seed()
+ self._seed = int(seed)
+
+ self._rank = comm.get_rank()
+ self._world_size = comm.get_world_size()
+
+ # Split into whole number (_int_part) and fractional (_frac_part) parts.
+ self._int_part = torch.trunc(repeat_factors)
+ self._frac_part = repeat_factors - self._int_part
+
+ @staticmethod
+ def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh):
+ """
+ Compute (fractional) per-image repeat factors based on category frequency.
+ The repeat factor for an image is a function of the frequency of the rarest
+ category labeled in that image. The "frequency of category c" in [0, 1] is defined
+ as the fraction of images in the training set (without repeats) in which category c
+ appears.
+ See :paper:`lvis` (>= v2) Appendix B.2.
+
+ Args:
+ dataset_dicts (list[dict]): annotations in Detectron2 dataset format.
+ repeat_thresh (float): frequency threshold below which data is repeated.
+ If the frequency is half of `repeat_thresh`, the image will be
+ repeated twice.
+
+ Returns:
+ torch.Tensor:
+ the i-th element is the repeat factor for the dataset image at index i.
+ """
+ # 1. For each category c, compute the fraction of images that contain it: f(c)
+ category_freq = defaultdict(int)
+ for dataset_dict in dataset_dicts: # For each image (without repeats)
+ cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
+ for cat_id in cat_ids:
+ category_freq[cat_id] += 1
+ num_images = len(dataset_dicts)
+ for k, v in category_freq.items():
+ category_freq[k] = v / num_images
+
+ # 2. For each category c, compute the category-level repeat factor:
+ # r(c) = max(1, sqrt(t / f(c)))
+ category_rep = {
+ cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq))
+ for cat_id, cat_freq in category_freq.items()
+ }
+
+ # 3. For each image I, compute the image-level repeat factor:
+ # r(I) = max_{c in I} r(c)
+ rep_factors = []
+ for dataset_dict in dataset_dicts:
+ cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]}
+ rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0)
+ rep_factors.append(rep_factor)
+
+ return torch.tensor(rep_factors, dtype=torch.float32)
+
+ def _get_epoch_indices(self, generator):
+ """
+ Create a list of dataset indices (with repeats) to use for one epoch.
+
+ Args:
+ generator (torch.Generator): pseudo random number generator used for
+ stochastic rounding.
+
+ Returns:
+ torch.Tensor: list of dataset indices to use in one epoch. Each index
+ is repeated based on its calculated repeat factor.
+ """
+ # Since repeat factors are fractional, we use stochastic rounding so
+ # that the target repeat factor is achieved in expectation over the
+ # course of training
+ rands = torch.rand(len(self._frac_part), generator=generator)
+ rep_factors = self._int_part + (rands < self._frac_part).float()
+ # Construct a list of indices in which we repeat images as specified
+ indices = []
+ for dataset_index, rep_factor in enumerate(rep_factors):
+ indices.extend([dataset_index] * int(rep_factor.item()))
+ return torch.tensor(indices, dtype=torch.int64)
+
+ def __iter__(self):
+ start = self._rank
+ yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
+
+ def _infinite_indices(self):
+ g = torch.Generator()
+ g.manual_seed(self._seed)
+ while True:
+ # Sample indices with repeats determined by stochastic rounding; each
+ # "epoch" may have a slightly different size due to the rounding.
+ indices = self._get_epoch_indices(g)
+ if self._shuffle:
+ randperm = torch.randperm(len(indices), generator=g)
+ yield from indices[randperm].tolist()
+ else:
+ yield from indices.tolist()
+
+
+class InferenceSampler(Sampler):
+ """
+ Produce indices for inference across all workers.
+ Inference needs to run on the __exact__ set of samples,
+ therefore when the total number of samples is not divisible by the number of workers,
+ this sampler produces different number of samples on different workers.
+ """
+
+ def __init__(self, size: int):
+ """
+ Args:
+ size (int): the total number of data of the underlying dataset to sample from
+ """
+ self._size = size
+ assert size > 0
+ self._rank = comm.get_rank()
+ self._world_size = comm.get_world_size()
+ self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
+
+ @staticmethod
+ def _get_local_indices(total_size, world_size, rank):
+ shard_size = total_size // world_size
+ left = total_size % world_size
+ shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
+
+ begin = sum(shard_sizes[:rank])
+ end = min(sum(shard_sizes[: rank + 1]), total_size)
+ return range(begin, end)
+
+ def __iter__(self):
+ yield from self._local_indices
+
+ def __len__(self):
+ return len(self._local_indices)
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/samplers/grouped_batch_sampler.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/samplers/grouped_batch_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b247730aacd04dd0c752664acde3257c4eddd71
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/samplers/grouped_batch_sampler.py
@@ -0,0 +1,47 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import numpy as np
+from torch.utils.data.sampler import BatchSampler, Sampler
+
+
+class GroupedBatchSampler(BatchSampler):
+ """
+ Wraps another sampler to yield a mini-batch of indices.
+ It enforces that the batch only contain elements from the same group.
+ It also tries to provide mini-batches which follows an ordering which is
+ as close as possible to the ordering from the original sampler.
+ """
+
+ def __init__(self, sampler, group_ids, batch_size):
+ """
+ Args:
+ sampler (Sampler): Base sampler.
+ group_ids (list[int]): If the sampler produces indices in range [0, N),
+ `group_ids` must be a list of `N` ints which contains the group id of each sample.
+ The group ids must be a set of integers in the range [0, num_groups).
+ batch_size (int): Size of mini-batch.
+ """
+ if not isinstance(sampler, Sampler):
+ raise ValueError(
+ "sampler should be an instance of "
+ "torch.utils.data.Sampler, but got sampler={}".format(sampler)
+ )
+ self.sampler = sampler
+ self.group_ids = np.asarray(group_ids)
+ assert self.group_ids.ndim == 1
+ self.batch_size = batch_size
+ groups = np.unique(self.group_ids).tolist()
+
+ # buffer the indices of each group until batch size is reached
+ self.buffer_per_group = {k: [] for k in groups}
+
+ def __iter__(self):
+ for idx in self.sampler:
+ group_id = self.group_ids[idx]
+ group_buffer = self.buffer_per_group[group_id]
+ group_buffer.append(idx)
+ if len(group_buffer) == self.batch_size:
+ yield group_buffer[:] # yield a copy of the list
+ del group_buffer[:]
+
+ def __len__(self):
+ raise NotImplementedError("len() of GroupedBatchSampler is not well-defined.")
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/transforms/__init__.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/transforms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e91c6cdfacd6992a7a1e80c7d2e4b38b2cf7dcde
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/transforms/__init__.py
@@ -0,0 +1,14 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from fvcore.transforms.transform import Transform, TransformList # order them first
+from fvcore.transforms.transform import *
+from .transform import *
+from .augmentation import *
+from .augmentation_impl import *
+
+__all__ = [k for k in globals().keys() if not k.startswith("_")]
+
+
+from annotator.oneformer.detectron2.utils.env import fixup_module_metadata
+
+fixup_module_metadata(__name__, globals(), __all__)
+del fixup_module_metadata
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/transforms/augmentation.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/transforms/augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..63dd41aef658c9b51c7246880399405a029c5580
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/transforms/augmentation.py
@@ -0,0 +1,380 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import inspect
+import numpy as np
+import pprint
+from typing import Any, List, Optional, Tuple, Union
+from fvcore.transforms.transform import Transform, TransformList
+
+"""
+See "Data Augmentation" tutorial for an overview of the system:
+https://detectron2.readthedocs.io/tutorials/augmentation.html
+"""
+
+
+__all__ = [
+ "Augmentation",
+ "AugmentationList",
+ "AugInput",
+ "TransformGen",
+ "apply_transform_gens",
+ "StandardAugInput",
+ "apply_augmentations",
+]
+
+
+def _check_img_dtype(img):
+ assert isinstance(img, np.ndarray), "[Augmentation] Needs an numpy array, but got a {}!".format(
+ type(img)
+ )
+ assert not isinstance(img.dtype, np.integer) or (
+ img.dtype == np.uint8
+ ), "[Augmentation] Got image of type {}, use uint8 or floating points instead!".format(
+ img.dtype
+ )
+ assert img.ndim in [2, 3], img.ndim
+
+
+def _get_aug_input_args(aug, aug_input) -> List[Any]:
+ """
+ Get the arguments to be passed to ``aug.get_transform`` from the input ``aug_input``.
+ """
+ if aug.input_args is None:
+ # Decide what attributes are needed automatically
+ prms = list(inspect.signature(aug.get_transform).parameters.items())
+ # The default behavior is: if there is one parameter, then its "image"
+ # (work automatically for majority of use cases, and also avoid BC breaking),
+ # Otherwise, use the argument names.
+ if len(prms) == 1:
+ names = ("image",)
+ else:
+ names = []
+ for name, prm in prms:
+ if prm.kind in (
+ inspect.Parameter.VAR_POSITIONAL,
+ inspect.Parameter.VAR_KEYWORD,
+ ):
+ raise TypeError(
+ f""" \
+The default implementation of `{type(aug)}.__call__` does not allow \
+`{type(aug)}.get_transform` to use variable-length arguments (*args, **kwargs)! \
+If arguments are unknown, reimplement `__call__` instead. \
+"""
+ )
+ names.append(name)
+ aug.input_args = tuple(names)
+
+ args = []
+ for f in aug.input_args:
+ try:
+ args.append(getattr(aug_input, f))
+ except AttributeError as e:
+ raise AttributeError(
+ f"{type(aug)}.get_transform needs input attribute '{f}', "
+ f"but it is not an attribute of {type(aug_input)}!"
+ ) from e
+ return args
+
+
+class Augmentation:
+ """
+ Augmentation defines (often random) policies/strategies to generate :class:`Transform`
+ from data. It is often used for pre-processing of input data.
+
+ A "policy" that generates a :class:`Transform` may, in the most general case,
+ need arbitrary information from input data in order to determine what transforms
+ to apply. Therefore, each :class:`Augmentation` instance defines the arguments
+ needed by its :meth:`get_transform` method. When called with the positional arguments,
+ the :meth:`get_transform` method executes the policy.
+
+ Note that :class:`Augmentation` defines the policies to create a :class:`Transform`,
+ but not how to execute the actual transform operations to those data.
+ Its :meth:`__call__` method will use :meth:`AugInput.transform` to execute the transform.
+
+ The returned `Transform` object is meant to describe deterministic transformation, which means
+ it can be re-applied on associated data, e.g. the geometry of an image and its segmentation
+ masks need to be transformed together.
+ (If such re-application is not needed, then determinism is not a crucial requirement.)
+ """
+
+ input_args: Optional[Tuple[str]] = None
+ """
+ Stores the attribute names needed by :meth:`get_transform`, e.g. ``("image", "sem_seg")``.
+ By default, it is just a tuple of argument names in :meth:`self.get_transform`, which often only
+ contain "image". As long as the argument name convention is followed, there is no need for
+ users to touch this attribute.
+ """
+
+ def _init(self, params=None):
+ if params:
+ for k, v in params.items():
+ if k != "self" and not k.startswith("_"):
+ setattr(self, k, v)
+
+ def get_transform(self, *args) -> Transform:
+ """
+ Execute the policy based on input data, and decide what transform to apply to inputs.
+
+ Args:
+ args: Any fixed-length positional arguments. By default, the name of the arguments
+ should exist in the :class:`AugInput` to be used.
+
+ Returns:
+ Transform: Returns the deterministic transform to apply to the input.
+
+ Examples:
+ ::
+ class MyAug:
+ # if a policy needs to know both image and semantic segmentation
+ def get_transform(image, sem_seg) -> T.Transform:
+ pass
+ tfm: Transform = MyAug().get_transform(image, sem_seg)
+ new_image = tfm.apply_image(image)
+
+ Notes:
+ Users can freely use arbitrary new argument names in custom
+ :meth:`get_transform` method, as long as they are available in the
+ input data. In detectron2 we use the following convention:
+
+ * image: (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or
+ floating point in range [0, 1] or [0, 255].
+ * boxes: (N,4) ndarray of float32. It represents the instance bounding boxes
+ of N instances. Each is in XYXY format in unit of absolute coordinates.
+ * sem_seg: (H,W) ndarray of type uint8. Each element is an integer label of pixel.
+
+ We do not specify convention for other types and do not include builtin
+ :class:`Augmentation` that uses other types in detectron2.
+ """
+ raise NotImplementedError
+
+ def __call__(self, aug_input) -> Transform:
+ """
+ Augment the given `aug_input` **in-place**, and return the transform that's used.
+
+ This method will be called to apply the augmentation. In most augmentation, it
+ is enough to use the default implementation, which calls :meth:`get_transform`
+ using the inputs. But a subclass can overwrite it to have more complicated logic.
+
+ Args:
+ aug_input (AugInput): an object that has attributes needed by this augmentation
+ (defined by ``self.get_transform``). Its ``transform`` method will be called
+ to in-place transform it.
+
+ Returns:
+ Transform: the transform that is applied on the input.
+ """
+ args = _get_aug_input_args(self, aug_input)
+ tfm = self.get_transform(*args)
+ assert isinstance(tfm, (Transform, TransformList)), (
+ f"{type(self)}.get_transform must return an instance of Transform! "
+ f"Got {type(tfm)} instead."
+ )
+ aug_input.transform(tfm)
+ return tfm
+
+ def _rand_range(self, low=1.0, high=None, size=None):
+ """
+ Uniform float random number between low and high.
+ """
+ if high is None:
+ low, high = 0, low
+ if size is None:
+ size = []
+ return np.random.uniform(low, high, size)
+
+ def __repr__(self):
+ """
+ Produce something like:
+ "MyAugmentation(field1={self.field1}, field2={self.field2})"
+ """
+ try:
+ sig = inspect.signature(self.__init__)
+ classname = type(self).__name__
+ argstr = []
+ for name, param in sig.parameters.items():
+ assert (
+ param.kind != param.VAR_POSITIONAL and param.kind != param.VAR_KEYWORD
+ ), "The default __repr__ doesn't support *args or **kwargs"
+ assert hasattr(self, name), (
+ "Attribute {} not found! "
+ "Default __repr__ only works if attributes match the constructor.".format(name)
+ )
+ attr = getattr(self, name)
+ default = param.default
+ if default is attr:
+ continue
+ attr_str = pprint.pformat(attr)
+ if "\n" in attr_str:
+ # don't show it if pformat decides to use >1 lines
+ attr_str = "..."
+ argstr.append("{}={}".format(name, attr_str))
+ return "{}({})".format(classname, ", ".join(argstr))
+ except AssertionError:
+ return super().__repr__()
+
+ __str__ = __repr__
+
+
+class _TransformToAug(Augmentation):
+ def __init__(self, tfm: Transform):
+ self.tfm = tfm
+
+ def get_transform(self, *args):
+ return self.tfm
+
+ def __repr__(self):
+ return repr(self.tfm)
+
+ __str__ = __repr__
+
+
+def _transform_to_aug(tfm_or_aug):
+ """
+ Wrap Transform into Augmentation.
+ Private, used internally to implement augmentations.
+ """
+ assert isinstance(tfm_or_aug, (Transform, Augmentation)), tfm_or_aug
+ if isinstance(tfm_or_aug, Augmentation):
+ return tfm_or_aug
+ else:
+ return _TransformToAug(tfm_or_aug)
+
+
+class AugmentationList(Augmentation):
+ """
+ Apply a sequence of augmentations.
+
+ It has ``__call__`` method to apply the augmentations.
+
+ Note that :meth:`get_transform` method is impossible (will throw error if called)
+ for :class:`AugmentationList`, because in order to apply a sequence of augmentations,
+ the kth augmentation must be applied first, to provide inputs needed by the (k+1)th
+ augmentation.
+ """
+
+ def __init__(self, augs):
+ """
+ Args:
+ augs (list[Augmentation or Transform]):
+ """
+ super().__init__()
+ self.augs = [_transform_to_aug(x) for x in augs]
+
+ def __call__(self, aug_input) -> TransformList:
+ tfms = []
+ for x in self.augs:
+ tfm = x(aug_input)
+ tfms.append(tfm)
+ return TransformList(tfms)
+
+ def __repr__(self):
+ msgs = [str(x) for x in self.augs]
+ return "AugmentationList[{}]".format(", ".join(msgs))
+
+ __str__ = __repr__
+
+
+class AugInput:
+ """
+ Input that can be used with :meth:`Augmentation.__call__`.
+ This is a standard implementation for the majority of use cases.
+ This class provides the standard attributes **"image", "boxes", "sem_seg"**
+ defined in :meth:`__init__` and they may be needed by different augmentations.
+ Most augmentation policies do not need attributes beyond these three.
+
+ After applying augmentations to these attributes (using :meth:`AugInput.transform`),
+ the returned transforms can then be used to transform other data structures that users have.
+
+ Examples:
+ ::
+ input = AugInput(image, boxes=boxes)
+ tfms = augmentation(input)
+ transformed_image = input.image
+ transformed_boxes = input.boxes
+ transformed_other_data = tfms.apply_other(other_data)
+
+ An extended project that works with new data types may implement augmentation policies
+ that need other inputs. An algorithm may need to transform inputs in a way different
+ from the standard approach defined in this class. In those rare situations, users can
+ implement a class similar to this class, that satify the following condition:
+
+ * The input must provide access to these data in the form of attribute access
+ (``getattr``). For example, if an :class:`Augmentation` to be applied needs "image"
+ and "sem_seg" arguments, its input must have the attribute "image" and "sem_seg".
+ * The input must have a ``transform(tfm: Transform) -> None`` method which
+ in-place transforms all its attributes.
+ """
+
+ # TODO maybe should support more builtin data types here
+ def __init__(
+ self,
+ image: np.ndarray,
+ *,
+ boxes: Optional[np.ndarray] = None,
+ sem_seg: Optional[np.ndarray] = None,
+ ):
+ """
+ Args:
+ image (ndarray): (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or
+ floating point in range [0, 1] or [0, 255]. The meaning of C is up
+ to users.
+ boxes (ndarray or None): Nx4 float32 boxes in XYXY_ABS mode
+ sem_seg (ndarray or None): HxW uint8 semantic segmentation mask. Each element
+ is an integer label of pixel.
+ """
+ _check_img_dtype(image)
+ self.image = image
+ self.boxes = boxes
+ self.sem_seg = sem_seg
+
+ def transform(self, tfm: Transform) -> None:
+ """
+ In-place transform all attributes of this class.
+
+ By "in-place", it means after calling this method, accessing an attribute such
+ as ``self.image`` will return transformed data.
+ """
+ self.image = tfm.apply_image(self.image)
+ if self.boxes is not None:
+ self.boxes = tfm.apply_box(self.boxes)
+ if self.sem_seg is not None:
+ self.sem_seg = tfm.apply_segmentation(self.sem_seg)
+
+ def apply_augmentations(
+ self, augmentations: List[Union[Augmentation, Transform]]
+ ) -> TransformList:
+ """
+ Equivalent of ``AugmentationList(augmentations)(self)``
+ """
+ return AugmentationList(augmentations)(self)
+
+
+def apply_augmentations(augmentations: List[Union[Transform, Augmentation]], inputs):
+ """
+ Use ``T.AugmentationList(augmentations)(inputs)`` instead.
+ """
+ if isinstance(inputs, np.ndarray):
+ # handle the common case of image-only Augmentation, also for backward compatibility
+ image_only = True
+ inputs = AugInput(inputs)
+ else:
+ image_only = False
+ tfms = inputs.apply_augmentations(augmentations)
+ return inputs.image if image_only else inputs, tfms
+
+
+apply_transform_gens = apply_augmentations
+"""
+Alias for backward-compatibility.
+"""
+
+TransformGen = Augmentation
+"""
+Alias for Augmentation, since it is something that generates :class:`Transform`s
+"""
+
+StandardAugInput = AugInput
+"""
+Alias for compatibility. It's not worth the complexity to have two classes.
+"""
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/transforms/augmentation_impl.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/transforms/augmentation_impl.py
new file mode 100644
index 0000000000000000000000000000000000000000..965f0a947d7c3ff03b0990f1a645703d470227de
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/transforms/augmentation_impl.py
@@ -0,0 +1,736 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+"""
+Implement many useful :class:`Augmentation`.
+"""
+import numpy as np
+import sys
+from numpy import random
+from typing import Tuple
+import torch
+from fvcore.transforms.transform import (
+ BlendTransform,
+ CropTransform,
+ HFlipTransform,
+ NoOpTransform,
+ PadTransform,
+ Transform,
+ TransformList,
+ VFlipTransform,
+)
+from PIL import Image
+
+from annotator.oneformer.detectron2.structures import Boxes, pairwise_iou
+
+from .augmentation import Augmentation, _transform_to_aug
+from .transform import ExtentTransform, ResizeTransform, RotationTransform
+
+__all__ = [
+ "FixedSizeCrop",
+ "RandomApply",
+ "RandomBrightness",
+ "RandomContrast",
+ "RandomCrop",
+ "RandomExtent",
+ "RandomFlip",
+ "RandomSaturation",
+ "RandomLighting",
+ "RandomRotation",
+ "Resize",
+ "ResizeScale",
+ "ResizeShortestEdge",
+ "RandomCrop_CategoryAreaConstraint",
+ "RandomResize",
+ "MinIoURandomCrop",
+]
+
+
+class RandomApply(Augmentation):
+ """
+ Randomly apply an augmentation with a given probability.
+ """
+
+ def __init__(self, tfm_or_aug, prob=0.5):
+ """
+ Args:
+ tfm_or_aug (Transform, Augmentation): the transform or augmentation
+ to be applied. It can either be a `Transform` or `Augmentation`
+ instance.
+ prob (float): probability between 0.0 and 1.0 that
+ the wrapper transformation is applied
+ """
+ super().__init__()
+ self.aug = _transform_to_aug(tfm_or_aug)
+ assert 0.0 <= prob <= 1.0, f"Probablity must be between 0.0 and 1.0 (given: {prob})"
+ self.prob = prob
+
+ def get_transform(self, *args):
+ do = self._rand_range() < self.prob
+ if do:
+ return self.aug.get_transform(*args)
+ else:
+ return NoOpTransform()
+
+ def __call__(self, aug_input):
+ do = self._rand_range() < self.prob
+ if do:
+ return self.aug(aug_input)
+ else:
+ return NoOpTransform()
+
+
+class RandomFlip(Augmentation):
+ """
+ Flip the image horizontally or vertically with the given probability.
+ """
+
+ def __init__(self, prob=0.5, *, horizontal=True, vertical=False):
+ """
+ Args:
+ prob (float): probability of flip.
+ horizontal (boolean): whether to apply horizontal flipping
+ vertical (boolean): whether to apply vertical flipping
+ """
+ super().__init__()
+
+ if horizontal and vertical:
+ raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.")
+ if not horizontal and not vertical:
+ raise ValueError("At least one of horiz or vert has to be True!")
+ self._init(locals())
+
+ def get_transform(self, image):
+ h, w = image.shape[:2]
+ do = self._rand_range() < self.prob
+ if do:
+ if self.horizontal:
+ return HFlipTransform(w)
+ elif self.vertical:
+ return VFlipTransform(h)
+ else:
+ return NoOpTransform()
+
+
+class Resize(Augmentation):
+ """Resize image to a fixed target size"""
+
+ def __init__(self, shape, interp=Image.BILINEAR):
+ """
+ Args:
+ shape: (h, w) tuple or a int
+ interp: PIL interpolation method
+ """
+ if isinstance(shape, int):
+ shape = (shape, shape)
+ shape = tuple(shape)
+ self._init(locals())
+
+ def get_transform(self, image):
+ return ResizeTransform(
+ image.shape[0], image.shape[1], self.shape[0], self.shape[1], self.interp
+ )
+
+
+class ResizeShortestEdge(Augmentation):
+ """
+ Resize the image while keeping the aspect ratio unchanged.
+ It attempts to scale the shorter edge to the given `short_edge_length`,
+ as long as the longer edge does not exceed `max_size`.
+ If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
+ """
+
+ @torch.jit.unused
+ def __init__(
+ self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR
+ ):
+ """
+ Args:
+ short_edge_length (list[int]): If ``sample_style=="range"``,
+ a [min, max] interval from which to sample the shortest edge length.
+ If ``sample_style=="choice"``, a list of shortest edge lengths to sample from.
+ max_size (int): maximum allowed longest edge length.
+ sample_style (str): either "range" or "choice".
+ """
+ super().__init__()
+ assert sample_style in ["range", "choice"], sample_style
+
+ self.is_range = sample_style == "range"
+ if isinstance(short_edge_length, int):
+ short_edge_length = (short_edge_length, short_edge_length)
+ if self.is_range:
+ assert len(short_edge_length) == 2, (
+ "short_edge_length must be two values using 'range' sample style."
+ f" Got {short_edge_length}!"
+ )
+ self._init(locals())
+
+ @torch.jit.unused
+ def get_transform(self, image):
+ h, w = image.shape[:2]
+ if self.is_range:
+ size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1)
+ else:
+ size = np.random.choice(self.short_edge_length)
+ if size == 0:
+ return NoOpTransform()
+
+ newh, neww = ResizeShortestEdge.get_output_shape(h, w, size, self.max_size)
+ return ResizeTransform(h, w, newh, neww, self.interp)
+
+ @staticmethod
+ def get_output_shape(
+ oldh: int, oldw: int, short_edge_length: int, max_size: int
+ ) -> Tuple[int, int]:
+ """
+ Compute the output size given input size and target short edge length.
+ """
+ h, w = oldh, oldw
+ size = short_edge_length * 1.0
+ scale = size / min(h, w)
+ if h < w:
+ newh, neww = size, scale * w
+ else:
+ newh, neww = scale * h, size
+ if max(newh, neww) > max_size:
+ scale = max_size * 1.0 / max(newh, neww)
+ newh = newh * scale
+ neww = neww * scale
+ neww = int(neww + 0.5)
+ newh = int(newh + 0.5)
+ return (newh, neww)
+
+
+class ResizeScale(Augmentation):
+ """
+ Takes target size as input and randomly scales the given target size between `min_scale`
+ and `max_scale`. It then scales the input image such that it fits inside the scaled target
+ box, keeping the aspect ratio constant.
+ This implements the resize part of the Google's 'resize_and_crop' data augmentation:
+ https://github.com/tensorflow/tpu/blob/master/models/official/detection/utils/input_utils.py#L127
+ """
+
+ def __init__(
+ self,
+ min_scale: float,
+ max_scale: float,
+ target_height: int,
+ target_width: int,
+ interp: int = Image.BILINEAR,
+ ):
+ """
+ Args:
+ min_scale: minimum image scale range.
+ max_scale: maximum image scale range.
+ target_height: target image height.
+ target_width: target image width.
+ interp: image interpolation method.
+ """
+ super().__init__()
+ self._init(locals())
+
+ def _get_resize(self, image: np.ndarray, scale: float) -> Transform:
+ input_size = image.shape[:2]
+
+ # Compute new target size given a scale.
+ target_size = (self.target_height, self.target_width)
+ target_scale_size = np.multiply(target_size, scale)
+
+ # Compute actual rescaling applied to input image and output size.
+ output_scale = np.minimum(
+ target_scale_size[0] / input_size[0], target_scale_size[1] / input_size[1]
+ )
+ output_size = np.round(np.multiply(input_size, output_scale)).astype(int)
+
+ return ResizeTransform(
+ input_size[0], input_size[1], output_size[0], output_size[1], self.interp
+ )
+
+ def get_transform(self, image: np.ndarray) -> Transform:
+ random_scale = np.random.uniform(self.min_scale, self.max_scale)
+ return self._get_resize(image, random_scale)
+
+
+class RandomRotation(Augmentation):
+ """
+ This method returns a copy of this image, rotated the given
+ number of degrees counter clockwise around the given center.
+ """
+
+ def __init__(self, angle, expand=True, center=None, sample_style="range", interp=None):
+ """
+ Args:
+ angle (list[float]): If ``sample_style=="range"``,
+ a [min, max] interval from which to sample the angle (in degrees).
+ If ``sample_style=="choice"``, a list of angles to sample from
+ expand (bool): choose if the image should be resized to fit the whole
+ rotated image (default), or simply cropped
+ center (list[[float, float]]): If ``sample_style=="range"``,
+ a [[minx, miny], [maxx, maxy]] relative interval from which to sample the center,
+ [0, 0] being the top left of the image and [1, 1] the bottom right.
+ If ``sample_style=="choice"``, a list of centers to sample from
+ Default: None, which means that the center of rotation is the center of the image
+ center has no effect if expand=True because it only affects shifting
+ """
+ super().__init__()
+ assert sample_style in ["range", "choice"], sample_style
+ self.is_range = sample_style == "range"
+ if isinstance(angle, (float, int)):
+ angle = (angle, angle)
+ if center is not None and isinstance(center[0], (float, int)):
+ center = (center, center)
+ self._init(locals())
+
+ def get_transform(self, image):
+ h, w = image.shape[:2]
+ center = None
+ if self.is_range:
+ angle = np.random.uniform(self.angle[0], self.angle[1])
+ if self.center is not None:
+ center = (
+ np.random.uniform(self.center[0][0], self.center[1][0]),
+ np.random.uniform(self.center[0][1], self.center[1][1]),
+ )
+ else:
+ angle = np.random.choice(self.angle)
+ if self.center is not None:
+ center = np.random.choice(self.center)
+
+ if center is not None:
+ center = (w * center[0], h * center[1]) # Convert to absolute coordinates
+
+ if angle % 360 == 0:
+ return NoOpTransform()
+
+ return RotationTransform(h, w, angle, expand=self.expand, center=center, interp=self.interp)
+
+
+class FixedSizeCrop(Augmentation):
+ """
+ If `crop_size` is smaller than the input image size, then it uses a random crop of
+ the crop size. If `crop_size` is larger than the input image size, then it pads
+ the right and the bottom of the image to the crop size if `pad` is True, otherwise
+ it returns the smaller image.
+ """
+
+ def __init__(
+ self,
+ crop_size: Tuple[int],
+ pad: bool = True,
+ pad_value: float = 128.0,
+ seg_pad_value: int = 255,
+ ):
+ """
+ Args:
+ crop_size: target image (height, width).
+ pad: if True, will pad images smaller than `crop_size` up to `crop_size`
+ pad_value: the padding value to the image.
+ seg_pad_value: the padding value to the segmentation mask.
+ """
+ super().__init__()
+ self._init(locals())
+
+ def _get_crop(self, image: np.ndarray) -> Transform:
+ # Compute the image scale and scaled size.
+ input_size = image.shape[:2]
+ output_size = self.crop_size
+
+ # Add random crop if the image is scaled up.
+ max_offset = np.subtract(input_size, output_size)
+ max_offset = np.maximum(max_offset, 0)
+ offset = np.multiply(max_offset, np.random.uniform(0.0, 1.0))
+ offset = np.round(offset).astype(int)
+ return CropTransform(
+ offset[1], offset[0], output_size[1], output_size[0], input_size[1], input_size[0]
+ )
+
+ def _get_pad(self, image: np.ndarray) -> Transform:
+ # Compute the image scale and scaled size.
+ input_size = image.shape[:2]
+ output_size = self.crop_size
+
+ # Add padding if the image is scaled down.
+ pad_size = np.subtract(output_size, input_size)
+ pad_size = np.maximum(pad_size, 0)
+ original_size = np.minimum(input_size, output_size)
+ return PadTransform(
+ 0,
+ 0,
+ pad_size[1],
+ pad_size[0],
+ original_size[1],
+ original_size[0],
+ self.pad_value,
+ self.seg_pad_value,
+ )
+
+ def get_transform(self, image: np.ndarray) -> TransformList:
+ transforms = [self._get_crop(image)]
+ if self.pad:
+ transforms.append(self._get_pad(image))
+ return TransformList(transforms)
+
+
+class RandomCrop(Augmentation):
+ """
+ Randomly crop a rectangle region out of an image.
+ """
+
+ def __init__(self, crop_type: str, crop_size):
+ """
+ Args:
+ crop_type (str): one of "relative_range", "relative", "absolute", "absolute_range".
+ crop_size (tuple[float, float]): two floats, explained below.
+
+ - "relative": crop a (H * crop_size[0], W * crop_size[1]) region from an input image of
+ size (H, W). crop size should be in (0, 1]
+ - "relative_range": uniformly sample two values from [crop_size[0], 1]
+ and [crop_size[1]], 1], and use them as in "relative" crop type.
+ - "absolute" crop a (crop_size[0], crop_size[1]) region from input image.
+ crop_size must be smaller than the input image size.
+ - "absolute_range", for an input of size (H, W), uniformly sample H_crop in
+ [crop_size[0], min(H, crop_size[1])] and W_crop in [crop_size[0], min(W, crop_size[1])].
+ Then crop a region (H_crop, W_crop).
+ """
+ # TODO style of relative_range and absolute_range are not consistent:
+ # one takes (h, w) but another takes (min, max)
+ super().__init__()
+ assert crop_type in ["relative_range", "relative", "absolute", "absolute_range"]
+ self._init(locals())
+
+ def get_transform(self, image):
+ h, w = image.shape[:2]
+ croph, cropw = self.get_crop_size((h, w))
+ assert h >= croph and w >= cropw, "Shape computation in {} has bugs.".format(self)
+ h0 = np.random.randint(h - croph + 1)
+ w0 = np.random.randint(w - cropw + 1)
+ return CropTransform(w0, h0, cropw, croph)
+
+ def get_crop_size(self, image_size):
+ """
+ Args:
+ image_size (tuple): height, width
+
+ Returns:
+ crop_size (tuple): height, width in absolute pixels
+ """
+ h, w = image_size
+ if self.crop_type == "relative":
+ ch, cw = self.crop_size
+ return int(h * ch + 0.5), int(w * cw + 0.5)
+ elif self.crop_type == "relative_range":
+ crop_size = np.asarray(self.crop_size, dtype=np.float32)
+ ch, cw = crop_size + np.random.rand(2) * (1 - crop_size)
+ return int(h * ch + 0.5), int(w * cw + 0.5)
+ elif self.crop_type == "absolute":
+ return (min(self.crop_size[0], h), min(self.crop_size[1], w))
+ elif self.crop_type == "absolute_range":
+ assert self.crop_size[0] <= self.crop_size[1]
+ ch = np.random.randint(min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1)
+ cw = np.random.randint(min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1)
+ return ch, cw
+ else:
+ raise NotImplementedError("Unknown crop type {}".format(self.crop_type))
+
+
+class RandomCrop_CategoryAreaConstraint(Augmentation):
+ """
+ Similar to :class:`RandomCrop`, but find a cropping window such that no single category
+ occupies a ratio of more than `single_category_max_area` in semantic segmentation ground
+ truth, which can cause unstability in training. The function attempts to find such a valid
+ cropping window for at most 10 times.
+ """
+
+ def __init__(
+ self,
+ crop_type: str,
+ crop_size,
+ single_category_max_area: float = 1.0,
+ ignored_category: int = None,
+ ):
+ """
+ Args:
+ crop_type, crop_size: same as in :class:`RandomCrop`
+ single_category_max_area: the maximum allowed area ratio of a
+ category. Set to 1.0 to disable
+ ignored_category: allow this category in the semantic segmentation
+ ground truth to exceed the area ratio. Usually set to the category
+ that's ignored in training.
+ """
+ self.crop_aug = RandomCrop(crop_type, crop_size)
+ self._init(locals())
+
+ def get_transform(self, image, sem_seg):
+ if self.single_category_max_area >= 1.0:
+ return self.crop_aug.get_transform(image)
+ else:
+ h, w = sem_seg.shape
+ for _ in range(10):
+ crop_size = self.crop_aug.get_crop_size((h, w))
+ y0 = np.random.randint(h - crop_size[0] + 1)
+ x0 = np.random.randint(w - crop_size[1] + 1)
+ sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]]
+ labels, cnt = np.unique(sem_seg_temp, return_counts=True)
+ if self.ignored_category is not None:
+ cnt = cnt[labels != self.ignored_category]
+ if len(cnt) > 1 and np.max(cnt) < np.sum(cnt) * self.single_category_max_area:
+ break
+ crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0])
+ return crop_tfm
+
+
+class RandomExtent(Augmentation):
+ """
+ Outputs an image by cropping a random "subrect" of the source image.
+
+ The subrect can be parameterized to include pixels outside the source image,
+ in which case they will be set to zeros (i.e. black). The size of the output
+ image will vary with the size of the random subrect.
+ """
+
+ def __init__(self, scale_range, shift_range):
+ """
+ Args:
+ output_size (h, w): Dimensions of output image
+ scale_range (l, h): Range of input-to-output size scaling factor
+ shift_range (x, y): Range of shifts of the cropped subrect. The rect
+ is shifted by [w / 2 * Uniform(-x, x), h / 2 * Uniform(-y, y)],
+ where (w, h) is the (width, height) of the input image. Set each
+ component to zero to crop at the image's center.
+ """
+ super().__init__()
+ self._init(locals())
+
+ def get_transform(self, image):
+ img_h, img_w = image.shape[:2]
+
+ # Initialize src_rect to fit the input image.
+ src_rect = np.array([-0.5 * img_w, -0.5 * img_h, 0.5 * img_w, 0.5 * img_h])
+
+ # Apply a random scaling to the src_rect.
+ src_rect *= np.random.uniform(self.scale_range[0], self.scale_range[1])
+
+ # Apply a random shift to the coordinates origin.
+ src_rect[0::2] += self.shift_range[0] * img_w * (np.random.rand() - 0.5)
+ src_rect[1::2] += self.shift_range[1] * img_h * (np.random.rand() - 0.5)
+
+ # Map src_rect coordinates into image coordinates (center at corner).
+ src_rect[0::2] += 0.5 * img_w
+ src_rect[1::2] += 0.5 * img_h
+
+ return ExtentTransform(
+ src_rect=(src_rect[0], src_rect[1], src_rect[2], src_rect[3]),
+ output_size=(int(src_rect[3] - src_rect[1]), int(src_rect[2] - src_rect[0])),
+ )
+
+
+class RandomContrast(Augmentation):
+ """
+ Randomly transforms image contrast.
+
+ Contrast intensity is uniformly sampled in (intensity_min, intensity_max).
+ - intensity < 1 will reduce contrast
+ - intensity = 1 will preserve the input image
+ - intensity > 1 will increase contrast
+
+ See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
+ """
+
+ def __init__(self, intensity_min, intensity_max):
+ """
+ Args:
+ intensity_min (float): Minimum augmentation
+ intensity_max (float): Maximum augmentation
+ """
+ super().__init__()
+ self._init(locals())
+
+ def get_transform(self, image):
+ w = np.random.uniform(self.intensity_min, self.intensity_max)
+ return BlendTransform(src_image=image.mean(), src_weight=1 - w, dst_weight=w)
+
+
+class RandomBrightness(Augmentation):
+ """
+ Randomly transforms image brightness.
+
+ Brightness intensity is uniformly sampled in (intensity_min, intensity_max).
+ - intensity < 1 will reduce brightness
+ - intensity = 1 will preserve the input image
+ - intensity > 1 will increase brightness
+
+ See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
+ """
+
+ def __init__(self, intensity_min, intensity_max):
+ """
+ Args:
+ intensity_min (float): Minimum augmentation
+ intensity_max (float): Maximum augmentation
+ """
+ super().__init__()
+ self._init(locals())
+
+ def get_transform(self, image):
+ w = np.random.uniform(self.intensity_min, self.intensity_max)
+ return BlendTransform(src_image=0, src_weight=1 - w, dst_weight=w)
+
+
+class RandomSaturation(Augmentation):
+ """
+ Randomly transforms saturation of an RGB image.
+ Input images are assumed to have 'RGB' channel order.
+
+ Saturation intensity is uniformly sampled in (intensity_min, intensity_max).
+ - intensity < 1 will reduce saturation (make the image more grayscale)
+ - intensity = 1 will preserve the input image
+ - intensity > 1 will increase saturation
+
+ See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html
+ """
+
+ def __init__(self, intensity_min, intensity_max):
+ """
+ Args:
+ intensity_min (float): Minimum augmentation (1 preserves input).
+ intensity_max (float): Maximum augmentation (1 preserves input).
+ """
+ super().__init__()
+ self._init(locals())
+
+ def get_transform(self, image):
+ assert image.shape[-1] == 3, "RandomSaturation only works on RGB images"
+ w = np.random.uniform(self.intensity_min, self.intensity_max)
+ grayscale = image.dot([0.299, 0.587, 0.114])[:, :, np.newaxis]
+ return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w)
+
+
+class RandomLighting(Augmentation):
+ """
+ The "lighting" augmentation described in AlexNet, using fixed PCA over ImageNet.
+ Input images are assumed to have 'RGB' channel order.
+
+ The degree of color jittering is randomly sampled via a normal distribution,
+ with standard deviation given by the scale parameter.
+ """
+
+ def __init__(self, scale):
+ """
+ Args:
+ scale (float): Standard deviation of principal component weighting.
+ """
+ super().__init__()
+ self._init(locals())
+ self.eigen_vecs = np.array(
+ [[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]
+ )
+ self.eigen_vals = np.array([0.2175, 0.0188, 0.0045])
+
+ def get_transform(self, image):
+ assert image.shape[-1] == 3, "RandomLighting only works on RGB images"
+ weights = np.random.normal(scale=self.scale, size=3)
+ return BlendTransform(
+ src_image=self.eigen_vecs.dot(weights * self.eigen_vals), src_weight=1.0, dst_weight=1.0
+ )
+
+
+class RandomResize(Augmentation):
+ """Randomly resize image to a target size in shape_list"""
+
+ def __init__(self, shape_list, interp=Image.BILINEAR):
+ """
+ Args:
+ shape_list: a list of shapes in (h, w)
+ interp: PIL interpolation method
+ """
+ self.shape_list = shape_list
+ self._init(locals())
+
+ def get_transform(self, image):
+ shape_idx = np.random.randint(low=0, high=len(self.shape_list))
+ h, w = self.shape_list[shape_idx]
+ return ResizeTransform(image.shape[0], image.shape[1], h, w, self.interp)
+
+
+class MinIoURandomCrop(Augmentation):
+ """Random crop the image & bboxes, the cropped patches have minimum IoU
+ requirement with original image & bboxes, the IoU threshold is randomly
+ selected from min_ious.
+
+ Args:
+ min_ious (tuple): minimum IoU threshold for all intersections with
+ bounding boxes
+ min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
+ where a >= min_crop_size)
+ mode_trials: number of trials for sampling min_ious threshold
+ crop_trials: number of trials for sampling crop_size after cropping
+ """
+
+ def __init__(
+ self,
+ min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
+ min_crop_size=0.3,
+ mode_trials=1000,
+ crop_trials=50,
+ ):
+ self.min_ious = min_ious
+ self.sample_mode = (1, *min_ious, 0)
+ self.min_crop_size = min_crop_size
+ self.mode_trials = mode_trials
+ self.crop_trials = crop_trials
+
+ def get_transform(self, image, boxes):
+ """Call function to crop images and bounding boxes with minimum IoU
+ constraint.
+
+ Args:
+ boxes: ground truth boxes in (x1, y1, x2, y2) format
+ """
+ if boxes is None:
+ return NoOpTransform()
+ h, w, c = image.shape
+ for _ in range(self.mode_trials):
+ mode = random.choice(self.sample_mode)
+ self.mode = mode
+ if mode == 1:
+ return NoOpTransform()
+
+ min_iou = mode
+ for _ in range(self.crop_trials):
+ new_w = random.uniform(self.min_crop_size * w, w)
+ new_h = random.uniform(self.min_crop_size * h, h)
+
+ # h / w in [0.5, 2]
+ if new_h / new_w < 0.5 or new_h / new_w > 2:
+ continue
+
+ left = random.uniform(w - new_w)
+ top = random.uniform(h - new_h)
+
+ patch = np.array((int(left), int(top), int(left + new_w), int(top + new_h)))
+ # Line or point crop is not allowed
+ if patch[2] == patch[0] or patch[3] == patch[1]:
+ continue
+ overlaps = pairwise_iou(
+ Boxes(patch.reshape(-1, 4)), Boxes(boxes.reshape(-1, 4))
+ ).reshape(-1)
+ if len(overlaps) > 0 and overlaps.min() < min_iou:
+ continue
+
+ # center of boxes should inside the crop img
+ # only adjust boxes and instance masks when the gt is not empty
+ if len(overlaps) > 0:
+ # adjust boxes
+ def is_center_of_bboxes_in_patch(boxes, patch):
+ center = (boxes[:, :2] + boxes[:, 2:]) / 2
+ mask = (
+ (center[:, 0] > patch[0])
+ * (center[:, 1] > patch[1])
+ * (center[:, 0] < patch[2])
+ * (center[:, 1] < patch[3])
+ )
+ return mask
+
+ mask = is_center_of_bboxes_in_patch(boxes, patch)
+ if not mask.any():
+ continue
+ return CropTransform(int(left), int(top), int(new_w), int(new_h))
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/data/transforms/transform.py b/sd-webui-controlnet/annotator/oneformer/detectron2/data/transforms/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..de44b991d7ab0d920ffb769e1402f08e358d37f7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/data/transforms/transform.py
@@ -0,0 +1,351 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+"""
+See "Data Augmentation" tutorial for an overview of the system:
+https://detectron2.readthedocs.io/tutorials/augmentation.html
+"""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from fvcore.transforms.transform import (
+ CropTransform,
+ HFlipTransform,
+ NoOpTransform,
+ Transform,
+ TransformList,
+)
+from PIL import Image
+
+try:
+ import cv2 # noqa
+except ImportError:
+ # OpenCV is an optional dependency at the moment
+ pass
+
+__all__ = [
+ "ExtentTransform",
+ "ResizeTransform",
+ "RotationTransform",
+ "ColorTransform",
+ "PILColorTransform",
+]
+
+
+class ExtentTransform(Transform):
+ """
+ Extracts a subregion from the source image and scales it to the output size.
+
+ The fill color is used to map pixels from the source rect that fall outside
+ the source image.
+
+ See: https://pillow.readthedocs.io/en/latest/PIL.html#PIL.ImageTransform.ExtentTransform
+ """
+
+ def __init__(self, src_rect, output_size, interp=Image.LINEAR, fill=0):
+ """
+ Args:
+ src_rect (x0, y0, x1, y1): src coordinates
+ output_size (h, w): dst image size
+ interp: PIL interpolation methods
+ fill: Fill color used when src_rect extends outside image
+ """
+ super().__init__()
+ self._set_attributes(locals())
+
+ def apply_image(self, img, interp=None):
+ h, w = self.output_size
+ if len(img.shape) > 2 and img.shape[2] == 1:
+ pil_image = Image.fromarray(img[:, :, 0], mode="L")
+ else:
+ pil_image = Image.fromarray(img)
+ pil_image = pil_image.transform(
+ size=(w, h),
+ method=Image.EXTENT,
+ data=self.src_rect,
+ resample=interp if interp else self.interp,
+ fill=self.fill,
+ )
+ ret = np.asarray(pil_image)
+ if len(img.shape) > 2 and img.shape[2] == 1:
+ ret = np.expand_dims(ret, -1)
+ return ret
+
+ def apply_coords(self, coords):
+ # Transform image center from source coordinates into output coordinates
+ # and then map the new origin to the corner of the output image.
+ h, w = self.output_size
+ x0, y0, x1, y1 = self.src_rect
+ new_coords = coords.astype(np.float32)
+ new_coords[:, 0] -= 0.5 * (x0 + x1)
+ new_coords[:, 1] -= 0.5 * (y0 + y1)
+ new_coords[:, 0] *= w / (x1 - x0)
+ new_coords[:, 1] *= h / (y1 - y0)
+ new_coords[:, 0] += 0.5 * w
+ new_coords[:, 1] += 0.5 * h
+ return new_coords
+
+ def apply_segmentation(self, segmentation):
+ segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
+ return segmentation
+
+
+class ResizeTransform(Transform):
+ """
+ Resize the image to a target size.
+ """
+
+ def __init__(self, h, w, new_h, new_w, interp=None):
+ """
+ Args:
+ h, w (int): original image size
+ new_h, new_w (int): new image size
+ interp: PIL interpolation methods, defaults to bilinear.
+ """
+ # TODO decide on PIL vs opencv
+ super().__init__()
+ if interp is None:
+ interp = Image.BILINEAR
+ self._set_attributes(locals())
+
+ def apply_image(self, img, interp=None):
+ assert img.shape[:2] == (self.h, self.w)
+ assert len(img.shape) <= 4
+ interp_method = interp if interp is not None else self.interp
+
+ if img.dtype == np.uint8:
+ if len(img.shape) > 2 and img.shape[2] == 1:
+ pil_image = Image.fromarray(img[:, :, 0], mode="L")
+ else:
+ pil_image = Image.fromarray(img)
+ pil_image = pil_image.resize((self.new_w, self.new_h), interp_method)
+ ret = np.asarray(pil_image)
+ if len(img.shape) > 2 and img.shape[2] == 1:
+ ret = np.expand_dims(ret, -1)
+ else:
+ # PIL only supports uint8
+ if any(x < 0 for x in img.strides):
+ img = np.ascontiguousarray(img)
+ img = torch.from_numpy(img)
+ shape = list(img.shape)
+ shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
+ img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
+ _PIL_RESIZE_TO_INTERPOLATE_MODE = {
+ Image.NEAREST: "nearest",
+ Image.BILINEAR: "bilinear",
+ Image.BICUBIC: "bicubic",
+ }
+ mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[interp_method]
+ align_corners = None if mode == "nearest" else False
+ img = F.interpolate(
+ img, (self.new_h, self.new_w), mode=mode, align_corners=align_corners
+ )
+ shape[:2] = (self.new_h, self.new_w)
+ ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
+
+ return ret
+
+ def apply_coords(self, coords):
+ coords[:, 0] = coords[:, 0] * (self.new_w * 1.0 / self.w)
+ coords[:, 1] = coords[:, 1] * (self.new_h * 1.0 / self.h)
+ return coords
+
+ def apply_segmentation(self, segmentation):
+ segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
+ return segmentation
+
+ def inverse(self):
+ return ResizeTransform(self.new_h, self.new_w, self.h, self.w, self.interp)
+
+
+class RotationTransform(Transform):
+ """
+ This method returns a copy of this image, rotated the given
+ number of degrees counter clockwise around its center.
+ """
+
+ def __init__(self, h, w, angle, expand=True, center=None, interp=None):
+ """
+ Args:
+ h, w (int): original image size
+ angle (float): degrees for rotation
+ expand (bool): choose if the image should be resized to fit the whole
+ rotated image (default), or simply cropped
+ center (tuple (width, height)): coordinates of the rotation center
+ if left to None, the center will be fit to the center of each image
+ center has no effect if expand=True because it only affects shifting
+ interp: cv2 interpolation method, default cv2.INTER_LINEAR
+ """
+ super().__init__()
+ image_center = np.array((w / 2, h / 2))
+ if center is None:
+ center = image_center
+ if interp is None:
+ interp = cv2.INTER_LINEAR
+ abs_cos, abs_sin = (abs(np.cos(np.deg2rad(angle))), abs(np.sin(np.deg2rad(angle))))
+ if expand:
+ # find the new width and height bounds
+ bound_w, bound_h = np.rint(
+ [h * abs_sin + w * abs_cos, h * abs_cos + w * abs_sin]
+ ).astype(int)
+ else:
+ bound_w, bound_h = w, h
+
+ self._set_attributes(locals())
+ self.rm_coords = self.create_rotation_matrix()
+ # Needed because of this problem https://github.com/opencv/opencv/issues/11784
+ self.rm_image = self.create_rotation_matrix(offset=-0.5)
+
+ def apply_image(self, img, interp=None):
+ """
+ img should be a numpy array, formatted as Height * Width * Nchannels
+ """
+ if len(img) == 0 or self.angle % 360 == 0:
+ return img
+ assert img.shape[:2] == (self.h, self.w)
+ interp = interp if interp is not None else self.interp
+ return cv2.warpAffine(img, self.rm_image, (self.bound_w, self.bound_h), flags=interp)
+
+ def apply_coords(self, coords):
+ """
+ coords should be a N * 2 array-like, containing N couples of (x, y) points
+ """
+ coords = np.asarray(coords, dtype=float)
+ if len(coords) == 0 or self.angle % 360 == 0:
+ return coords
+ return cv2.transform(coords[:, np.newaxis, :], self.rm_coords)[:, 0, :]
+
+ def apply_segmentation(self, segmentation):
+ segmentation = self.apply_image(segmentation, interp=cv2.INTER_NEAREST)
+ return segmentation
+
+ def create_rotation_matrix(self, offset=0):
+ center = (self.center[0] + offset, self.center[1] + offset)
+ rm = cv2.getRotationMatrix2D(tuple(center), self.angle, 1)
+ if self.expand:
+ # Find the coordinates of the center of rotation in the new image
+ # The only point for which we know the future coordinates is the center of the image
+ rot_im_center = cv2.transform(self.image_center[None, None, :] + offset, rm)[0, 0, :]
+ new_center = np.array([self.bound_w / 2, self.bound_h / 2]) + offset - rot_im_center
+ # shift the rotation center to the new coordinates
+ rm[:, 2] += new_center
+ return rm
+
+ def inverse(self):
+ """
+ The inverse is to rotate it back with expand, and crop to get the original shape.
+ """
+ if not self.expand: # Not possible to inverse if a part of the image is lost
+ raise NotImplementedError()
+ rotation = RotationTransform(
+ self.bound_h, self.bound_w, -self.angle, True, None, self.interp
+ )
+ crop = CropTransform(
+ (rotation.bound_w - self.w) // 2, (rotation.bound_h - self.h) // 2, self.w, self.h
+ )
+ return TransformList([rotation, crop])
+
+
+class ColorTransform(Transform):
+ """
+ Generic wrapper for any photometric transforms.
+ These transformations should only affect the color space and
+ not the coordinate space of the image (e.g. annotation
+ coordinates such as bounding boxes should not be changed)
+ """
+
+ def __init__(self, op):
+ """
+ Args:
+ op (Callable): operation to be applied to the image,
+ which takes in an ndarray and returns an ndarray.
+ """
+ if not callable(op):
+ raise ValueError("op parameter should be callable")
+ super().__init__()
+ self._set_attributes(locals())
+
+ def apply_image(self, img):
+ return self.op(img)
+
+ def apply_coords(self, coords):
+ return coords
+
+ def inverse(self):
+ return NoOpTransform()
+
+ def apply_segmentation(self, segmentation):
+ return segmentation
+
+
+class PILColorTransform(ColorTransform):
+ """
+ Generic wrapper for PIL Photometric image transforms,
+ which affect the color space and not the coordinate
+ space of the image
+ """
+
+ def __init__(self, op):
+ """
+ Args:
+ op (Callable): operation to be applied to the image,
+ which takes in a PIL Image and returns a transformed
+ PIL Image.
+ For reference on possible operations see:
+ - https://pillow.readthedocs.io/en/stable/
+ """
+ if not callable(op):
+ raise ValueError("op parameter should be callable")
+ super().__init__(op)
+
+ def apply_image(self, img):
+ img = Image.fromarray(img)
+ return np.asarray(super().apply_image(img))
+
+
+def HFlip_rotated_box(transform, rotated_boxes):
+ """
+ Apply the horizontal flip transform on rotated boxes.
+
+ Args:
+ rotated_boxes (ndarray): Nx5 floating point array of
+ (x_center, y_center, width, height, angle_degrees) format
+ in absolute coordinates.
+ """
+ # Transform x_center
+ rotated_boxes[:, 0] = transform.width - rotated_boxes[:, 0]
+ # Transform angle
+ rotated_boxes[:, 4] = -rotated_boxes[:, 4]
+ return rotated_boxes
+
+
+def Resize_rotated_box(transform, rotated_boxes):
+ """
+ Apply the resizing transform on rotated boxes. For details of how these (approximation)
+ formulas are derived, please refer to :meth:`RotatedBoxes.scale`.
+
+ Args:
+ rotated_boxes (ndarray): Nx5 floating point array of
+ (x_center, y_center, width, height, angle_degrees) format
+ in absolute coordinates.
+ """
+ scale_factor_x = transform.new_w * 1.0 / transform.w
+ scale_factor_y = transform.new_h * 1.0 / transform.h
+ rotated_boxes[:, 0] *= scale_factor_x
+ rotated_boxes[:, 1] *= scale_factor_y
+ theta = rotated_boxes[:, 4] * np.pi / 180.0
+ c = np.cos(theta)
+ s = np.sin(theta)
+ rotated_boxes[:, 2] *= np.sqrt(np.square(scale_factor_x * c) + np.square(scale_factor_y * s))
+ rotated_boxes[:, 3] *= np.sqrt(np.square(scale_factor_x * s) + np.square(scale_factor_y * c))
+ rotated_boxes[:, 4] = np.arctan2(scale_factor_x * s, scale_factor_y * c) * 180 / np.pi
+
+ return rotated_boxes
+
+
+HFlipTransform.register_type("rotated_box", HFlip_rotated_box)
+ResizeTransform.register_type("rotated_box", Resize_rotated_box)
+
+# not necessary any more with latest fvcore
+NoOpTransform.register_type("rotated_box", lambda t, x: x)
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/engine/__init__.py b/sd-webui-controlnet/annotator/oneformer/detectron2/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..08a61572b4c7d09c8d400e903a96cbf5b2cc4763
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/engine/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+from .launch import *
+from .train_loop import *
+
+__all__ = [k for k in globals().keys() if not k.startswith("_")]
+
+
+# prefer to let hooks and defaults live in separate namespaces (therefore not in __all__)
+# but still make them available here
+from .hooks import *
+from .defaults import *
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/engine/defaults.py b/sd-webui-controlnet/annotator/oneformer/detectron2/engine/defaults.py
new file mode 100644
index 0000000000000000000000000000000000000000..51d49148ca7b048402a63490bf7df83a43c65d9f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/engine/defaults.py
@@ -0,0 +1,715 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+"""
+This file contains components with some default boilerplate logic user may need
+in training / testing. They will not work for everyone, but many users may find them useful.
+
+The behavior of functions/classes in this file is subject to change,
+since they are meant to represent the "common default behavior" people need in their projects.
+"""
+
+import argparse
+import logging
+import os
+import sys
+import weakref
+from collections import OrderedDict
+from typing import Optional
+import torch
+from fvcore.nn.precise_bn import get_bn_modules
+from omegaconf import OmegaConf
+from torch.nn.parallel import DistributedDataParallel
+
+import annotator.oneformer.detectron2.data.transforms as T
+from annotator.oneformer.detectron2.checkpoint import DetectionCheckpointer
+from annotator.oneformer.detectron2.config import CfgNode, LazyConfig
+from annotator.oneformer.detectron2.data import (
+ MetadataCatalog,
+ build_detection_test_loader,
+ build_detection_train_loader,
+)
+from annotator.oneformer.detectron2.evaluation import (
+ DatasetEvaluator,
+ inference_on_dataset,
+ print_csv_format,
+ verify_results,
+)
+from annotator.oneformer.detectron2.modeling import build_model
+from annotator.oneformer.detectron2.solver import build_lr_scheduler, build_optimizer
+from annotator.oneformer.detectron2.utils import comm
+from annotator.oneformer.detectron2.utils.collect_env import collect_env_info
+from annotator.oneformer.detectron2.utils.env import seed_all_rng
+from annotator.oneformer.detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+from annotator.oneformer.detectron2.utils.logger import setup_logger
+
+from . import hooks
+from .train_loop import AMPTrainer, SimpleTrainer, TrainerBase
+
+__all__ = [
+ "create_ddp_model",
+ "default_argument_parser",
+ "default_setup",
+ "default_writers",
+ "DefaultPredictor",
+ "DefaultTrainer",
+]
+
+
+def create_ddp_model(model, *, fp16_compression=False, **kwargs):
+ """
+ Create a DistributedDataParallel model if there are >1 processes.
+
+ Args:
+ model: a torch.nn.Module
+ fp16_compression: add fp16 compression hooks to the ddp object.
+ See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
+ kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
+ """ # noqa
+ if comm.get_world_size() == 1:
+ return model
+ if "device_ids" not in kwargs:
+ kwargs["device_ids"] = [comm.get_local_rank()]
+ ddp = DistributedDataParallel(model, **kwargs)
+ if fp16_compression:
+ from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
+
+ ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
+ return ddp
+
+
+def default_argument_parser(epilog=None):
+ """
+ Create a parser with some common arguments used by detectron2 users.
+
+ Args:
+ epilog (str): epilog passed to ArgumentParser describing the usage.
+
+ Returns:
+ argparse.ArgumentParser:
+ """
+ parser = argparse.ArgumentParser(
+ epilog=epilog
+ or f"""
+Examples:
+
+Run on single machine:
+ $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
+
+Change some config options:
+ $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
+
+Run on multiple machines:
+ (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags]
+ (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags]
+""",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+ parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
+ parser.add_argument(
+ "--resume",
+ action="store_true",
+ help="Whether to attempt to resume from the checkpoint directory. "
+ "See documentation of `DefaultTrainer.resume_or_load()` for what it means.",
+ )
+ parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
+ parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
+ parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
+ parser.add_argument(
+ "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
+ )
+
+ # PyTorch still may leave orphan processes in multi-gpu training.
+ # Therefore we use a deterministic way to obtain port,
+ # so that users are aware of orphan processes by seeing the port occupied.
+ port = 2**15 + 2**14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2**14
+ parser.add_argument(
+ "--dist-url",
+ default="tcp://127.0.0.1:{}".format(port),
+ help="initialization URL for pytorch distributed backend. See "
+ "https://pytorch.org/docs/stable/distributed.html for details.",
+ )
+ parser.add_argument(
+ "opts",
+ help="""
+Modify config options at the end of the command. For Yacs configs, use
+space-separated "PATH.KEY VALUE" pairs.
+For python-based LazyConfig, use "path.key=value".
+ """.strip(),
+ default=None,
+ nargs=argparse.REMAINDER,
+ )
+ return parser
+
+
+def _try_get_key(cfg, *keys, default=None):
+ """
+ Try select keys from cfg until the first key that exists. Otherwise return default.
+ """
+ if isinstance(cfg, CfgNode):
+ cfg = OmegaConf.create(cfg.dump())
+ for k in keys:
+ none = object()
+ p = OmegaConf.select(cfg, k, default=none)
+ if p is not none:
+ return p
+ return default
+
+
+def _highlight(code, filename):
+ try:
+ import pygments
+ except ImportError:
+ return code
+
+ from pygments.lexers import Python3Lexer, YamlLexer
+ from pygments.formatters import Terminal256Formatter
+
+ lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer()
+ code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai"))
+ return code
+
+
+def default_setup(cfg, args):
+ """
+ Perform some basic common setups at the beginning of a job, including:
+
+ 1. Set up the detectron2 logger
+ 2. Log basic information about environment, cmdline arguments, and config
+ 3. Backup the config to the output directory
+
+ Args:
+ cfg (CfgNode or omegaconf.DictConfig): the full config to be used
+ args (argparse.NameSpace): the command line arguments to be logged
+ """
+ output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir")
+ if comm.is_main_process() and output_dir:
+ PathManager.mkdirs(output_dir)
+
+ rank = comm.get_rank()
+ setup_logger(output_dir, distributed_rank=rank, name="fvcore")
+ logger = setup_logger(output_dir, distributed_rank=rank)
+
+ logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
+ logger.info("Environment info:\n" + collect_env_info())
+
+ logger.info("Command line arguments: " + str(args))
+ if hasattr(args, "config_file") and args.config_file != "":
+ logger.info(
+ "Contents of args.config_file={}:\n{}".format(
+ args.config_file,
+ _highlight(PathManager.open(args.config_file, "r").read(), args.config_file),
+ )
+ )
+
+ if comm.is_main_process() and output_dir:
+ # Note: some of our scripts may expect the existence of
+ # config.yaml in output directory
+ path = os.path.join(output_dir, "config.yaml")
+ if isinstance(cfg, CfgNode):
+ logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml")))
+ with PathManager.open(path, "w") as f:
+ f.write(cfg.dump())
+ else:
+ LazyConfig.save(cfg, path)
+ logger.info("Full config saved to {}".format(path))
+
+ # make sure each worker has a different, yet deterministic seed if specified
+ seed = _try_get_key(cfg, "SEED", "train.seed", default=-1)
+ seed_all_rng(None if seed < 0 else seed + rank)
+
+ # cudnn benchmark has large overhead. It shouldn't be used considering the small size of
+ # typical validation set.
+ if not (hasattr(args, "eval_only") and args.eval_only):
+ torch.backends.cudnn.benchmark = _try_get_key(
+ cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False
+ )
+
+
+def default_writers(output_dir: str, max_iter: Optional[int] = None):
+ """
+ Build a list of :class:`EventWriter` to be used.
+ It now consists of a :class:`CommonMetricPrinter`,
+ :class:`TensorboardXWriter` and :class:`JSONWriter`.
+
+ Args:
+ output_dir: directory to store JSON metrics and tensorboard events
+ max_iter: the total number of iterations
+
+ Returns:
+ list[EventWriter]: a list of :class:`EventWriter` objects.
+ """
+ PathManager.mkdirs(output_dir)
+ return [
+ # It may not always print what you want to see, since it prints "common" metrics only.
+ CommonMetricPrinter(max_iter),
+ JSONWriter(os.path.join(output_dir, "metrics.json")),
+ TensorboardXWriter(output_dir),
+ ]
+
+
+class DefaultPredictor:
+ """
+ Create a simple end-to-end predictor with the given config that runs on
+ single device for a single input image.
+
+ Compared to using the model directly, this class does the following additions:
+
+ 1. Load checkpoint from `cfg.MODEL.WEIGHTS`.
+ 2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`.
+ 3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`.
+ 4. Take one input image and produce a single output, instead of a batch.
+
+ This is meant for simple demo purposes, so it does the above steps automatically.
+ This is not meant for benchmarks or running complicated inference logic.
+ If you'd like to do anything more complicated, please refer to its source code as
+ examples to build and use the model manually.
+
+ Attributes:
+ metadata (Metadata): the metadata of the underlying dataset, obtained from
+ cfg.DATASETS.TEST.
+
+ Examples:
+ ::
+ pred = DefaultPredictor(cfg)
+ inputs = cv2.imread("input.jpg")
+ outputs = pred(inputs)
+ """
+
+ def __init__(self, cfg):
+ self.cfg = cfg.clone() # cfg can be modified by model
+ self.model = build_model(self.cfg)
+ self.model.eval()
+ if len(cfg.DATASETS.TEST):
+ self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])
+
+ checkpointer = DetectionCheckpointer(self.model)
+ checkpointer.load(cfg.MODEL.WEIGHTS)
+
+ self.aug = T.ResizeShortestEdge(
+ [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
+ )
+
+ self.input_format = cfg.INPUT.FORMAT
+ assert self.input_format in ["RGB", "BGR"], self.input_format
+
+ def __call__(self, original_image):
+ """
+ Args:
+ original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
+
+ Returns:
+ predictions (dict):
+ the output of the model for one image only.
+ See :doc:`/tutorials/models` for details about the format.
+ """
+ with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
+ # Apply pre-processing to image.
+ if self.input_format == "RGB":
+ # whether the model expects BGR inputs or RGB
+ original_image = original_image[:, :, ::-1]
+ height, width = original_image.shape[:2]
+ image = self.aug.get_transform(original_image).apply_image(original_image)
+ image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
+
+ inputs = {"image": image, "height": height, "width": width}
+ predictions = self.model([inputs])[0]
+ return predictions
+
+
+class DefaultTrainer(TrainerBase):
+ """
+ A trainer with default training logic. It does the following:
+
+ 1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader
+ defined by the given config. Create a LR scheduler defined by the config.
+ 2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when
+ `resume_or_load` is called.
+ 3. Register a few common hooks defined by the config.
+
+ It is created to simplify the **standard model training workflow** and reduce code boilerplate
+ for users who only need the standard training workflow, with standard features.
+ It means this class makes *many assumptions* about your training logic that
+ may easily become invalid in a new research. In fact, any assumptions beyond those made in the
+ :class:`SimpleTrainer` are too much for research.
+
+ The code of this class has been annotated about restrictive assumptions it makes.
+ When they do not work for you, you're encouraged to:
+
+ 1. Overwrite methods of this class, OR:
+ 2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
+ nothing else. You can then add your own hooks if needed. OR:
+ 3. Write your own training loop similar to `tools/plain_train_net.py`.
+
+ See the :doc:`/tutorials/training` tutorials for more details.
+
+ Note that the behavior of this class, like other functions/classes in
+ this file, is not stable, since it is meant to represent the "common default behavior".
+ It is only guaranteed to work well with the standard models and training workflow in detectron2.
+ To obtain more stable behavior, write your own training logic with other public APIs.
+
+ Examples:
+ ::
+ trainer = DefaultTrainer(cfg)
+ trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
+ trainer.train()
+
+ Attributes:
+ scheduler:
+ checkpointer (DetectionCheckpointer):
+ cfg (CfgNode):
+ """
+
+ def __init__(self, cfg):
+ """
+ Args:
+ cfg (CfgNode):
+ """
+ super().__init__()
+ logger = logging.getLogger("detectron2")
+ if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
+ setup_logger()
+ cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
+
+ # Assume these objects must be constructed in this order.
+ model = self.build_model(cfg)
+ optimizer = self.build_optimizer(cfg, model)
+ data_loader = self.build_train_loader(cfg)
+
+ model = create_ddp_model(model, broadcast_buffers=False)
+ self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
+ model, data_loader, optimizer
+ )
+
+ self.scheduler = self.build_lr_scheduler(cfg, optimizer)
+ self.checkpointer = DetectionCheckpointer(
+ # Assume you want to save checkpoints together with logs/statistics
+ model,
+ cfg.OUTPUT_DIR,
+ trainer=weakref.proxy(self),
+ )
+ self.start_iter = 0
+ self.max_iter = cfg.SOLVER.MAX_ITER
+ self.cfg = cfg
+
+ self.register_hooks(self.build_hooks())
+
+ def resume_or_load(self, resume=True):
+ """
+ If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
+ a `last_checkpoint` file), resume from the file. Resuming means loading all
+ available states (eg. optimizer and scheduler) and update iteration counter
+ from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
+
+ Otherwise, this is considered as an independent training. The method will load model
+ weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
+ from iteration 0.
+
+ Args:
+ resume (bool): whether to do resume or not
+ """
+ self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
+ if resume and self.checkpointer.has_checkpoint():
+ # The checkpoint stores the training iteration that just finished, thus we start
+ # at the next iteration
+ self.start_iter = self.iter + 1
+
+ def build_hooks(self):
+ """
+ Build a list of default hooks, including timing, evaluation,
+ checkpointing, lr scheduling, precise BN, writing events.
+
+ Returns:
+ list[HookBase]:
+ """
+ cfg = self.cfg.clone()
+ cfg.defrost()
+ cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
+
+ ret = [
+ hooks.IterationTimer(),
+ hooks.LRScheduler(),
+ hooks.PreciseBN(
+ # Run at the same freq as (but before) evaluation.
+ cfg.TEST.EVAL_PERIOD,
+ self.model,
+ # Build a new data loader to not affect training
+ self.build_train_loader(cfg),
+ cfg.TEST.PRECISE_BN.NUM_ITER,
+ )
+ if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
+ else None,
+ ]
+
+ # Do PreciseBN before checkpointer, because it updates the model and need to
+ # be saved by checkpointer.
+ # This is not always the best: if checkpointing has a different frequency,
+ # some checkpoints may have more precise statistics than others.
+ if comm.is_main_process():
+ ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
+
+ def test_and_save_results():
+ self._last_eval_results = self.test(self.cfg, self.model)
+ return self._last_eval_results
+
+ # Do evaluation after checkpointer, because then if it fails,
+ # we can use the saved checkpoint to debug.
+ ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
+
+ if comm.is_main_process():
+ # Here the default print/log frequency of each writer is used.
+ # run writers in the end, so that evaluation metrics are written
+ ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
+ return ret
+
+ def build_writers(self):
+ """
+ Build a list of writers to be used using :func:`default_writers()`.
+ If you'd like a different list of writers, you can overwrite it in
+ your trainer.
+
+ Returns:
+ list[EventWriter]: a list of :class:`EventWriter` objects.
+ """
+ return default_writers(self.cfg.OUTPUT_DIR, self.max_iter)
+
+ def train(self):
+ """
+ Run training.
+
+ Returns:
+ OrderedDict of results, if evaluation is enabled. Otherwise None.
+ """
+ super().train(self.start_iter, self.max_iter)
+ if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
+ assert hasattr(
+ self, "_last_eval_results"
+ ), "No evaluation results obtained during training!"
+ verify_results(self.cfg, self._last_eval_results)
+ return self._last_eval_results
+
+ def run_step(self):
+ self._trainer.iter = self.iter
+ self._trainer.run_step()
+
+ def state_dict(self):
+ ret = super().state_dict()
+ ret["_trainer"] = self._trainer.state_dict()
+ return ret
+
+ def load_state_dict(self, state_dict):
+ super().load_state_dict(state_dict)
+ self._trainer.load_state_dict(state_dict["_trainer"])
+
+ @classmethod
+ def build_model(cls, cfg):
+ """
+ Returns:
+ torch.nn.Module:
+
+ It now calls :func:`detectron2.modeling.build_model`.
+ Overwrite it if you'd like a different model.
+ """
+ model = build_model(cfg)
+ logger = logging.getLogger(__name__)
+ logger.info("Model:\n{}".format(model))
+ return model
+
+ @classmethod
+ def build_optimizer(cls, cfg, model):
+ """
+ Returns:
+ torch.optim.Optimizer:
+
+ It now calls :func:`detectron2.solver.build_optimizer`.
+ Overwrite it if you'd like a different optimizer.
+ """
+ return build_optimizer(cfg, model)
+
+ @classmethod
+ def build_lr_scheduler(cls, cfg, optimizer):
+ """
+ It now calls :func:`detectron2.solver.build_lr_scheduler`.
+ Overwrite it if you'd like a different scheduler.
+ """
+ return build_lr_scheduler(cfg, optimizer)
+
+ @classmethod
+ def build_train_loader(cls, cfg):
+ """
+ Returns:
+ iterable
+
+ It now calls :func:`detectron2.data.build_detection_train_loader`.
+ Overwrite it if you'd like a different data loader.
+ """
+ return build_detection_train_loader(cfg)
+
+ @classmethod
+ def build_test_loader(cls, cfg, dataset_name):
+ """
+ Returns:
+ iterable
+
+ It now calls :func:`detectron2.data.build_detection_test_loader`.
+ Overwrite it if you'd like a different data loader.
+ """
+ return build_detection_test_loader(cfg, dataset_name)
+
+ @classmethod
+ def build_evaluator(cls, cfg, dataset_name):
+ """
+ Returns:
+ DatasetEvaluator or None
+
+ It is not implemented by default.
+ """
+ raise NotImplementedError(
+ """
+If you want DefaultTrainer to automatically run evaluation,
+please implement `build_evaluator()` in subclasses (see train_net.py for example).
+Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example).
+"""
+ )
+
+ @classmethod
+ def test(cls, cfg, model, evaluators=None):
+ """
+ Evaluate the given model. The given model is expected to already contain
+ weights to evaluate.
+
+ Args:
+ cfg (CfgNode):
+ model (nn.Module):
+ evaluators (list[DatasetEvaluator] or None): if None, will call
+ :meth:`build_evaluator`. Otherwise, must have the same length as
+ ``cfg.DATASETS.TEST``.
+
+ Returns:
+ dict: a dict of result metrics
+ """
+ logger = logging.getLogger(__name__)
+ if isinstance(evaluators, DatasetEvaluator):
+ evaluators = [evaluators]
+ if evaluators is not None:
+ assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
+ len(cfg.DATASETS.TEST), len(evaluators)
+ )
+
+ results = OrderedDict()
+ for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
+ data_loader = cls.build_test_loader(cfg, dataset_name)
+ # When evaluators are passed in as arguments,
+ # implicitly assume that evaluators can be created before data_loader.
+ if evaluators is not None:
+ evaluator = evaluators[idx]
+ else:
+ try:
+ evaluator = cls.build_evaluator(cfg, dataset_name)
+ except NotImplementedError:
+ logger.warn(
+ "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
+ "or implement its `build_evaluator` method."
+ )
+ results[dataset_name] = {}
+ continue
+ results_i = inference_on_dataset(model, data_loader, evaluator)
+ results[dataset_name] = results_i
+ if comm.is_main_process():
+ assert isinstance(
+ results_i, dict
+ ), "Evaluator must return a dict on the main process. Got {} instead.".format(
+ results_i
+ )
+ logger.info("Evaluation results for {} in csv format:".format(dataset_name))
+ print_csv_format(results_i)
+
+ if len(results) == 1:
+ results = list(results.values())[0]
+ return results
+
+ @staticmethod
+ def auto_scale_workers(cfg, num_workers: int):
+ """
+ When the config is defined for certain number of workers (according to
+ ``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of
+ workers currently in use, returns a new cfg where the total batch size
+ is scaled so that the per-GPU batch size stays the same as the
+ original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``.
+
+ Other config options are also scaled accordingly:
+ * training steps and warmup steps are scaled inverse proportionally.
+ * learning rate are scaled proportionally, following :paper:`ImageNet in 1h`.
+
+ For example, with the original config like the following:
+
+ .. code-block:: yaml
+
+ IMS_PER_BATCH: 16
+ BASE_LR: 0.1
+ REFERENCE_WORLD_SIZE: 8
+ MAX_ITER: 5000
+ STEPS: (4000,)
+ CHECKPOINT_PERIOD: 1000
+
+ When this config is used on 16 GPUs instead of the reference number 8,
+ calling this method will return a new config with:
+
+ .. code-block:: yaml
+
+ IMS_PER_BATCH: 32
+ BASE_LR: 0.2
+ REFERENCE_WORLD_SIZE: 16
+ MAX_ITER: 2500
+ STEPS: (2000,)
+ CHECKPOINT_PERIOD: 500
+
+ Note that both the original config and this new config can be trained on 16 GPUs.
+ It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``).
+
+ Returns:
+ CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``.
+ """
+ old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE
+ if old_world_size == 0 or old_world_size == num_workers:
+ return cfg
+ cfg = cfg.clone()
+ frozen = cfg.is_frozen()
+ cfg.defrost()
+
+ assert (
+ cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0
+ ), "Invalid REFERENCE_WORLD_SIZE in config!"
+ scale = num_workers / old_world_size
+ bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale))
+ lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
+ max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale))
+ warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale))
+ cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS)
+ cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))
+ cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale))
+ cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariant
+ logger = logging.getLogger(__name__)
+ logger.info(
+ f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, "
+ f"max_iter={max_iter}, warmup={warmup_iter}."
+ )
+
+ if frozen:
+ cfg.freeze()
+ return cfg
+
+
+# Access basic attributes from the underlying trainer
+for _attr in ["model", "data_loader", "optimizer"]:
+ setattr(
+ DefaultTrainer,
+ _attr,
+ property(
+ # getter
+ lambda self, x=_attr: getattr(self._trainer, x),
+ # setter
+ lambda self, value, x=_attr: setattr(self._trainer, x, value),
+ ),
+ )
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/engine/hooks.py b/sd-webui-controlnet/annotator/oneformer/detectron2/engine/hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dd43ac77068c908bc13263f1697fa2e3332d7c9
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/engine/hooks.py
@@ -0,0 +1,690 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import datetime
+import itertools
+import logging
+import math
+import operator
+import os
+import tempfile
+import time
+import warnings
+from collections import Counter
+import torch
+from fvcore.common.checkpoint import Checkpointer
+from fvcore.common.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
+from fvcore.common.param_scheduler import ParamScheduler
+from fvcore.common.timer import Timer
+from fvcore.nn.precise_bn import get_bn_modules, update_bn_stats
+
+import annotator.oneformer.detectron2.utils.comm as comm
+from annotator.oneformer.detectron2.evaluation.testing import flatten_results_dict
+from annotator.oneformer.detectron2.solver import LRMultiplier
+from annotator.oneformer.detectron2.solver import LRScheduler as _LRScheduler
+from annotator.oneformer.detectron2.utils.events import EventStorage, EventWriter
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+from .train_loop import HookBase
+
+__all__ = [
+ "CallbackHook",
+ "IterationTimer",
+ "PeriodicWriter",
+ "PeriodicCheckpointer",
+ "BestCheckpointer",
+ "LRScheduler",
+ "AutogradProfiler",
+ "EvalHook",
+ "PreciseBN",
+ "TorchProfiler",
+ "TorchMemoryStats",
+]
+
+
+"""
+Implement some common hooks.
+"""
+
+
+class CallbackHook(HookBase):
+ """
+ Create a hook using callback functions provided by the user.
+ """
+
+ def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
+ """
+ Each argument is a function that takes one argument: the trainer.
+ """
+ self._before_train = before_train
+ self._before_step = before_step
+ self._after_step = after_step
+ self._after_train = after_train
+
+ def before_train(self):
+ if self._before_train:
+ self._before_train(self.trainer)
+
+ def after_train(self):
+ if self._after_train:
+ self._after_train(self.trainer)
+ # The functions may be closures that hold reference to the trainer
+ # Therefore, delete them to avoid circular reference.
+ del self._before_train, self._after_train
+ del self._before_step, self._after_step
+
+ def before_step(self):
+ if self._before_step:
+ self._before_step(self.trainer)
+
+ def after_step(self):
+ if self._after_step:
+ self._after_step(self.trainer)
+
+
+class IterationTimer(HookBase):
+ """
+ Track the time spent for each iteration (each run_step call in the trainer).
+ Print a summary in the end of training.
+
+ This hook uses the time between the call to its :meth:`before_step`
+ and :meth:`after_step` methods.
+ Under the convention that :meth:`before_step` of all hooks should only
+ take negligible amount of time, the :class:`IterationTimer` hook should be
+ placed at the beginning of the list of hooks to obtain accurate timing.
+ """
+
+ def __init__(self, warmup_iter=3):
+ """
+ Args:
+ warmup_iter (int): the number of iterations at the beginning to exclude
+ from timing.
+ """
+ self._warmup_iter = warmup_iter
+ self._step_timer = Timer()
+ self._start_time = time.perf_counter()
+ self._total_timer = Timer()
+
+ def before_train(self):
+ self._start_time = time.perf_counter()
+ self._total_timer.reset()
+ self._total_timer.pause()
+
+ def after_train(self):
+ logger = logging.getLogger(__name__)
+ total_time = time.perf_counter() - self._start_time
+ total_time_minus_hooks = self._total_timer.seconds()
+ hook_time = total_time - total_time_minus_hooks
+
+ num_iter = self.trainer.storage.iter + 1 - self.trainer.start_iter - self._warmup_iter
+
+ if num_iter > 0 and total_time_minus_hooks > 0:
+ # Speed is meaningful only after warmup
+ # NOTE this format is parsed by grep in some scripts
+ logger.info(
+ "Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
+ num_iter,
+ str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
+ total_time_minus_hooks / num_iter,
+ )
+ )
+
+ logger.info(
+ "Total training time: {} ({} on hooks)".format(
+ str(datetime.timedelta(seconds=int(total_time))),
+ str(datetime.timedelta(seconds=int(hook_time))),
+ )
+ )
+
+ def before_step(self):
+ self._step_timer.reset()
+ self._total_timer.resume()
+
+ def after_step(self):
+ # +1 because we're in after_step, the current step is done
+ # but not yet counted
+ iter_done = self.trainer.storage.iter - self.trainer.start_iter + 1
+ if iter_done >= self._warmup_iter:
+ sec = self._step_timer.seconds()
+ self.trainer.storage.put_scalars(time=sec)
+ else:
+ self._start_time = time.perf_counter()
+ self._total_timer.reset()
+
+ self._total_timer.pause()
+
+
+class PeriodicWriter(HookBase):
+ """
+ Write events to EventStorage (by calling ``writer.write()``) periodically.
+
+ It is executed every ``period`` iterations and after the last iteration.
+ Note that ``period`` does not affect how data is smoothed by each writer.
+ """
+
+ def __init__(self, writers, period=20):
+ """
+ Args:
+ writers (list[EventWriter]): a list of EventWriter objects
+ period (int):
+ """
+ self._writers = writers
+ for w in writers:
+ assert isinstance(w, EventWriter), w
+ self._period = period
+
+ def after_step(self):
+ if (self.trainer.iter + 1) % self._period == 0 or (
+ self.trainer.iter == self.trainer.max_iter - 1
+ ):
+ for writer in self._writers:
+ writer.write()
+
+ def after_train(self):
+ for writer in self._writers:
+ # If any new data is found (e.g. produced by other after_train),
+ # write them before closing
+ writer.write()
+ writer.close()
+
+
+class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
+ """
+ Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook.
+
+ Note that when used as a hook,
+ it is unable to save additional data other than what's defined
+ by the given `checkpointer`.
+
+ It is executed every ``period`` iterations and after the last iteration.
+ """
+
+ def before_train(self):
+ self.max_iter = self.trainer.max_iter
+
+ def after_step(self):
+ # No way to use **kwargs
+ self.step(self.trainer.iter)
+
+
+class BestCheckpointer(HookBase):
+ """
+ Checkpoints best weights based off given metric.
+
+ This hook should be used in conjunction to and executed after the hook
+ that produces the metric, e.g. `EvalHook`.
+ """
+
+ def __init__(
+ self,
+ eval_period: int,
+ checkpointer: Checkpointer,
+ val_metric: str,
+ mode: str = "max",
+ file_prefix: str = "model_best",
+ ) -> None:
+ """
+ Args:
+ eval_period (int): the period `EvalHook` is set to run.
+ checkpointer: the checkpointer object used to save checkpoints.
+ val_metric (str): validation metric to track for best checkpoint, e.g. "bbox/AP50"
+ mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be
+ maximized or minimized, e.g. for "bbox/AP50" it should be "max"
+ file_prefix (str): the prefix of checkpoint's filename, defaults to "model_best"
+ """
+ self._logger = logging.getLogger(__name__)
+ self._period = eval_period
+ self._val_metric = val_metric
+ assert mode in [
+ "max",
+ "min",
+ ], f'Mode "{mode}" to `BestCheckpointer` is unknown. It should be one of {"max", "min"}.'
+ if mode == "max":
+ self._compare = operator.gt
+ else:
+ self._compare = operator.lt
+ self._checkpointer = checkpointer
+ self._file_prefix = file_prefix
+ self.best_metric = None
+ self.best_iter = None
+
+ def _update_best(self, val, iteration):
+ if math.isnan(val) or math.isinf(val):
+ return False
+ self.best_metric = val
+ self.best_iter = iteration
+ return True
+
+ def _best_checking(self):
+ metric_tuple = self.trainer.storage.latest().get(self._val_metric)
+ if metric_tuple is None:
+ self._logger.warning(
+ f"Given val metric {self._val_metric} does not seem to be computed/stored."
+ "Will not be checkpointing based on it."
+ )
+ return
+ else:
+ latest_metric, metric_iter = metric_tuple
+
+ if self.best_metric is None:
+ if self._update_best(latest_metric, metric_iter):
+ additional_state = {"iteration": metric_iter}
+ self._checkpointer.save(f"{self._file_prefix}", **additional_state)
+ self._logger.info(
+ f"Saved first model at {self.best_metric:0.5f} @ {self.best_iter} steps"
+ )
+ elif self._compare(latest_metric, self.best_metric):
+ additional_state = {"iteration": metric_iter}
+ self._checkpointer.save(f"{self._file_prefix}", **additional_state)
+ self._logger.info(
+ f"Saved best model as latest eval score for {self._val_metric} is "
+ f"{latest_metric:0.5f}, better than last best score "
+ f"{self.best_metric:0.5f} @ iteration {self.best_iter}."
+ )
+ self._update_best(latest_metric, metric_iter)
+ else:
+ self._logger.info(
+ f"Not saving as latest eval score for {self._val_metric} is {latest_metric:0.5f}, "
+ f"not better than best score {self.best_metric:0.5f} @ iteration {self.best_iter}."
+ )
+
+ def after_step(self):
+ # same conditions as `EvalHook`
+ next_iter = self.trainer.iter + 1
+ if (
+ self._period > 0
+ and next_iter % self._period == 0
+ and next_iter != self.trainer.max_iter
+ ):
+ self._best_checking()
+
+ def after_train(self):
+ # same conditions as `EvalHook`
+ if self.trainer.iter + 1 >= self.trainer.max_iter:
+ self._best_checking()
+
+
+class LRScheduler(HookBase):
+ """
+ A hook which executes a torch builtin LR scheduler and summarizes the LR.
+ It is executed after every iteration.
+ """
+
+ def __init__(self, optimizer=None, scheduler=None):
+ """
+ Args:
+ optimizer (torch.optim.Optimizer):
+ scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler):
+ if a :class:`ParamScheduler` object, it defines the multiplier over the base LR
+ in the optimizer.
+
+ If any argument is not given, will try to obtain it from the trainer.
+ """
+ self._optimizer = optimizer
+ self._scheduler = scheduler
+
+ def before_train(self):
+ self._optimizer = self._optimizer or self.trainer.optimizer
+ if isinstance(self.scheduler, ParamScheduler):
+ self._scheduler = LRMultiplier(
+ self._optimizer,
+ self.scheduler,
+ self.trainer.max_iter,
+ last_iter=self.trainer.iter - 1,
+ )
+ self._best_param_group_id = LRScheduler.get_best_param_group_id(self._optimizer)
+
+ @staticmethod
+ def get_best_param_group_id(optimizer):
+ # NOTE: some heuristics on what LR to summarize
+ # summarize the param group with most parameters
+ largest_group = max(len(g["params"]) for g in optimizer.param_groups)
+
+ if largest_group == 1:
+ # If all groups have one parameter,
+ # then find the most common initial LR, and use it for summary
+ lr_count = Counter([g["lr"] for g in optimizer.param_groups])
+ lr = lr_count.most_common()[0][0]
+ for i, g in enumerate(optimizer.param_groups):
+ if g["lr"] == lr:
+ return i
+ else:
+ for i, g in enumerate(optimizer.param_groups):
+ if len(g["params"]) == largest_group:
+ return i
+
+ def after_step(self):
+ lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
+ self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
+ self.scheduler.step()
+
+ @property
+ def scheduler(self):
+ return self._scheduler or self.trainer.scheduler
+
+ def state_dict(self):
+ if isinstance(self.scheduler, _LRScheduler):
+ return self.scheduler.state_dict()
+ return {}
+
+ def load_state_dict(self, state_dict):
+ if isinstance(self.scheduler, _LRScheduler):
+ logger = logging.getLogger(__name__)
+ logger.info("Loading scheduler from state_dict ...")
+ self.scheduler.load_state_dict(state_dict)
+
+
+class TorchProfiler(HookBase):
+ """
+ A hook which runs `torch.profiler.profile`.
+
+ Examples:
+ ::
+ hooks.TorchProfiler(
+ lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR
+ )
+
+ The above example will run the profiler for iteration 10~20 and dump
+ results to ``OUTPUT_DIR``. We did not profile the first few iterations
+ because they are typically slower than the rest.
+ The result files can be loaded in the ``chrome://tracing`` page in chrome browser,
+ and the tensorboard visualizations can be visualized using
+ ``tensorboard --logdir OUTPUT_DIR/log``
+ """
+
+ def __init__(self, enable_predicate, output_dir, *, activities=None, save_tensorboard=True):
+ """
+ Args:
+ enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
+ and returns whether to enable the profiler.
+ It will be called once every step, and can be used to select which steps to profile.
+ output_dir (str): the output directory to dump tracing files.
+ activities (iterable): same as in `torch.profiler.profile`.
+ save_tensorboard (bool): whether to save tensorboard visualizations at (output_dir)/log/
+ """
+ self._enable_predicate = enable_predicate
+ self._activities = activities
+ self._output_dir = output_dir
+ self._save_tensorboard = save_tensorboard
+
+ def before_step(self):
+ if self._enable_predicate(self.trainer):
+ if self._save_tensorboard:
+ on_trace_ready = torch.profiler.tensorboard_trace_handler(
+ os.path.join(
+ self._output_dir,
+ "log",
+ "profiler-tensorboard-iter{}".format(self.trainer.iter),
+ ),
+ f"worker{comm.get_rank()}",
+ )
+ else:
+ on_trace_ready = None
+ self._profiler = torch.profiler.profile(
+ activities=self._activities,
+ on_trace_ready=on_trace_ready,
+ record_shapes=True,
+ profile_memory=True,
+ with_stack=True,
+ with_flops=True,
+ )
+ self._profiler.__enter__()
+ else:
+ self._profiler = None
+
+ def after_step(self):
+ if self._profiler is None:
+ return
+ self._profiler.__exit__(None, None, None)
+ if not self._save_tensorboard:
+ PathManager.mkdirs(self._output_dir)
+ out_file = os.path.join(
+ self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter)
+ )
+ if "://" not in out_file:
+ self._profiler.export_chrome_trace(out_file)
+ else:
+ # Support non-posix filesystems
+ with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d:
+ tmp_file = os.path.join(d, "tmp.json")
+ self._profiler.export_chrome_trace(tmp_file)
+ with open(tmp_file) as f:
+ content = f.read()
+ with PathManager.open(out_file, "w") as f:
+ f.write(content)
+
+
+class AutogradProfiler(TorchProfiler):
+ """
+ A hook which runs `torch.autograd.profiler.profile`.
+
+ Examples:
+ ::
+ hooks.AutogradProfiler(
+ lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR
+ )
+
+ The above example will run the profiler for iteration 10~20 and dump
+ results to ``OUTPUT_DIR``. We did not profile the first few iterations
+ because they are typically slower than the rest.
+ The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
+
+ Note:
+ When used together with NCCL on older version of GPUs,
+ autograd profiler may cause deadlock because it unnecessarily allocates
+ memory on every device it sees. The memory management calls, if
+ interleaved with NCCL calls, lead to deadlock on GPUs that do not
+ support ``cudaLaunchCooperativeKernelMultiDevice``.
+ """
+
+ def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
+ """
+ Args:
+ enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
+ and returns whether to enable the profiler.
+ It will be called once every step, and can be used to select which steps to profile.
+ output_dir (str): the output directory to dump tracing files.
+ use_cuda (bool): same as in `torch.autograd.profiler.profile`.
+ """
+ warnings.warn("AutogradProfiler has been deprecated in favor of TorchProfiler.")
+ self._enable_predicate = enable_predicate
+ self._use_cuda = use_cuda
+ self._output_dir = output_dir
+
+ def before_step(self):
+ if self._enable_predicate(self.trainer):
+ self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
+ self._profiler.__enter__()
+ else:
+ self._profiler = None
+
+
+class EvalHook(HookBase):
+ """
+ Run an evaluation function periodically, and at the end of training.
+
+ It is executed every ``eval_period`` iterations and after the last iteration.
+ """
+
+ def __init__(self, eval_period, eval_function, eval_after_train=True):
+ """
+ Args:
+ eval_period (int): the period to run `eval_function`. Set to 0 to
+ not evaluate periodically (but still evaluate after the last iteration
+ if `eval_after_train` is True).
+ eval_function (callable): a function which takes no arguments, and
+ returns a nested dict of evaluation metrics.
+ eval_after_train (bool): whether to evaluate after the last iteration
+
+ Note:
+ This hook must be enabled in all or none workers.
+ If you would like only certain workers to perform evaluation,
+ give other workers a no-op function (`eval_function=lambda: None`).
+ """
+ self._period = eval_period
+ self._func = eval_function
+ self._eval_after_train = eval_after_train
+
+ def _do_eval(self):
+ results = self._func()
+
+ if results:
+ assert isinstance(
+ results, dict
+ ), "Eval function must return a dict. Got {} instead.".format(results)
+
+ flattened_results = flatten_results_dict(results)
+ for k, v in flattened_results.items():
+ try:
+ v = float(v)
+ except Exception as e:
+ raise ValueError(
+ "[EvalHook] eval_function should return a nested dict of float. "
+ "Got '{}: {}' instead.".format(k, v)
+ ) from e
+ self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
+
+ # Evaluation may take different time among workers.
+ # A barrier make them start the next iteration together.
+ comm.synchronize()
+
+ def after_step(self):
+ next_iter = self.trainer.iter + 1
+ if self._period > 0 and next_iter % self._period == 0:
+ # do the last eval in after_train
+ if next_iter != self.trainer.max_iter:
+ self._do_eval()
+
+ def after_train(self):
+ # This condition is to prevent the eval from running after a failed training
+ if self._eval_after_train and self.trainer.iter + 1 >= self.trainer.max_iter:
+ self._do_eval()
+ # func is likely a closure that holds reference to the trainer
+ # therefore we clean it to avoid circular reference in the end
+ del self._func
+
+
+class PreciseBN(HookBase):
+ """
+ The standard implementation of BatchNorm uses EMA in inference, which is
+ sometimes suboptimal.
+ This class computes the true average of statistics rather than the moving average,
+ and put true averages to every BN layer in the given model.
+
+ It is executed every ``period`` iterations and after the last iteration.
+ """
+
+ def __init__(self, period, model, data_loader, num_iter):
+ """
+ Args:
+ period (int): the period this hook is run, or 0 to not run during training.
+ The hook will always run in the end of training.
+ model (nn.Module): a module whose all BN layers in training mode will be
+ updated by precise BN.
+ Note that user is responsible for ensuring the BN layers to be
+ updated are in training mode when this hook is triggered.
+ data_loader (iterable): it will produce data to be run by `model(data)`.
+ num_iter (int): number of iterations used to compute the precise
+ statistics.
+ """
+ self._logger = logging.getLogger(__name__)
+ if len(get_bn_modules(model)) == 0:
+ self._logger.info(
+ "PreciseBN is disabled because model does not contain BN layers in training mode."
+ )
+ self._disabled = True
+ return
+
+ self._model = model
+ self._data_loader = data_loader
+ self._num_iter = num_iter
+ self._period = period
+ self._disabled = False
+
+ self._data_iter = None
+
+ def after_step(self):
+ next_iter = self.trainer.iter + 1
+ is_final = next_iter == self.trainer.max_iter
+ if is_final or (self._period > 0 and next_iter % self._period == 0):
+ self.update_stats()
+
+ def update_stats(self):
+ """
+ Update the model with precise statistics. Users can manually call this method.
+ """
+ if self._disabled:
+ return
+
+ if self._data_iter is None:
+ self._data_iter = iter(self._data_loader)
+
+ def data_loader():
+ for num_iter in itertools.count(1):
+ if num_iter % 100 == 0:
+ self._logger.info(
+ "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
+ )
+ # This way we can reuse the same iterator
+ yield next(self._data_iter)
+
+ with EventStorage(): # capture events in a new storage to discard them
+ self._logger.info(
+ "Running precise-BN for {} iterations... ".format(self._num_iter)
+ + "Note that this could produce different statistics every time."
+ )
+ update_bn_stats(self._model, data_loader(), self._num_iter)
+
+
+class TorchMemoryStats(HookBase):
+ """
+ Writes pytorch's cuda memory statistics periodically.
+ """
+
+ def __init__(self, period=20, max_runs=10):
+ """
+ Args:
+ period (int): Output stats each 'period' iterations
+ max_runs (int): Stop the logging after 'max_runs'
+ """
+
+ self._logger = logging.getLogger(__name__)
+ self._period = period
+ self._max_runs = max_runs
+ self._runs = 0
+
+ def after_step(self):
+ if self._runs > self._max_runs:
+ return
+
+ if (self.trainer.iter + 1) % self._period == 0 or (
+ self.trainer.iter == self.trainer.max_iter - 1
+ ):
+ if torch.cuda.is_available():
+ max_reserved_mb = torch.cuda.max_memory_reserved() / 1024.0 / 1024.0
+ reserved_mb = torch.cuda.memory_reserved() / 1024.0 / 1024.0
+ max_allocated_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
+ allocated_mb = torch.cuda.memory_allocated() / 1024.0 / 1024.0
+
+ self._logger.info(
+ (
+ " iter: {} "
+ " max_reserved_mem: {:.0f}MB "
+ " reserved_mem: {:.0f}MB "
+ " max_allocated_mem: {:.0f}MB "
+ " allocated_mem: {:.0f}MB "
+ ).format(
+ self.trainer.iter,
+ max_reserved_mb,
+ reserved_mb,
+ max_allocated_mb,
+ allocated_mb,
+ )
+ )
+
+ self._runs += 1
+ if self._runs == self._max_runs:
+ mem_summary = torch.cuda.memory_summary()
+ self._logger.info("\n" + mem_summary)
+
+ torch.cuda.reset_peak_memory_stats()
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/engine/launch.py b/sd-webui-controlnet/annotator/oneformer/detectron2/engine/launch.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a2d6bcdb5f1906d3eedb04b5aa939f8269f0344
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/engine/launch.py
@@ -0,0 +1,123 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+from datetime import timedelta
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+from annotator.oneformer.detectron2.utils import comm
+
+__all__ = ["DEFAULT_TIMEOUT", "launch"]
+
+DEFAULT_TIMEOUT = timedelta(minutes=30)
+
+
+def _find_free_port():
+ import socket
+
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ # Binding to port 0 will cause the OS to find an available port for us
+ sock.bind(("", 0))
+ port = sock.getsockname()[1]
+ sock.close()
+ # NOTE: there is still a chance the port could be taken by other processes.
+ return port
+
+
+def launch(
+ main_func,
+ # Should be num_processes_per_machine, but kept for compatibility.
+ num_gpus_per_machine,
+ num_machines=1,
+ machine_rank=0,
+ dist_url=None,
+ args=(),
+ timeout=DEFAULT_TIMEOUT,
+):
+ """
+ Launch multi-process or distributed training.
+ This function must be called on all machines involved in the training.
+ It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
+
+ Args:
+ main_func: a function that will be called by `main_func(*args)`
+ num_gpus_per_machine (int): number of processes per machine. When
+ using GPUs, this should be the number of GPUs.
+ num_machines (int): the total number of machines
+ machine_rank (int): the rank of this machine
+ dist_url (str): url to connect to for distributed jobs, including protocol
+ e.g. "tcp://127.0.0.1:8686".
+ Can be set to "auto" to automatically select a free port on localhost
+ timeout (timedelta): timeout of the distributed workers
+ args (tuple): arguments passed to main_func
+ """
+ world_size = num_machines * num_gpus_per_machine
+ if world_size > 1:
+ # https://github.com/pytorch/pytorch/pull/14391
+ # TODO prctl in spawned processes
+
+ if dist_url == "auto":
+ assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
+ port = _find_free_port()
+ dist_url = f"tcp://127.0.0.1:{port}"
+ if num_machines > 1 and dist_url.startswith("file://"):
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
+ )
+
+ mp.start_processes(
+ _distributed_worker,
+ nprocs=num_gpus_per_machine,
+ args=(
+ main_func,
+ world_size,
+ num_gpus_per_machine,
+ machine_rank,
+ dist_url,
+ args,
+ timeout,
+ ),
+ daemon=False,
+ )
+ else:
+ main_func(*args)
+
+
+def _distributed_worker(
+ local_rank,
+ main_func,
+ world_size,
+ num_gpus_per_machine,
+ machine_rank,
+ dist_url,
+ args,
+ timeout=DEFAULT_TIMEOUT,
+):
+ has_gpu = torch.cuda.is_available()
+ if has_gpu:
+ assert num_gpus_per_machine <= torch.cuda.device_count()
+ global_rank = machine_rank * num_gpus_per_machine + local_rank
+ try:
+ dist.init_process_group(
+ backend="NCCL" if has_gpu else "GLOO",
+ init_method=dist_url,
+ world_size=world_size,
+ rank=global_rank,
+ timeout=timeout,
+ )
+ except Exception as e:
+ logger = logging.getLogger(__name__)
+ logger.error("Process group URL: {}".format(dist_url))
+ raise e
+
+ # Setup the local process group.
+ comm.create_local_process_group(num_gpus_per_machine)
+ if has_gpu:
+ torch.cuda.set_device(local_rank)
+
+ # synchronize is needed here to prevent a possible timeout after calling init_process_group
+ # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
+ comm.synchronize()
+
+ main_func(*args)
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/engine/train_loop.py b/sd-webui-controlnet/annotator/oneformer/detectron2/engine/train_loop.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c24c5af94e8f9367a5d577a617ec426292d3f89
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/engine/train_loop.py
@@ -0,0 +1,469 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import logging
+import numpy as np
+import time
+import weakref
+from typing import List, Mapping, Optional
+import torch
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+import annotator.oneformer.detectron2.utils.comm as comm
+from annotator.oneformer.detectron2.utils.events import EventStorage, get_event_storage
+from annotator.oneformer.detectron2.utils.logger import _log_api_usage
+
+__all__ = ["HookBase", "TrainerBase", "SimpleTrainer", "AMPTrainer"]
+
+
+class HookBase:
+ """
+ Base class for hooks that can be registered with :class:`TrainerBase`.
+
+ Each hook can implement 4 methods. The way they are called is demonstrated
+ in the following snippet:
+ ::
+ hook.before_train()
+ for iter in range(start_iter, max_iter):
+ hook.before_step()
+ trainer.run_step()
+ hook.after_step()
+ iter += 1
+ hook.after_train()
+
+ Notes:
+ 1. In the hook method, users can access ``self.trainer`` to access more
+ properties about the context (e.g., model, current iteration, or config
+ if using :class:`DefaultTrainer`).
+
+ 2. A hook that does something in :meth:`before_step` can often be
+ implemented equivalently in :meth:`after_step`.
+ If the hook takes non-trivial time, it is strongly recommended to
+ implement the hook in :meth:`after_step` instead of :meth:`before_step`.
+ The convention is that :meth:`before_step` should only take negligible time.
+
+ Following this convention will allow hooks that do care about the difference
+ between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
+ function properly.
+
+ """
+
+ trainer: "TrainerBase" = None
+ """
+ A weak reference to the trainer object. Set by the trainer when the hook is registered.
+ """
+
+ def before_train(self):
+ """
+ Called before the first iteration.
+ """
+ pass
+
+ def after_train(self):
+ """
+ Called after the last iteration.
+ """
+ pass
+
+ def before_step(self):
+ """
+ Called before each iteration.
+ """
+ pass
+
+ def after_backward(self):
+ """
+ Called after the backward pass of each iteration.
+ """
+ pass
+
+ def after_step(self):
+ """
+ Called after each iteration.
+ """
+ pass
+
+ def state_dict(self):
+ """
+ Hooks are stateless by default, but can be made checkpointable by
+ implementing `state_dict` and `load_state_dict`.
+ """
+ return {}
+
+
+class TrainerBase:
+ """
+ Base class for iterative trainer with hooks.
+
+ The only assumption we made here is: the training runs in a loop.
+ A subclass can implement what the loop is.
+ We made no assumptions about the existence of dataloader, optimizer, model, etc.
+
+ Attributes:
+ iter(int): the current iteration.
+
+ start_iter(int): The iteration to start with.
+ By convention the minimum possible value is 0.
+
+ max_iter(int): The iteration to end training.
+
+ storage(EventStorage): An EventStorage that's opened during the course of training.
+ """
+
+ def __init__(self) -> None:
+ self._hooks: List[HookBase] = []
+ self.iter: int = 0
+ self.start_iter: int = 0
+ self.max_iter: int
+ self.storage: EventStorage
+ _log_api_usage("trainer." + self.__class__.__name__)
+
+ def register_hooks(self, hooks: List[Optional[HookBase]]) -> None:
+ """
+ Register hooks to the trainer. The hooks are executed in the order
+ they are registered.
+
+ Args:
+ hooks (list[Optional[HookBase]]): list of hooks
+ """
+ hooks = [h for h in hooks if h is not None]
+ for h in hooks:
+ assert isinstance(h, HookBase)
+ # To avoid circular reference, hooks and trainer cannot own each other.
+ # This normally does not matter, but will cause memory leak if the
+ # involved objects contain __del__:
+ # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
+ h.trainer = weakref.proxy(self)
+ self._hooks.extend(hooks)
+
+ def train(self, start_iter: int, max_iter: int):
+ """
+ Args:
+ start_iter, max_iter (int): See docs above
+ """
+ logger = logging.getLogger(__name__)
+ logger.info("Starting training from iteration {}".format(start_iter))
+
+ self.iter = self.start_iter = start_iter
+ self.max_iter = max_iter
+
+ with EventStorage(start_iter) as self.storage:
+ try:
+ self.before_train()
+ for self.iter in range(start_iter, max_iter):
+ self.before_step()
+ self.run_step()
+ self.after_step()
+ # self.iter == max_iter can be used by `after_train` to
+ # tell whether the training successfully finished or failed
+ # due to exceptions.
+ self.iter += 1
+ except Exception:
+ logger.exception("Exception during training:")
+ raise
+ finally:
+ self.after_train()
+
+ def before_train(self):
+ for h in self._hooks:
+ h.before_train()
+
+ def after_train(self):
+ self.storage.iter = self.iter
+ for h in self._hooks:
+ h.after_train()
+
+ def before_step(self):
+ # Maintain the invariant that storage.iter == trainer.iter
+ # for the entire execution of each step
+ self.storage.iter = self.iter
+
+ for h in self._hooks:
+ h.before_step()
+
+ def after_backward(self):
+ for h in self._hooks:
+ h.after_backward()
+
+ def after_step(self):
+ for h in self._hooks:
+ h.after_step()
+
+ def run_step(self):
+ raise NotImplementedError
+
+ def state_dict(self):
+ ret = {"iteration": self.iter}
+ hooks_state = {}
+ for h in self._hooks:
+ sd = h.state_dict()
+ if sd:
+ name = type(h).__qualname__
+ if name in hooks_state:
+ # TODO handle repetitive stateful hooks
+ continue
+ hooks_state[name] = sd
+ if hooks_state:
+ ret["hooks"] = hooks_state
+ return ret
+
+ def load_state_dict(self, state_dict):
+ logger = logging.getLogger(__name__)
+ self.iter = state_dict["iteration"]
+ for key, value in state_dict.get("hooks", {}).items():
+ for h in self._hooks:
+ try:
+ name = type(h).__qualname__
+ except AttributeError:
+ continue
+ if name == key:
+ h.load_state_dict(value)
+ break
+ else:
+ logger.warning(f"Cannot find the hook '{key}', its state_dict is ignored.")
+
+
+class SimpleTrainer(TrainerBase):
+ """
+ A simple trainer for the most common type of task:
+ single-cost single-optimizer single-data-source iterative optimization,
+ optionally using data-parallelism.
+ It assumes that every step, you:
+
+ 1. Compute the loss with a data from the data_loader.
+ 2. Compute the gradients with the above loss.
+ 3. Update the model with the optimizer.
+
+ All other tasks during training (checkpointing, logging, evaluation, LR schedule)
+ are maintained by hooks, which can be registered by :meth:`TrainerBase.register_hooks`.
+
+ If you want to do anything fancier than this,
+ either subclass TrainerBase and implement your own `run_step`,
+ or write your own training loop.
+ """
+
+ def __init__(self, model, data_loader, optimizer, gather_metric_period=1):
+ """
+ Args:
+ model: a torch Module. Takes a data from data_loader and returns a
+ dict of losses.
+ data_loader: an iterable. Contains data to be used to call model.
+ optimizer: a torch optimizer.
+ gather_metric_period: an int. Every gather_metric_period iterations
+ the metrics are gathered from all the ranks to rank 0 and logged.
+ """
+ super().__init__()
+
+ """
+ We set the model to training mode in the trainer.
+ However it's valid to train a model that's in eval mode.
+ If you want your model (or a submodule of it) to behave
+ like evaluation during training, you can overwrite its train() method.
+ """
+ model.train()
+
+ self.model = model
+ self.data_loader = data_loader
+ # to access the data loader iterator, call `self._data_loader_iter`
+ self._data_loader_iter_obj = None
+ self.optimizer = optimizer
+ self.gather_metric_period = gather_metric_period
+
+ def run_step(self):
+ """
+ Implement the standard training logic described above.
+ """
+ assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
+ start = time.perf_counter()
+ """
+ If you want to do something with the data, you can wrap the dataloader.
+ """
+ data = next(self._data_loader_iter)
+ data_time = time.perf_counter() - start
+
+ """
+ If you want to do something with the losses, you can wrap the model.
+ """
+ loss_dict = self.model(data)
+ if isinstance(loss_dict, torch.Tensor):
+ losses = loss_dict
+ loss_dict = {"total_loss": loss_dict}
+ else:
+ losses = sum(loss_dict.values())
+
+ """
+ If you need to accumulate gradients or do something similar, you can
+ wrap the optimizer with your custom `zero_grad()` method.
+ """
+ self.optimizer.zero_grad()
+ losses.backward()
+
+ self.after_backward()
+
+ self._write_metrics(loss_dict, data_time)
+
+ """
+ If you need gradient clipping/scaling or other processing, you can
+ wrap the optimizer with your custom `step()` method. But it is
+ suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4
+ """
+ self.optimizer.step()
+
+ @property
+ def _data_loader_iter(self):
+ # only create the data loader iterator when it is used
+ if self._data_loader_iter_obj is None:
+ self._data_loader_iter_obj = iter(self.data_loader)
+ return self._data_loader_iter_obj
+
+ def reset_data_loader(self, data_loader_builder):
+ """
+ Delete and replace the current data loader with a new one, which will be created
+ by calling `data_loader_builder` (without argument).
+ """
+ del self.data_loader
+ data_loader = data_loader_builder()
+ self.data_loader = data_loader
+ self._data_loader_iter_obj = None
+
+ def _write_metrics(
+ self,
+ loss_dict: Mapping[str, torch.Tensor],
+ data_time: float,
+ prefix: str = "",
+ ) -> None:
+ if (self.iter + 1) % self.gather_metric_period == 0:
+ SimpleTrainer.write_metrics(loss_dict, data_time, prefix)
+
+ @staticmethod
+ def write_metrics(
+ loss_dict: Mapping[str, torch.Tensor],
+ data_time: float,
+ prefix: str = "",
+ ) -> None:
+ """
+ Args:
+ loss_dict (dict): dict of scalar losses
+ data_time (float): time taken by the dataloader iteration
+ prefix (str): prefix for logging keys
+ """
+ metrics_dict = {k: v.detach().cpu().item() for k, v in loss_dict.items()}
+ metrics_dict["data_time"] = data_time
+
+ # Gather metrics among all workers for logging
+ # This assumes we do DDP-style training, which is currently the only
+ # supported method in detectron2.
+ all_metrics_dict = comm.gather(metrics_dict)
+
+ if comm.is_main_process():
+ storage = get_event_storage()
+
+ # data_time among workers can have high variance. The actual latency
+ # caused by data_time is the maximum among workers.
+ data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
+ storage.put_scalar("data_time", data_time)
+
+ # average the rest metrics
+ metrics_dict = {
+ k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
+ }
+ total_losses_reduced = sum(metrics_dict.values())
+ if not np.isfinite(total_losses_reduced):
+ raise FloatingPointError(
+ f"Loss became infinite or NaN at iteration={storage.iter}!\n"
+ f"loss_dict = {metrics_dict}"
+ )
+
+ storage.put_scalar("{}total_loss".format(prefix), total_losses_reduced)
+ if len(metrics_dict) > 1:
+ storage.put_scalars(**metrics_dict)
+
+ def state_dict(self):
+ ret = super().state_dict()
+ ret["optimizer"] = self.optimizer.state_dict()
+ return ret
+
+ def load_state_dict(self, state_dict):
+ super().load_state_dict(state_dict)
+ self.optimizer.load_state_dict(state_dict["optimizer"])
+
+
+class AMPTrainer(SimpleTrainer):
+ """
+ Like :class:`SimpleTrainer`, but uses PyTorch's native automatic mixed precision
+ in the training loop.
+ """
+
+ def __init__(
+ self,
+ model,
+ data_loader,
+ optimizer,
+ gather_metric_period=1,
+ grad_scaler=None,
+ precision: torch.dtype = torch.float16,
+ log_grad_scaler: bool = False,
+ ):
+ """
+ Args:
+ model, data_loader, optimizer, gather_metric_period: same as in :class:`SimpleTrainer`.
+ grad_scaler: torch GradScaler to automatically scale gradients.
+ precision: torch.dtype as the target precision to cast to in computations
+ """
+ unsupported = "AMPTrainer does not support single-process multi-device training!"
+ if isinstance(model, DistributedDataParallel):
+ assert not (model.device_ids and len(model.device_ids) > 1), unsupported
+ assert not isinstance(model, DataParallel), unsupported
+
+ super().__init__(model, data_loader, optimizer, gather_metric_period)
+
+ if grad_scaler is None:
+ from torch.cuda.amp import GradScaler
+
+ grad_scaler = GradScaler()
+ self.grad_scaler = grad_scaler
+ self.precision = precision
+ self.log_grad_scaler = log_grad_scaler
+
+ def run_step(self):
+ """
+ Implement the AMP training logic.
+ """
+ assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
+ assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
+ from torch.cuda.amp import autocast
+
+ start = time.perf_counter()
+ data = next(self._data_loader_iter)
+ data_time = time.perf_counter() - start
+
+ with autocast(dtype=self.precision):
+ loss_dict = self.model(data)
+ if isinstance(loss_dict, torch.Tensor):
+ losses = loss_dict
+ loss_dict = {"total_loss": loss_dict}
+ else:
+ losses = sum(loss_dict.values())
+
+ self.optimizer.zero_grad()
+ self.grad_scaler.scale(losses).backward()
+
+ if self.log_grad_scaler:
+ storage = get_event_storage()
+ storage.put_scalar("[metric]grad_scaler", self.grad_scaler.get_scale())
+
+ self.after_backward()
+
+ self._write_metrics(loss_dict, data_time)
+
+ self.grad_scaler.step(self.optimizer)
+ self.grad_scaler.update()
+
+ def state_dict(self):
+ ret = super().state_dict()
+ ret["grad_scaler"] = self.grad_scaler.state_dict()
+ return ret
+
+ def load_state_dict(self, state_dict):
+ super().load_state_dict(state_dict)
+ self.grad_scaler.load_state_dict(state_dict["grad_scaler"])
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/__init__.py b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d96609e8f2261a6800fe85fcf3e1eaeaa44455c6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from .cityscapes_evaluation import CityscapesInstanceEvaluator, CityscapesSemSegEvaluator
+from .coco_evaluation import COCOEvaluator
+from .rotated_coco_evaluation import RotatedCOCOEvaluator
+from .evaluator import DatasetEvaluator, DatasetEvaluators, inference_context, inference_on_dataset
+from .lvis_evaluation import LVISEvaluator
+from .panoptic_evaluation import COCOPanopticEvaluator
+from .pascal_voc_evaluation import PascalVOCDetectionEvaluator
+from .sem_seg_evaluation import SemSegEvaluator
+from .testing import print_csv_format, verify_results
+
+__all__ = [k for k in globals().keys() if not k.startswith("_")]
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/cityscapes_evaluation.py b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/cityscapes_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5be637dc87b5ca8645563a4a921144f6c5fd877
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/cityscapes_evaluation.py
@@ -0,0 +1,197 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import glob
+import logging
+import numpy as np
+import os
+import tempfile
+from collections import OrderedDict
+import torch
+from PIL import Image
+
+from annotator.oneformer.detectron2.data import MetadataCatalog
+from annotator.oneformer.detectron2.utils import comm
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+from .evaluator import DatasetEvaluator
+
+
+class CityscapesEvaluator(DatasetEvaluator):
+ """
+ Base class for evaluation using cityscapes API.
+ """
+
+ def __init__(self, dataset_name):
+ """
+ Args:
+ dataset_name (str): the name of the dataset.
+ It must have the following metadata associated with it:
+ "thing_classes", "gt_dir".
+ """
+ self._metadata = MetadataCatalog.get(dataset_name)
+ self._cpu_device = torch.device("cpu")
+ self._logger = logging.getLogger(__name__)
+
+ def reset(self):
+ self._working_dir = tempfile.TemporaryDirectory(prefix="cityscapes_eval_")
+ self._temp_dir = self._working_dir.name
+ # All workers will write to the same results directory
+ # TODO this does not work in distributed training
+ assert (
+ comm.get_local_size() == comm.get_world_size()
+ ), "CityscapesEvaluator currently do not work with multiple machines."
+ self._temp_dir = comm.all_gather(self._temp_dir)[0]
+ if self._temp_dir != self._working_dir.name:
+ self._working_dir.cleanup()
+ self._logger.info(
+ "Writing cityscapes results to temporary directory {} ...".format(self._temp_dir)
+ )
+
+
+class CityscapesInstanceEvaluator(CityscapesEvaluator):
+ """
+ Evaluate instance segmentation results on cityscapes dataset using cityscapes API.
+
+ Note:
+ * It does not work in multi-machine distributed training.
+ * It contains a synchronization, therefore has to be used on all ranks.
+ * Only the main process runs evaluation.
+ """
+
+ def process(self, inputs, outputs):
+ from cityscapesscripts.helpers.labels import name2label
+
+ for input, output in zip(inputs, outputs):
+ file_name = input["file_name"]
+ basename = os.path.splitext(os.path.basename(file_name))[0]
+ pred_txt = os.path.join(self._temp_dir, basename + "_pred.txt")
+
+ if "instances" in output:
+ output = output["instances"].to(self._cpu_device)
+ num_instances = len(output)
+ with open(pred_txt, "w") as fout:
+ for i in range(num_instances):
+ pred_class = output.pred_classes[i]
+ classes = self._metadata.thing_classes[pred_class]
+ class_id = name2label[classes].id
+ score = output.scores[i]
+ mask = output.pred_masks[i].numpy().astype("uint8")
+ png_filename = os.path.join(
+ self._temp_dir, basename + "_{}_{}.png".format(i, classes)
+ )
+
+ Image.fromarray(mask * 255).save(png_filename)
+ fout.write(
+ "{} {} {}\n".format(os.path.basename(png_filename), class_id, score)
+ )
+ else:
+ # Cityscapes requires a prediction file for every ground truth image.
+ with open(pred_txt, "w") as fout:
+ pass
+
+ def evaluate(self):
+ """
+ Returns:
+ dict: has a key "segm", whose value is a dict of "AP" and "AP50".
+ """
+ comm.synchronize()
+ if comm.get_rank() > 0:
+ return
+ import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as cityscapes_eval
+
+ self._logger.info("Evaluating results under {} ...".format(self._temp_dir))
+
+ # set some global states in cityscapes evaluation API, before evaluating
+ cityscapes_eval.args.predictionPath = os.path.abspath(self._temp_dir)
+ cityscapes_eval.args.predictionWalk = None
+ cityscapes_eval.args.JSONOutput = False
+ cityscapes_eval.args.colorized = False
+ cityscapes_eval.args.gtInstancesFile = os.path.join(self._temp_dir, "gtInstances.json")
+
+ # These lines are adopted from
+ # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa
+ gt_dir = PathManager.get_local_path(self._metadata.gt_dir)
+ groundTruthImgList = glob.glob(os.path.join(gt_dir, "*", "*_gtFine_instanceIds.png"))
+ assert len(
+ groundTruthImgList
+ ), "Cannot find any ground truth images to use for evaluation. Searched for: {}".format(
+ cityscapes_eval.args.groundTruthSearch
+ )
+ predictionImgList = []
+ for gt in groundTruthImgList:
+ predictionImgList.append(cityscapes_eval.getPrediction(gt, cityscapes_eval.args))
+ results = cityscapes_eval.evaluateImgLists(
+ predictionImgList, groundTruthImgList, cityscapes_eval.args
+ )["averages"]
+
+ ret = OrderedDict()
+ ret["segm"] = {"AP": results["allAp"] * 100, "AP50": results["allAp50%"] * 100}
+ self._working_dir.cleanup()
+ return ret
+
+
+class CityscapesSemSegEvaluator(CityscapesEvaluator):
+ """
+ Evaluate semantic segmentation results on cityscapes dataset using cityscapes API.
+
+ Note:
+ * It does not work in multi-machine distributed training.
+ * It contains a synchronization, therefore has to be used on all ranks.
+ * Only the main process runs evaluation.
+ """
+
+ def process(self, inputs, outputs):
+ from cityscapesscripts.helpers.labels import trainId2label
+
+ for input, output in zip(inputs, outputs):
+ file_name = input["file_name"]
+ basename = os.path.splitext(os.path.basename(file_name))[0]
+ pred_filename = os.path.join(self._temp_dir, basename + "_pred.png")
+
+ output = output["sem_seg"].argmax(dim=0).to(self._cpu_device).numpy()
+ pred = 255 * np.ones(output.shape, dtype=np.uint8)
+ for train_id, label in trainId2label.items():
+ if label.ignoreInEval:
+ continue
+ pred[output == train_id] = label.id
+ Image.fromarray(pred).save(pred_filename)
+
+ def evaluate(self):
+ comm.synchronize()
+ if comm.get_rank() > 0:
+ return
+ # Load the Cityscapes eval script *after* setting the required env var,
+ # since the script reads CITYSCAPES_DATASET into global variables at load time.
+ import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as cityscapes_eval
+
+ self._logger.info("Evaluating results under {} ...".format(self._temp_dir))
+
+ # set some global states in cityscapes evaluation API, before evaluating
+ cityscapes_eval.args.predictionPath = os.path.abspath(self._temp_dir)
+ cityscapes_eval.args.predictionWalk = None
+ cityscapes_eval.args.JSONOutput = False
+ cityscapes_eval.args.colorized = False
+
+ # These lines are adopted from
+ # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalPixelLevelSemanticLabeling.py # noqa
+ gt_dir = PathManager.get_local_path(self._metadata.gt_dir)
+ groundTruthImgList = glob.glob(os.path.join(gt_dir, "*", "*_gtFine_labelIds.png"))
+ assert len(
+ groundTruthImgList
+ ), "Cannot find any ground truth images to use for evaluation. Searched for: {}".format(
+ cityscapes_eval.args.groundTruthSearch
+ )
+ predictionImgList = []
+ for gt in groundTruthImgList:
+ predictionImgList.append(cityscapes_eval.getPrediction(cityscapes_eval.args, gt))
+ results = cityscapes_eval.evaluateImgLists(
+ predictionImgList, groundTruthImgList, cityscapes_eval.args
+ )
+ ret = OrderedDict()
+ ret["sem_seg"] = {
+ "IoU": 100.0 * results["averageScoreClasses"],
+ "iIoU": 100.0 * results["averageScoreInstClasses"],
+ "IoU_sup": 100.0 * results["averageScoreCategories"],
+ "iIoU_sup": 100.0 * results["averageScoreInstCategories"],
+ }
+ self._working_dir.cleanup()
+ return ret
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/coco_evaluation.py b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/coco_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdc41798537d3b2e6fc7096c9f4bebd724f1e395
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/coco_evaluation.py
@@ -0,0 +1,722 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import contextlib
+import copy
+import io
+import itertools
+import json
+import logging
+import numpy as np
+import os
+import pickle
+from collections import OrderedDict
+import annotator.oneformer.pycocotools.mask as mask_util
+import torch
+from annotator.oneformer.pycocotools.coco import COCO
+from annotator.oneformer.pycocotools.cocoeval import COCOeval
+from tabulate import tabulate
+
+import annotator.oneformer.detectron2.utils.comm as comm
+from annotator.oneformer.detectron2.config import CfgNode
+from annotator.oneformer.detectron2.data import MetadataCatalog
+from annotator.oneformer.detectron2.data.datasets.coco import convert_to_coco_json
+from annotator.oneformer.detectron2.structures import Boxes, BoxMode, pairwise_iou
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+from annotator.oneformer.detectron2.utils.logger import create_small_table
+
+from .evaluator import DatasetEvaluator
+
+try:
+ from annotator.oneformer.detectron2.evaluation.fast_eval_api import COCOeval_opt
+except ImportError:
+ COCOeval_opt = COCOeval
+
+
+class COCOEvaluator(DatasetEvaluator):
+ """
+ Evaluate AR for object proposals, AP for instance detection/segmentation, AP
+ for keypoint detection outputs using COCO's metrics.
+ See http://cocodataset.org/#detection-eval and
+ http://cocodataset.org/#keypoints-eval to understand its metrics.
+ The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
+ the metric cannot be computed (e.g. due to no predictions made).
+
+ In addition to COCO, this evaluator is able to support any bounding box detection,
+ instance segmentation, or keypoint detection dataset.
+ """
+
+ def __init__(
+ self,
+ dataset_name,
+ tasks=None,
+ distributed=True,
+ output_dir=None,
+ *,
+ max_dets_per_image=None,
+ use_fast_impl=True,
+ kpt_oks_sigmas=(),
+ allow_cached_coco=True,
+ ):
+ """
+ Args:
+ dataset_name (str): name of the dataset to be evaluated.
+ It must have either the following corresponding metadata:
+
+ "json_file": the path to the COCO format annotation
+
+ Or it must be in detectron2's standard dataset format
+ so it can be converted to COCO format automatically.
+ tasks (tuple[str]): tasks that can be evaluated under the given
+ configuration. A task is one of "bbox", "segm", "keypoints".
+ By default, will infer this automatically from predictions.
+ distributed (True): if True, will collect results from all ranks and run evaluation
+ in the main process.
+ Otherwise, will only evaluate the results in the current process.
+ output_dir (str): optional, an output directory to dump all
+ results predicted on the dataset. The dump contains two files:
+
+ 1. "instances_predictions.pth" a file that can be loaded with `torch.load` and
+ contains all the results in the format they are produced by the model.
+ 2. "coco_instances_results.json" a json file in COCO's result format.
+ max_dets_per_image (int): limit on the maximum number of detections per image.
+ By default in COCO, this limit is to 100, but this can be customized
+ to be greater, as is needed in evaluation metrics AP fixed and AP pool
+ (see https://arxiv.org/pdf/2102.01066.pdf)
+ This doesn't affect keypoint evaluation.
+ use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP.
+ Although the results should be very close to the official implementation in COCO
+ API, it is still recommended to compute results with the official API for use in
+ papers. The faster implementation also uses more RAM.
+ kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS.
+ See http://cocodataset.org/#keypoints-eval
+ When empty, it will use the defaults in COCO.
+ Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
+ allow_cached_coco (bool): Whether to use cached coco json from previous validation
+ runs. You should set this to False if you need to use different validation data.
+ Defaults to True.
+ """
+ self._logger = logging.getLogger(__name__)
+ self._distributed = distributed
+ self._output_dir = output_dir
+
+ if use_fast_impl and (COCOeval_opt is COCOeval):
+ self._logger.info("Fast COCO eval is not built. Falling back to official COCO eval.")
+ use_fast_impl = False
+ self._use_fast_impl = use_fast_impl
+
+ # COCOeval requires the limit on the number of detections per image (maxDets) to be a list
+ # with at least 3 elements. The default maxDets in COCOeval is [1, 10, 100], in which the
+ # 3rd element (100) is used as the limit on the number of detections per image when
+ # evaluating AP. COCOEvaluator expects an integer for max_dets_per_image, so for COCOeval,
+ # we reformat max_dets_per_image into [1, 10, max_dets_per_image], based on the defaults.
+ if max_dets_per_image is None:
+ max_dets_per_image = [1, 10, 100]
+ else:
+ max_dets_per_image = [1, 10, max_dets_per_image]
+ self._max_dets_per_image = max_dets_per_image
+
+ if tasks is not None and isinstance(tasks, CfgNode):
+ kpt_oks_sigmas = (
+ tasks.TEST.KEYPOINT_OKS_SIGMAS if not kpt_oks_sigmas else kpt_oks_sigmas
+ )
+ self._logger.warn(
+ "COCO Evaluator instantiated using config, this is deprecated behavior."
+ " Please pass in explicit arguments instead."
+ )
+ self._tasks = None # Infering it from predictions should be better
+ else:
+ self._tasks = tasks
+
+ self._cpu_device = torch.device("cpu")
+
+ self._metadata = MetadataCatalog.get(dataset_name)
+ if not hasattr(self._metadata, "json_file"):
+ if output_dir is None:
+ raise ValueError(
+ "output_dir must be provided to COCOEvaluator "
+ "for datasets not in COCO format."
+ )
+ self._logger.info(f"Trying to convert '{dataset_name}' to COCO format ...")
+
+ cache_path = os.path.join(output_dir, f"{dataset_name}_coco_format.json")
+ self._metadata.json_file = cache_path
+ convert_to_coco_json(dataset_name, cache_path, allow_cached=allow_cached_coco)
+
+ json_file = PathManager.get_local_path(self._metadata.json_file)
+ with contextlib.redirect_stdout(io.StringIO()):
+ self._coco_api = COCO(json_file)
+
+ # Test set json files do not contain annotations (evaluation must be
+ # performed using the COCO evaluation server).
+ self._do_evaluation = "annotations" in self._coco_api.dataset
+ if self._do_evaluation:
+ self._kpt_oks_sigmas = kpt_oks_sigmas
+
+ def reset(self):
+ self._predictions = []
+
+ def process(self, inputs, outputs):
+ """
+ Args:
+ inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
+ It is a list of dict. Each dict corresponds to an image and
+ contains keys like "height", "width", "file_name", "image_id".
+ outputs: the outputs of a COCO model. It is a list of dicts with key
+ "instances" that contains :class:`Instances`.
+ """
+ for input, output in zip(inputs, outputs):
+ prediction = {"image_id": input["image_id"]}
+
+ if "instances" in output:
+ instances = output["instances"].to(self._cpu_device)
+ prediction["instances"] = instances_to_coco_json(instances, input["image_id"])
+ if "proposals" in output:
+ prediction["proposals"] = output["proposals"].to(self._cpu_device)
+ if len(prediction) > 1:
+ self._predictions.append(prediction)
+
+ def evaluate(self, img_ids=None):
+ """
+ Args:
+ img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
+ """
+ if self._distributed:
+ comm.synchronize()
+ predictions = comm.gather(self._predictions, dst=0)
+ predictions = list(itertools.chain(*predictions))
+
+ if not comm.is_main_process():
+ return {}
+ else:
+ predictions = self._predictions
+
+ if len(predictions) == 0:
+ self._logger.warning("[COCOEvaluator] Did not receive valid predictions.")
+ return {}
+
+ if self._output_dir:
+ PathManager.mkdirs(self._output_dir)
+ file_path = os.path.join(self._output_dir, "instances_predictions.pth")
+ with PathManager.open(file_path, "wb") as f:
+ torch.save(predictions, f)
+
+ self._results = OrderedDict()
+ if "proposals" in predictions[0]:
+ self._eval_box_proposals(predictions)
+ if "instances" in predictions[0]:
+ self._eval_predictions(predictions, img_ids=img_ids)
+ # Copy so the caller can do whatever with results
+ return copy.deepcopy(self._results)
+
+ def _tasks_from_predictions(self, predictions):
+ """
+ Get COCO API "tasks" (i.e. iou_type) from COCO-format predictions.
+ """
+ tasks = {"bbox"}
+ for pred in predictions:
+ if "segmentation" in pred:
+ tasks.add("segm")
+ if "keypoints" in pred:
+ tasks.add("keypoints")
+ return sorted(tasks)
+
+ def _eval_predictions(self, predictions, img_ids=None):
+ """
+ Evaluate predictions. Fill self._results with the metrics of the tasks.
+ """
+ self._logger.info("Preparing results for COCO format ...")
+ coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
+ tasks = self._tasks or self._tasks_from_predictions(coco_results)
+
+ # unmap the category ids for COCO
+ if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
+ dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
+ all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
+ num_classes = len(all_contiguous_ids)
+ assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
+
+ reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
+ for result in coco_results:
+ category_id = result["category_id"]
+ assert category_id < num_classes, (
+ f"A prediction has class={category_id}, "
+ f"but the dataset only has {num_classes} classes and "
+ f"predicted class id should be in [0, {num_classes - 1}]."
+ )
+ result["category_id"] = reverse_id_mapping[category_id]
+
+ if self._output_dir:
+ file_path = os.path.join(self._output_dir, "coco_instances_results.json")
+ self._logger.info("Saving results to {}".format(file_path))
+ with PathManager.open(file_path, "w") as f:
+ f.write(json.dumps(coco_results))
+ f.flush()
+
+ if not self._do_evaluation:
+ self._logger.info("Annotations are not available for evaluation.")
+ return
+
+ self._logger.info(
+ "Evaluating predictions with {} COCO API...".format(
+ "unofficial" if self._use_fast_impl else "official"
+ )
+ )
+ for task in sorted(tasks):
+ assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
+ coco_eval = (
+ _evaluate_predictions_on_coco(
+ self._coco_api,
+ coco_results,
+ task,
+ kpt_oks_sigmas=self._kpt_oks_sigmas,
+ cocoeval_fn=COCOeval_opt if self._use_fast_impl else COCOeval,
+ img_ids=img_ids,
+ max_dets_per_image=self._max_dets_per_image,
+ )
+ if len(coco_results) > 0
+ else None # cocoapi does not handle empty results very well
+ )
+
+ res = self._derive_coco_results(
+ coco_eval, task, class_names=self._metadata.get("thing_classes")
+ )
+ self._results[task] = res
+
+ def _eval_box_proposals(self, predictions):
+ """
+ Evaluate the box proposals in predictions.
+ Fill self._results with the metrics for "box_proposals" task.
+ """
+ if self._output_dir:
+ # Saving generated box proposals to file.
+ # Predicted box_proposals are in XYXY_ABS mode.
+ bbox_mode = BoxMode.XYXY_ABS.value
+ ids, boxes, objectness_logits = [], [], []
+ for prediction in predictions:
+ ids.append(prediction["image_id"])
+ boxes.append(prediction["proposals"].proposal_boxes.tensor.numpy())
+ objectness_logits.append(prediction["proposals"].objectness_logits.numpy())
+
+ proposal_data = {
+ "boxes": boxes,
+ "objectness_logits": objectness_logits,
+ "ids": ids,
+ "bbox_mode": bbox_mode,
+ }
+ with PathManager.open(os.path.join(self._output_dir, "box_proposals.pkl"), "wb") as f:
+ pickle.dump(proposal_data, f)
+
+ if not self._do_evaluation:
+ self._logger.info("Annotations are not available for evaluation.")
+ return
+
+ self._logger.info("Evaluating bbox proposals ...")
+ res = {}
+ areas = {"all": "", "small": "s", "medium": "m", "large": "l"}
+ for limit in [100, 1000]:
+ for area, suffix in areas.items():
+ stats = _evaluate_box_proposals(predictions, self._coco_api, area=area, limit=limit)
+ key = "AR{}@{:d}".format(suffix, limit)
+ res[key] = float(stats["ar"].item() * 100)
+ self._logger.info("Proposal metrics: \n" + create_small_table(res))
+ self._results["box_proposals"] = res
+
+ def _derive_coco_results(self, coco_eval, iou_type, class_names=None):
+ """
+ Derive the desired score numbers from summarized COCOeval.
+
+ Args:
+ coco_eval (None or COCOEval): None represents no predictions from model.
+ iou_type (str):
+ class_names (None or list[str]): if provided, will use it to predict
+ per-category AP.
+
+ Returns:
+ a dict of {metric name: score}
+ """
+
+ metrics = {
+ "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
+ "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
+ "keypoints": ["AP", "AP50", "AP75", "APm", "APl"],
+ }[iou_type]
+
+ if coco_eval is None:
+ self._logger.warn("No predictions from the model!")
+ return {metric: float("nan") for metric in metrics}
+
+ # the standard metrics
+ results = {
+ metric: float(coco_eval.stats[idx] * 100 if coco_eval.stats[idx] >= 0 else "nan")
+ for idx, metric in enumerate(metrics)
+ }
+ self._logger.info(
+ "Evaluation results for {}: \n".format(iou_type) + create_small_table(results)
+ )
+ if not np.isfinite(sum(results.values())):
+ self._logger.info("Some metrics cannot be computed and is shown as NaN.")
+
+ if class_names is None or len(class_names) <= 1:
+ return results
+ # Compute per-category AP
+ # from https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L222-L252 # noqa
+ precisions = coco_eval.eval["precision"]
+ # precision has dims (iou, recall, cls, area range, max dets)
+ assert len(class_names) == precisions.shape[2]
+
+ results_per_category = []
+ for idx, name in enumerate(class_names):
+ # area range index 0: all area ranges
+ # max dets index -1: typically 100 per image
+ precision = precisions[:, :, idx, 0, -1]
+ precision = precision[precision > -1]
+ ap = np.mean(precision) if precision.size else float("nan")
+ results_per_category.append(("{}".format(name), float(ap * 100)))
+
+ # tabulate it
+ N_COLS = min(6, len(results_per_category) * 2)
+ results_flatten = list(itertools.chain(*results_per_category))
+ results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)])
+ table = tabulate(
+ results_2d,
+ tablefmt="pipe",
+ floatfmt=".3f",
+ headers=["category", "AP"] * (N_COLS // 2),
+ numalign="left",
+ )
+ self._logger.info("Per-category {} AP: \n".format(iou_type) + table)
+
+ results.update({"AP-" + name: ap for name, ap in results_per_category})
+ return results
+
+
+def instances_to_coco_json(instances, img_id):
+ """
+ Dump an "Instances" object to a COCO-format json that's used for evaluation.
+
+ Args:
+ instances (Instances):
+ img_id (int): the image id
+
+ Returns:
+ list[dict]: list of json annotations in COCO format.
+ """
+ num_instance = len(instances)
+ if num_instance == 0:
+ return []
+
+ boxes = instances.pred_boxes.tensor.numpy()
+ boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
+ boxes = boxes.tolist()
+ scores = instances.scores.tolist()
+ classes = instances.pred_classes.tolist()
+
+ has_mask = instances.has("pred_masks")
+ if has_mask:
+ # use RLE to encode the masks, because they are too large and takes memory
+ # since this evaluator stores outputs of the entire dataset
+ rles = [
+ mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
+ for mask in instances.pred_masks
+ ]
+ for rle in rles:
+ # "counts" is an array encoded by mask_util as a byte-stream. Python3's
+ # json writer which always produces strings cannot serialize a bytestream
+ # unless you decode it. Thankfully, utf-8 works out (which is also what
+ # the annotator.oneformer.pycocotools/_mask.pyx does).
+ rle["counts"] = rle["counts"].decode("utf-8")
+
+ has_keypoints = instances.has("pred_keypoints")
+ if has_keypoints:
+ keypoints = instances.pred_keypoints
+
+ results = []
+ for k in range(num_instance):
+ result = {
+ "image_id": img_id,
+ "category_id": classes[k],
+ "bbox": boxes[k],
+ "score": scores[k],
+ }
+ if has_mask:
+ result["segmentation"] = rles[k]
+ if has_keypoints:
+ # In COCO annotations,
+ # keypoints coordinates are pixel indices.
+ # However our predictions are floating point coordinates.
+ # Therefore we subtract 0.5 to be consistent with the annotation format.
+ # This is the inverse of data loading logic in `datasets/coco.py`.
+ keypoints[k][:, :2] -= 0.5
+ result["keypoints"] = keypoints[k].flatten().tolist()
+ results.append(result)
+ return results
+
+
+# inspired from Detectron:
+# https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L255 # noqa
+def _evaluate_box_proposals(dataset_predictions, coco_api, thresholds=None, area="all", limit=None):
+ """
+ Evaluate detection proposal recall metrics. This function is a much
+ faster alternative to the official COCO API recall evaluation code. However,
+ it produces slightly different results.
+ """
+ # Record max overlap value for each gt box
+ # Return vector of overlap values
+ areas = {
+ "all": 0,
+ "small": 1,
+ "medium": 2,
+ "large": 3,
+ "96-128": 4,
+ "128-256": 5,
+ "256-512": 6,
+ "512-inf": 7,
+ }
+ area_ranges = [
+ [0**2, 1e5**2], # all
+ [0**2, 32**2], # small
+ [32**2, 96**2], # medium
+ [96**2, 1e5**2], # large
+ [96**2, 128**2], # 96-128
+ [128**2, 256**2], # 128-256
+ [256**2, 512**2], # 256-512
+ [512**2, 1e5**2],
+ ] # 512-inf
+ assert area in areas, "Unknown area range: {}".format(area)
+ area_range = area_ranges[areas[area]]
+ gt_overlaps = []
+ num_pos = 0
+
+ for prediction_dict in dataset_predictions:
+ predictions = prediction_dict["proposals"]
+
+ # sort predictions in descending order
+ # TODO maybe remove this and make it explicit in the documentation
+ inds = predictions.objectness_logits.sort(descending=True)[1]
+ predictions = predictions[inds]
+
+ ann_ids = coco_api.getAnnIds(imgIds=prediction_dict["image_id"])
+ anno = coco_api.loadAnns(ann_ids)
+ gt_boxes = [
+ BoxMode.convert(obj["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
+ for obj in anno
+ if obj["iscrowd"] == 0
+ ]
+ gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes
+ gt_boxes = Boxes(gt_boxes)
+ gt_areas = torch.as_tensor([obj["area"] for obj in anno if obj["iscrowd"] == 0])
+
+ if len(gt_boxes) == 0 or len(predictions) == 0:
+ continue
+
+ valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1])
+ gt_boxes = gt_boxes[valid_gt_inds]
+
+ num_pos += len(gt_boxes)
+
+ if len(gt_boxes) == 0:
+ continue
+
+ if limit is not None and len(predictions) > limit:
+ predictions = predictions[:limit]
+
+ overlaps = pairwise_iou(predictions.proposal_boxes, gt_boxes)
+
+ _gt_overlaps = torch.zeros(len(gt_boxes))
+ for j in range(min(len(predictions), len(gt_boxes))):
+ # find which proposal box maximally covers each gt box
+ # and get the iou amount of coverage for each gt box
+ max_overlaps, argmax_overlaps = overlaps.max(dim=0)
+
+ # find which gt box is 'best' covered (i.e. 'best' = most iou)
+ gt_ovr, gt_ind = max_overlaps.max(dim=0)
+ assert gt_ovr >= 0
+ # find the proposal box that covers the best covered gt box
+ box_ind = argmax_overlaps[gt_ind]
+ # record the iou coverage of this gt box
+ _gt_overlaps[j] = overlaps[box_ind, gt_ind]
+ assert _gt_overlaps[j] == gt_ovr
+ # mark the proposal box and the gt box as used
+ overlaps[box_ind, :] = -1
+ overlaps[:, gt_ind] = -1
+
+ # append recorded iou coverage level
+ gt_overlaps.append(_gt_overlaps)
+ gt_overlaps = (
+ torch.cat(gt_overlaps, dim=0) if len(gt_overlaps) else torch.zeros(0, dtype=torch.float32)
+ )
+ gt_overlaps, _ = torch.sort(gt_overlaps)
+
+ if thresholds is None:
+ step = 0.05
+ thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
+ recalls = torch.zeros_like(thresholds)
+ # compute recall for each iou threshold
+ for i, t in enumerate(thresholds):
+ recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
+ # ar = 2 * np.trapz(recalls, thresholds)
+ ar = recalls.mean()
+ return {
+ "ar": ar,
+ "recalls": recalls,
+ "thresholds": thresholds,
+ "gt_overlaps": gt_overlaps,
+ "num_pos": num_pos,
+ }
+
+
+def _evaluate_predictions_on_coco(
+ coco_gt,
+ coco_results,
+ iou_type,
+ kpt_oks_sigmas=None,
+ cocoeval_fn=COCOeval_opt,
+ img_ids=None,
+ max_dets_per_image=None,
+):
+ """
+ Evaluate the coco results using COCOEval API.
+ """
+ assert len(coco_results) > 0
+
+ if iou_type == "segm":
+ coco_results = copy.deepcopy(coco_results)
+ # When evaluating mask AP, if the results contain bbox, cocoapi will
+ # use the box area as the area of the instance, instead of the mask area.
+ # This leads to a different definition of small/medium/large.
+ # We remove the bbox field to let mask AP use mask area.
+ for c in coco_results:
+ c.pop("bbox", None)
+
+ coco_dt = coco_gt.loadRes(coco_results)
+ coco_eval = cocoeval_fn(coco_gt, coco_dt, iou_type)
+ # For COCO, the default max_dets_per_image is [1, 10, 100].
+ if max_dets_per_image is None:
+ max_dets_per_image = [1, 10, 100] # Default from COCOEval
+ else:
+ assert (
+ len(max_dets_per_image) >= 3
+ ), "COCOeval requires maxDets (and max_dets_per_image) to have length at least 3"
+ # In the case that user supplies a custom input for max_dets_per_image,
+ # apply COCOevalMaxDets to evaluate AP with the custom input.
+ if max_dets_per_image[2] != 100:
+ coco_eval = COCOevalMaxDets(coco_gt, coco_dt, iou_type)
+ if iou_type != "keypoints":
+ coco_eval.params.maxDets = max_dets_per_image
+
+ if img_ids is not None:
+ coco_eval.params.imgIds = img_ids
+
+ if iou_type == "keypoints":
+ # Use the COCO default keypoint OKS sigmas unless overrides are specified
+ if kpt_oks_sigmas:
+ assert hasattr(coco_eval.params, "kpt_oks_sigmas"), "annotator.oneformer.pycocotools is too old!"
+ coco_eval.params.kpt_oks_sigmas = np.array(kpt_oks_sigmas)
+ # COCOAPI requires every detection and every gt to have keypoints, so
+ # we just take the first entry from both
+ num_keypoints_dt = len(coco_results[0]["keypoints"]) // 3
+ num_keypoints_gt = len(next(iter(coco_gt.anns.values()))["keypoints"]) // 3
+ num_keypoints_oks = len(coco_eval.params.kpt_oks_sigmas)
+ assert num_keypoints_oks == num_keypoints_dt == num_keypoints_gt, (
+ f"[COCOEvaluator] Prediction contain {num_keypoints_dt} keypoints. "
+ f"Ground truth contains {num_keypoints_gt} keypoints. "
+ f"The length of cfg.TEST.KEYPOINT_OKS_SIGMAS is {num_keypoints_oks}. "
+ "They have to agree with each other. For meaning of OKS, please refer to "
+ "http://cocodataset.org/#keypoints-eval."
+ )
+
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+
+ return coco_eval
+
+
+class COCOevalMaxDets(COCOeval):
+ """
+ Modified version of COCOeval for evaluating AP with a custom
+ maxDets (by default for COCO, maxDets is 100)
+ """
+
+ def summarize(self):
+ """
+ Compute and display summary metrics for evaluation results given
+ a custom value for max_dets_per_image
+ """
+
+ def _summarize(ap=1, iouThr=None, areaRng="all", maxDets=100):
+ p = self.params
+ iStr = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}"
+ titleStr = "Average Precision" if ap == 1 else "Average Recall"
+ typeStr = "(AP)" if ap == 1 else "(AR)"
+ iouStr = (
+ "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1])
+ if iouThr is None
+ else "{:0.2f}".format(iouThr)
+ )
+
+ aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
+ mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
+ if ap == 1:
+ # dimension of precision: [TxRxKxAxM]
+ s = self.eval["precision"]
+ # IoU
+ if iouThr is not None:
+ t = np.where(iouThr == p.iouThrs)[0]
+ s = s[t]
+ s = s[:, :, :, aind, mind]
+ else:
+ # dimension of recall: [TxKxAxM]
+ s = self.eval["recall"]
+ if iouThr is not None:
+ t = np.where(iouThr == p.iouThrs)[0]
+ s = s[t]
+ s = s[:, :, aind, mind]
+ if len(s[s > -1]) == 0:
+ mean_s = -1
+ else:
+ mean_s = np.mean(s[s > -1])
+ print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
+ return mean_s
+
+ def _summarizeDets():
+ stats = np.zeros((12,))
+ # Evaluate AP using the custom limit on maximum detections per image
+ stats[0] = _summarize(1, maxDets=self.params.maxDets[2])
+ stats[1] = _summarize(1, iouThr=0.5, maxDets=self.params.maxDets[2])
+ stats[2] = _summarize(1, iouThr=0.75, maxDets=self.params.maxDets[2])
+ stats[3] = _summarize(1, areaRng="small", maxDets=self.params.maxDets[2])
+ stats[4] = _summarize(1, areaRng="medium", maxDets=self.params.maxDets[2])
+ stats[5] = _summarize(1, areaRng="large", maxDets=self.params.maxDets[2])
+ stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
+ stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
+ stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
+ stats[9] = _summarize(0, areaRng="small", maxDets=self.params.maxDets[2])
+ stats[10] = _summarize(0, areaRng="medium", maxDets=self.params.maxDets[2])
+ stats[11] = _summarize(0, areaRng="large", maxDets=self.params.maxDets[2])
+ return stats
+
+ def _summarizeKps():
+ stats = np.zeros((10,))
+ stats[0] = _summarize(1, maxDets=20)
+ stats[1] = _summarize(1, maxDets=20, iouThr=0.5)
+ stats[2] = _summarize(1, maxDets=20, iouThr=0.75)
+ stats[3] = _summarize(1, maxDets=20, areaRng="medium")
+ stats[4] = _summarize(1, maxDets=20, areaRng="large")
+ stats[5] = _summarize(0, maxDets=20)
+ stats[6] = _summarize(0, maxDets=20, iouThr=0.5)
+ stats[7] = _summarize(0, maxDets=20, iouThr=0.75)
+ stats[8] = _summarize(0, maxDets=20, areaRng="medium")
+ stats[9] = _summarize(0, maxDets=20, areaRng="large")
+ return stats
+
+ if not self.eval:
+ raise Exception("Please run accumulate() first")
+ iouType = self.params.iouType
+ if iouType == "segm" or iouType == "bbox":
+ summarize = _summarizeDets
+ elif iouType == "keypoints":
+ summarize = _summarizeKps
+ self.stats = summarize()
+
+ def __str__(self):
+ self.summarize()
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/evaluator.py b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cddc296432cbb6f11caf3c3be98833a50778ffb
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/evaluator.py
@@ -0,0 +1,224 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import datetime
+import logging
+import time
+from collections import OrderedDict, abc
+from contextlib import ExitStack, contextmanager
+from typing import List, Union
+import torch
+from torch import nn
+
+from annotator.oneformer.detectron2.utils.comm import get_world_size, is_main_process
+from annotator.oneformer.detectron2.utils.logger import log_every_n_seconds
+
+
+class DatasetEvaluator:
+ """
+ Base class for a dataset evaluator.
+
+ The function :func:`inference_on_dataset` runs the model over
+ all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs.
+
+ This class will accumulate information of the inputs/outputs (by :meth:`process`),
+ and produce evaluation results in the end (by :meth:`evaluate`).
+ """
+
+ def reset(self):
+ """
+ Preparation for a new round of evaluation.
+ Should be called before starting a round of evaluation.
+ """
+ pass
+
+ def process(self, inputs, outputs):
+ """
+ Process the pair of inputs and outputs.
+ If they contain batches, the pairs can be consumed one-by-one using `zip`:
+
+ .. code-block:: python
+
+ for input_, output in zip(inputs, outputs):
+ # do evaluation on single input/output pair
+ ...
+
+ Args:
+ inputs (list): the inputs that's used to call the model.
+ outputs (list): the return value of `model(inputs)`
+ """
+ pass
+
+ def evaluate(self):
+ """
+ Evaluate/summarize the performance, after processing all input/output pairs.
+
+ Returns:
+ dict:
+ A new evaluator class can return a dict of arbitrary format
+ as long as the user can process the results.
+ In our train_net.py, we expect the following format:
+
+ * key: the name of the task (e.g., bbox)
+ * value: a dict of {metric name: score}, e.g.: {"AP50": 80}
+ """
+ pass
+
+
+class DatasetEvaluators(DatasetEvaluator):
+ """
+ Wrapper class to combine multiple :class:`DatasetEvaluator` instances.
+
+ This class dispatches every evaluation call to
+ all of its :class:`DatasetEvaluator`.
+ """
+
+ def __init__(self, evaluators):
+ """
+ Args:
+ evaluators (list): the evaluators to combine.
+ """
+ super().__init__()
+ self._evaluators = evaluators
+
+ def reset(self):
+ for evaluator in self._evaluators:
+ evaluator.reset()
+
+ def process(self, inputs, outputs):
+ for evaluator in self._evaluators:
+ evaluator.process(inputs, outputs)
+
+ def evaluate(self):
+ results = OrderedDict()
+ for evaluator in self._evaluators:
+ result = evaluator.evaluate()
+ if is_main_process() and result is not None:
+ for k, v in result.items():
+ assert (
+ k not in results
+ ), "Different evaluators produce results with the same key {}".format(k)
+ results[k] = v
+ return results
+
+
+def inference_on_dataset(
+ model, data_loader, evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None]
+):
+ """
+ Run model on the data_loader and evaluate the metrics with evaluator.
+ Also benchmark the inference speed of `model.__call__` accurately.
+ The model will be used in eval mode.
+
+ Args:
+ model (callable): a callable which takes an object from
+ `data_loader` and returns some outputs.
+
+ If it's an nn.Module, it will be temporarily set to `eval` mode.
+ If you wish to evaluate a model in `training` mode instead, you can
+ wrap the given model and override its behavior of `.eval()` and `.train()`.
+ data_loader: an iterable object with a length.
+ The elements it generates will be the inputs to the model.
+ evaluator: the evaluator(s) to run. Use `None` if you only want to benchmark,
+ but don't want to do any evaluation.
+
+ Returns:
+ The return value of `evaluator.evaluate()`
+ """
+ num_devices = get_world_size()
+ logger = logging.getLogger(__name__)
+ logger.info("Start inference on {} batches".format(len(data_loader)))
+
+ total = len(data_loader) # inference data loader must have a fixed length
+ if evaluator is None:
+ # create a no-op evaluator
+ evaluator = DatasetEvaluators([])
+ if isinstance(evaluator, abc.MutableSequence):
+ evaluator = DatasetEvaluators(evaluator)
+ evaluator.reset()
+
+ num_warmup = min(5, total - 1)
+ start_time = time.perf_counter()
+ total_data_time = 0
+ total_compute_time = 0
+ total_eval_time = 0
+ with ExitStack() as stack:
+ if isinstance(model, nn.Module):
+ stack.enter_context(inference_context(model))
+ stack.enter_context(torch.no_grad())
+
+ start_data_time = time.perf_counter()
+ for idx, inputs in enumerate(data_loader):
+ total_data_time += time.perf_counter() - start_data_time
+ if idx == num_warmup:
+ start_time = time.perf_counter()
+ total_data_time = 0
+ total_compute_time = 0
+ total_eval_time = 0
+
+ start_compute_time = time.perf_counter()
+ outputs = model(inputs)
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ total_compute_time += time.perf_counter() - start_compute_time
+
+ start_eval_time = time.perf_counter()
+ evaluator.process(inputs, outputs)
+ total_eval_time += time.perf_counter() - start_eval_time
+
+ iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
+ data_seconds_per_iter = total_data_time / iters_after_start
+ compute_seconds_per_iter = total_compute_time / iters_after_start
+ eval_seconds_per_iter = total_eval_time / iters_after_start
+ total_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_start
+ if idx >= num_warmup * 2 or compute_seconds_per_iter > 5:
+ eta = datetime.timedelta(seconds=int(total_seconds_per_iter * (total - idx - 1)))
+ log_every_n_seconds(
+ logging.INFO,
+ (
+ f"Inference done {idx + 1}/{total}. "
+ f"Dataloading: {data_seconds_per_iter:.4f} s/iter. "
+ f"Inference: {compute_seconds_per_iter:.4f} s/iter. "
+ f"Eval: {eval_seconds_per_iter:.4f} s/iter. "
+ f"Total: {total_seconds_per_iter:.4f} s/iter. "
+ f"ETA={eta}"
+ ),
+ n=5,
+ )
+ start_data_time = time.perf_counter()
+
+ # Measure the time only for this worker (before the synchronization barrier)
+ total_time = time.perf_counter() - start_time
+ total_time_str = str(datetime.timedelta(seconds=total_time))
+ # NOTE this format is parsed by grep
+ logger.info(
+ "Total inference time: {} ({:.6f} s / iter per device, on {} devices)".format(
+ total_time_str, total_time / (total - num_warmup), num_devices
+ )
+ )
+ total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
+ logger.info(
+ "Total inference pure compute time: {} ({:.6f} s / iter per device, on {} devices)".format(
+ total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
+ )
+ )
+
+ results = evaluator.evaluate()
+ # An evaluator may return None when not in main process.
+ # Replace it by an empty dict instead to make it easier for downstream code to handle
+ if results is None:
+ results = {}
+ return results
+
+
+@contextmanager
+def inference_context(model):
+ """
+ A context where the model is temporarily changed to eval mode,
+ and restored to previous mode afterwards.
+
+ Args:
+ model: a torch Module
+ """
+ training_mode = model.training
+ model.eval()
+ yield
+ model.train(training_mode)
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/fast_eval_api.py b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/fast_eval_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad1a8f82350098bafe56f6d9481626e812717052
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/fast_eval_api.py
@@ -0,0 +1,121 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import copy
+import logging
+import numpy as np
+import time
+from annotator.oneformer.pycocotools.cocoeval import COCOeval
+
+from annotator.oneformer.detectron2 import _C
+
+logger = logging.getLogger(__name__)
+
+
+class COCOeval_opt(COCOeval):
+ """
+ This is a slightly modified version of the original COCO API, where the functions evaluateImg()
+ and accumulate() are implemented in C++ to speedup evaluation
+ """
+
+ def evaluate(self):
+ """
+ Run per image evaluation on given images and store results in self.evalImgs_cpp, a
+ datastructure that isn't readable from Python but is used by a c++ implementation of
+ accumulate(). Unlike the original COCO PythonAPI, we don't populate the datastructure
+ self.evalImgs because this datastructure is a computational bottleneck.
+ :return: None
+ """
+ tic = time.time()
+
+ p = self.params
+ # add backward compatibility if useSegm is specified in params
+ if p.useSegm is not None:
+ p.iouType = "segm" if p.useSegm == 1 else "bbox"
+ logger.info("Evaluate annotation type *{}*".format(p.iouType))
+ p.imgIds = list(np.unique(p.imgIds))
+ if p.useCats:
+ p.catIds = list(np.unique(p.catIds))
+ p.maxDets = sorted(p.maxDets)
+ self.params = p
+
+ self._prepare() # bottleneck
+
+ # loop through images, area range, max detection number
+ catIds = p.catIds if p.useCats else [-1]
+
+ if p.iouType == "segm" or p.iouType == "bbox":
+ computeIoU = self.computeIoU
+ elif p.iouType == "keypoints":
+ computeIoU = self.computeOks
+ self.ious = {
+ (imgId, catId): computeIoU(imgId, catId) for imgId in p.imgIds for catId in catIds
+ } # bottleneck
+
+ maxDet = p.maxDets[-1]
+
+ # <<<< Beginning of code differences with original COCO API
+ def convert_instances_to_cpp(instances, is_det=False):
+ # Convert annotations for a list of instances in an image to a format that's fast
+ # to access in C++
+ instances_cpp = []
+ for instance in instances:
+ instance_cpp = _C.InstanceAnnotation(
+ int(instance["id"]),
+ instance["score"] if is_det else instance.get("score", 0.0),
+ instance["area"],
+ bool(instance.get("iscrowd", 0)),
+ bool(instance.get("ignore", 0)),
+ )
+ instances_cpp.append(instance_cpp)
+ return instances_cpp
+
+ # Convert GT annotations, detections, and IOUs to a format that's fast to access in C++
+ ground_truth_instances = [
+ [convert_instances_to_cpp(self._gts[imgId, catId]) for catId in p.catIds]
+ for imgId in p.imgIds
+ ]
+ detected_instances = [
+ [convert_instances_to_cpp(self._dts[imgId, catId], is_det=True) for catId in p.catIds]
+ for imgId in p.imgIds
+ ]
+ ious = [[self.ious[imgId, catId] for catId in catIds] for imgId in p.imgIds]
+
+ if not p.useCats:
+ # For each image, flatten per-category lists into a single list
+ ground_truth_instances = [[[o for c in i for o in c]] for i in ground_truth_instances]
+ detected_instances = [[[o for c in i for o in c]] for i in detected_instances]
+
+ # Call C++ implementation of self.evaluateImgs()
+ self._evalImgs_cpp = _C.COCOevalEvaluateImages(
+ p.areaRng, maxDet, p.iouThrs, ious, ground_truth_instances, detected_instances
+ )
+ self._evalImgs = None
+
+ self._paramsEval = copy.deepcopy(self.params)
+ toc = time.time()
+ logger.info("COCOeval_opt.evaluate() finished in {:0.2f} seconds.".format(toc - tic))
+ # >>>> End of code differences with original COCO API
+
+ def accumulate(self):
+ """
+ Accumulate per image evaluation results and store the result in self.eval. Does not
+ support changing parameter settings from those used by self.evaluate()
+ """
+ logger.info("Accumulating evaluation results...")
+ tic = time.time()
+ assert hasattr(
+ self, "_evalImgs_cpp"
+ ), "evaluate() must be called before accmulate() is called."
+
+ self.eval = _C.COCOevalAccumulate(self._paramsEval, self._evalImgs_cpp)
+
+ # recall is num_iou_thresholds X num_categories X num_area_ranges X num_max_detections
+ self.eval["recall"] = np.array(self.eval["recall"]).reshape(
+ self.eval["counts"][:1] + self.eval["counts"][2:]
+ )
+
+ # precision and scores are num_iou_thresholds X num_recall_thresholds X num_categories X
+ # num_area_ranges X num_max_detections
+ self.eval["precision"] = np.array(self.eval["precision"]).reshape(self.eval["counts"])
+ self.eval["scores"] = np.array(self.eval["scores"]).reshape(self.eval["counts"])
+ toc = time.time()
+ logger.info("COCOeval_opt.accumulate() finished in {:0.2f} seconds.".format(toc - tic))
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/lvis_evaluation.py b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/lvis_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d712ef262789edb85392cb54577c3a6b15e223e
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/lvis_evaluation.py
@@ -0,0 +1,380 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import copy
+import itertools
+import json
+import logging
+import os
+import pickle
+from collections import OrderedDict
+import torch
+
+import annotator.oneformer.detectron2.utils.comm as comm
+from annotator.oneformer.detectron2.config import CfgNode
+from annotator.oneformer.detectron2.data import MetadataCatalog
+from annotator.oneformer.detectron2.structures import Boxes, BoxMode, pairwise_iou
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+from annotator.oneformer.detectron2.utils.logger import create_small_table
+
+from .coco_evaluation import instances_to_coco_json
+from .evaluator import DatasetEvaluator
+
+
+class LVISEvaluator(DatasetEvaluator):
+ """
+ Evaluate object proposal and instance detection/segmentation outputs using
+ LVIS's metrics and evaluation API.
+ """
+
+ def __init__(
+ self,
+ dataset_name,
+ tasks=None,
+ distributed=True,
+ output_dir=None,
+ *,
+ max_dets_per_image=None,
+ ):
+ """
+ Args:
+ dataset_name (str): name of the dataset to be evaluated.
+ It must have the following corresponding metadata:
+ "json_file": the path to the LVIS format annotation
+ tasks (tuple[str]): tasks that can be evaluated under the given
+ configuration. A task is one of "bbox", "segm".
+ By default, will infer this automatically from predictions.
+ distributed (True): if True, will collect results from all ranks for evaluation.
+ Otherwise, will evaluate the results in the current process.
+ output_dir (str): optional, an output directory to dump results.
+ max_dets_per_image (None or int): limit on maximum detections per image in evaluating AP
+ This limit, by default of the LVIS dataset, is 300.
+ """
+ from lvis import LVIS
+
+ self._logger = logging.getLogger(__name__)
+
+ if tasks is not None and isinstance(tasks, CfgNode):
+ self._logger.warn(
+ "COCO Evaluator instantiated using config, this is deprecated behavior."
+ " Please pass in explicit arguments instead."
+ )
+ self._tasks = None # Infering it from predictions should be better
+ else:
+ self._tasks = tasks
+
+ self._distributed = distributed
+ self._output_dir = output_dir
+ self._max_dets_per_image = max_dets_per_image
+
+ self._cpu_device = torch.device("cpu")
+
+ self._metadata = MetadataCatalog.get(dataset_name)
+ json_file = PathManager.get_local_path(self._metadata.json_file)
+ self._lvis_api = LVIS(json_file)
+ # Test set json files do not contain annotations (evaluation must be
+ # performed using the LVIS evaluation server).
+ self._do_evaluation = len(self._lvis_api.get_ann_ids()) > 0
+
+ def reset(self):
+ self._predictions = []
+
+ def process(self, inputs, outputs):
+ """
+ Args:
+ inputs: the inputs to a LVIS model (e.g., GeneralizedRCNN).
+ It is a list of dict. Each dict corresponds to an image and
+ contains keys like "height", "width", "file_name", "image_id".
+ outputs: the outputs of a LVIS model. It is a list of dicts with key
+ "instances" that contains :class:`Instances`.
+ """
+ for input, output in zip(inputs, outputs):
+ prediction = {"image_id": input["image_id"]}
+
+ if "instances" in output:
+ instances = output["instances"].to(self._cpu_device)
+ prediction["instances"] = instances_to_coco_json(instances, input["image_id"])
+ if "proposals" in output:
+ prediction["proposals"] = output["proposals"].to(self._cpu_device)
+ self._predictions.append(prediction)
+
+ def evaluate(self):
+ if self._distributed:
+ comm.synchronize()
+ predictions = comm.gather(self._predictions, dst=0)
+ predictions = list(itertools.chain(*predictions))
+
+ if not comm.is_main_process():
+ return
+ else:
+ predictions = self._predictions
+
+ if len(predictions) == 0:
+ self._logger.warning("[LVISEvaluator] Did not receive valid predictions.")
+ return {}
+
+ if self._output_dir:
+ PathManager.mkdirs(self._output_dir)
+ file_path = os.path.join(self._output_dir, "instances_predictions.pth")
+ with PathManager.open(file_path, "wb") as f:
+ torch.save(predictions, f)
+
+ self._results = OrderedDict()
+ if "proposals" in predictions[0]:
+ self._eval_box_proposals(predictions)
+ if "instances" in predictions[0]:
+ self._eval_predictions(predictions)
+ # Copy so the caller can do whatever with results
+ return copy.deepcopy(self._results)
+
+ def _tasks_from_predictions(self, predictions):
+ for pred in predictions:
+ if "segmentation" in pred:
+ return ("bbox", "segm")
+ return ("bbox",)
+
+ def _eval_predictions(self, predictions):
+ """
+ Evaluate predictions. Fill self._results with the metrics of the tasks.
+
+ Args:
+ predictions (list[dict]): list of outputs from the model
+ """
+ self._logger.info("Preparing results in the LVIS format ...")
+ lvis_results = list(itertools.chain(*[x["instances"] for x in predictions]))
+ tasks = self._tasks or self._tasks_from_predictions(lvis_results)
+
+ # LVIS evaluator can be used to evaluate results for COCO dataset categories.
+ # In this case `_metadata` variable will have a field with COCO-specific category mapping.
+ if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
+ reverse_id_mapping = {
+ v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
+ }
+ for result in lvis_results:
+ result["category_id"] = reverse_id_mapping[result["category_id"]]
+ else:
+ # unmap the category ids for LVIS (from 0-indexed to 1-indexed)
+ for result in lvis_results:
+ result["category_id"] += 1
+
+ if self._output_dir:
+ file_path = os.path.join(self._output_dir, "lvis_instances_results.json")
+ self._logger.info("Saving results to {}".format(file_path))
+ with PathManager.open(file_path, "w") as f:
+ f.write(json.dumps(lvis_results))
+ f.flush()
+
+ if not self._do_evaluation:
+ self._logger.info("Annotations are not available for evaluation.")
+ return
+
+ self._logger.info("Evaluating predictions ...")
+ for task in sorted(tasks):
+ res = _evaluate_predictions_on_lvis(
+ self._lvis_api,
+ lvis_results,
+ task,
+ max_dets_per_image=self._max_dets_per_image,
+ class_names=self._metadata.get("thing_classes"),
+ )
+ self._results[task] = res
+
+ def _eval_box_proposals(self, predictions):
+ """
+ Evaluate the box proposals in predictions.
+ Fill self._results with the metrics for "box_proposals" task.
+ """
+ if self._output_dir:
+ # Saving generated box proposals to file.
+ # Predicted box_proposals are in XYXY_ABS mode.
+ bbox_mode = BoxMode.XYXY_ABS.value
+ ids, boxes, objectness_logits = [], [], []
+ for prediction in predictions:
+ ids.append(prediction["image_id"])
+ boxes.append(prediction["proposals"].proposal_boxes.tensor.numpy())
+ objectness_logits.append(prediction["proposals"].objectness_logits.numpy())
+
+ proposal_data = {
+ "boxes": boxes,
+ "objectness_logits": objectness_logits,
+ "ids": ids,
+ "bbox_mode": bbox_mode,
+ }
+ with PathManager.open(os.path.join(self._output_dir, "box_proposals.pkl"), "wb") as f:
+ pickle.dump(proposal_data, f)
+
+ if not self._do_evaluation:
+ self._logger.info("Annotations are not available for evaluation.")
+ return
+
+ self._logger.info("Evaluating bbox proposals ...")
+ res = {}
+ areas = {"all": "", "small": "s", "medium": "m", "large": "l"}
+ for limit in [100, 1000]:
+ for area, suffix in areas.items():
+ stats = _evaluate_box_proposals(predictions, self._lvis_api, area=area, limit=limit)
+ key = "AR{}@{:d}".format(suffix, limit)
+ res[key] = float(stats["ar"].item() * 100)
+ self._logger.info("Proposal metrics: \n" + create_small_table(res))
+ self._results["box_proposals"] = res
+
+
+# inspired from Detectron:
+# https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L255 # noqa
+def _evaluate_box_proposals(dataset_predictions, lvis_api, thresholds=None, area="all", limit=None):
+ """
+ Evaluate detection proposal recall metrics. This function is a much
+ faster alternative to the official LVIS API recall evaluation code. However,
+ it produces slightly different results.
+ """
+ # Record max overlap value for each gt box
+ # Return vector of overlap values
+ areas = {
+ "all": 0,
+ "small": 1,
+ "medium": 2,
+ "large": 3,
+ "96-128": 4,
+ "128-256": 5,
+ "256-512": 6,
+ "512-inf": 7,
+ }
+ area_ranges = [
+ [0**2, 1e5**2], # all
+ [0**2, 32**2], # small
+ [32**2, 96**2], # medium
+ [96**2, 1e5**2], # large
+ [96**2, 128**2], # 96-128
+ [128**2, 256**2], # 128-256
+ [256**2, 512**2], # 256-512
+ [512**2, 1e5**2],
+ ] # 512-inf
+ assert area in areas, "Unknown area range: {}".format(area)
+ area_range = area_ranges[areas[area]]
+ gt_overlaps = []
+ num_pos = 0
+
+ for prediction_dict in dataset_predictions:
+ predictions = prediction_dict["proposals"]
+
+ # sort predictions in descending order
+ # TODO maybe remove this and make it explicit in the documentation
+ inds = predictions.objectness_logits.sort(descending=True)[1]
+ predictions = predictions[inds]
+
+ ann_ids = lvis_api.get_ann_ids(img_ids=[prediction_dict["image_id"]])
+ anno = lvis_api.load_anns(ann_ids)
+ gt_boxes = [
+ BoxMode.convert(obj["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) for obj in anno
+ ]
+ gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes
+ gt_boxes = Boxes(gt_boxes)
+ gt_areas = torch.as_tensor([obj["area"] for obj in anno])
+
+ if len(gt_boxes) == 0 or len(predictions) == 0:
+ continue
+
+ valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1])
+ gt_boxes = gt_boxes[valid_gt_inds]
+
+ num_pos += len(gt_boxes)
+
+ if len(gt_boxes) == 0:
+ continue
+
+ if limit is not None and len(predictions) > limit:
+ predictions = predictions[:limit]
+
+ overlaps = pairwise_iou(predictions.proposal_boxes, gt_boxes)
+
+ _gt_overlaps = torch.zeros(len(gt_boxes))
+ for j in range(min(len(predictions), len(gt_boxes))):
+ # find which proposal box maximally covers each gt box
+ # and get the iou amount of coverage for each gt box
+ max_overlaps, argmax_overlaps = overlaps.max(dim=0)
+
+ # find which gt box is 'best' covered (i.e. 'best' = most iou)
+ gt_ovr, gt_ind = max_overlaps.max(dim=0)
+ assert gt_ovr >= 0
+ # find the proposal box that covers the best covered gt box
+ box_ind = argmax_overlaps[gt_ind]
+ # record the iou coverage of this gt box
+ _gt_overlaps[j] = overlaps[box_ind, gt_ind]
+ assert _gt_overlaps[j] == gt_ovr
+ # mark the proposal box and the gt box as used
+ overlaps[box_ind, :] = -1
+ overlaps[:, gt_ind] = -1
+
+ # append recorded iou coverage level
+ gt_overlaps.append(_gt_overlaps)
+ gt_overlaps = (
+ torch.cat(gt_overlaps, dim=0) if len(gt_overlaps) else torch.zeros(0, dtype=torch.float32)
+ )
+ gt_overlaps, _ = torch.sort(gt_overlaps)
+
+ if thresholds is None:
+ step = 0.05
+ thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
+ recalls = torch.zeros_like(thresholds)
+ # compute recall for each iou threshold
+ for i, t in enumerate(thresholds):
+ recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
+ # ar = 2 * np.trapz(recalls, thresholds)
+ ar = recalls.mean()
+ return {
+ "ar": ar,
+ "recalls": recalls,
+ "thresholds": thresholds,
+ "gt_overlaps": gt_overlaps,
+ "num_pos": num_pos,
+ }
+
+
+def _evaluate_predictions_on_lvis(
+ lvis_gt, lvis_results, iou_type, max_dets_per_image=None, class_names=None
+):
+ """
+ Args:
+ iou_type (str):
+ max_dets_per_image (None or int): limit on maximum detections per image in evaluating AP
+ This limit, by default of the LVIS dataset, is 300.
+ class_names (None or list[str]): if provided, will use it to predict
+ per-category AP.
+
+ Returns:
+ a dict of {metric name: score}
+ """
+ metrics = {
+ "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl", "APr", "APc", "APf"],
+ "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl", "APr", "APc", "APf"],
+ }[iou_type]
+
+ logger = logging.getLogger(__name__)
+
+ if len(lvis_results) == 0: # TODO: check if needed
+ logger.warn("No predictions from the model!")
+ return {metric: float("nan") for metric in metrics}
+
+ if iou_type == "segm":
+ lvis_results = copy.deepcopy(lvis_results)
+ # When evaluating mask AP, if the results contain bbox, LVIS API will
+ # use the box area as the area of the instance, instead of the mask area.
+ # This leads to a different definition of small/medium/large.
+ # We remove the bbox field to let mask AP use mask area.
+ for c in lvis_results:
+ c.pop("bbox", None)
+
+ if max_dets_per_image is None:
+ max_dets_per_image = 300 # Default for LVIS dataset
+
+ from lvis import LVISEval, LVISResults
+
+ logger.info(f"Evaluating with max detections per image = {max_dets_per_image}")
+ lvis_results = LVISResults(lvis_gt, lvis_results, max_dets=max_dets_per_image)
+ lvis_eval = LVISEval(lvis_gt, lvis_results, iou_type)
+ lvis_eval.run()
+ lvis_eval.print_results()
+
+ # Pull the standard metrics from the LVIS results
+ results = lvis_eval.get_results()
+ results = {metric: float(results[metric] * 100) for metric in metrics}
+ logger.info("Evaluation results for {}: \n".format(iou_type) + create_small_table(results))
+ return results
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/panoptic_evaluation.py b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/panoptic_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf77fe061291f44381f8417e82e8b2bc7c5a60c6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/panoptic_evaluation.py
@@ -0,0 +1,199 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import contextlib
+import io
+import itertools
+import json
+import logging
+import numpy as np
+import os
+import tempfile
+from collections import OrderedDict
+from typing import Optional
+from PIL import Image
+from tabulate import tabulate
+
+from annotator.oneformer.detectron2.data import MetadataCatalog
+from annotator.oneformer.detectron2.utils import comm
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+from .evaluator import DatasetEvaluator
+
+logger = logging.getLogger(__name__)
+
+
+class COCOPanopticEvaluator(DatasetEvaluator):
+ """
+ Evaluate Panoptic Quality metrics on COCO using PanopticAPI.
+ It saves panoptic segmentation prediction in `output_dir`
+
+ It contains a synchronize call and has to be called from all workers.
+ """
+
+ def __init__(self, dataset_name: str, output_dir: Optional[str] = None):
+ """
+ Args:
+ dataset_name: name of the dataset
+ output_dir: output directory to save results for evaluation.
+ """
+ self._metadata = MetadataCatalog.get(dataset_name)
+ self._thing_contiguous_id_to_dataset_id = {
+ v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
+ }
+ self._stuff_contiguous_id_to_dataset_id = {
+ v: k for k, v in self._metadata.stuff_dataset_id_to_contiguous_id.items()
+ }
+
+ self._output_dir = output_dir
+ if self._output_dir is not None:
+ PathManager.mkdirs(self._output_dir)
+
+ def reset(self):
+ self._predictions = []
+
+ def _convert_category_id(self, segment_info):
+ isthing = segment_info.pop("isthing", None)
+ if isthing is None:
+ # the model produces panoptic category id directly. No more conversion needed
+ return segment_info
+ if isthing is True:
+ segment_info["category_id"] = self._thing_contiguous_id_to_dataset_id[
+ segment_info["category_id"]
+ ]
+ else:
+ segment_info["category_id"] = self._stuff_contiguous_id_to_dataset_id[
+ segment_info["category_id"]
+ ]
+ return segment_info
+
+ def process(self, inputs, outputs):
+ from panopticapi.utils import id2rgb
+
+ for input, output in zip(inputs, outputs):
+ panoptic_img, segments_info = output["panoptic_seg"]
+ panoptic_img = panoptic_img.cpu().numpy()
+ if segments_info is None:
+ # If "segments_info" is None, we assume "panoptic_img" is a
+ # H*W int32 image storing the panoptic_id in the format of
+ # category_id * label_divisor + instance_id. We reserve -1 for
+ # VOID label, and add 1 to panoptic_img since the official
+ # evaluation script uses 0 for VOID label.
+ label_divisor = self._metadata.label_divisor
+ segments_info = []
+ for panoptic_label in np.unique(panoptic_img):
+ if panoptic_label == -1:
+ # VOID region.
+ continue
+ pred_class = panoptic_label // label_divisor
+ isthing = (
+ pred_class in self._metadata.thing_dataset_id_to_contiguous_id.values()
+ )
+ segments_info.append(
+ {
+ "id": int(panoptic_label) + 1,
+ "category_id": int(pred_class),
+ "isthing": bool(isthing),
+ }
+ )
+ # Official evaluation script uses 0 for VOID label.
+ panoptic_img += 1
+
+ file_name = os.path.basename(input["file_name"])
+ file_name_png = os.path.splitext(file_name)[0] + ".png"
+ with io.BytesIO() as out:
+ Image.fromarray(id2rgb(panoptic_img)).save(out, format="PNG")
+ segments_info = [self._convert_category_id(x) for x in segments_info]
+ self._predictions.append(
+ {
+ "image_id": input["image_id"],
+ "file_name": file_name_png,
+ "png_string": out.getvalue(),
+ "segments_info": segments_info,
+ }
+ )
+
+ def evaluate(self):
+ comm.synchronize()
+
+ self._predictions = comm.gather(self._predictions)
+ self._predictions = list(itertools.chain(*self._predictions))
+ if not comm.is_main_process():
+ return
+
+ # PanopticApi requires local files
+ gt_json = PathManager.get_local_path(self._metadata.panoptic_json)
+ gt_folder = PathManager.get_local_path(self._metadata.panoptic_root)
+
+ with tempfile.TemporaryDirectory(prefix="panoptic_eval") as pred_dir:
+ logger.info("Writing all panoptic predictions to {} ...".format(pred_dir))
+ for p in self._predictions:
+ with open(os.path.join(pred_dir, p["file_name"]), "wb") as f:
+ f.write(p.pop("png_string"))
+
+ with open(gt_json, "r") as f:
+ json_data = json.load(f)
+ json_data["annotations"] = self._predictions
+
+ output_dir = self._output_dir or pred_dir
+ predictions_json = os.path.join(output_dir, "predictions.json")
+ with PathManager.open(predictions_json, "w") as f:
+ f.write(json.dumps(json_data))
+
+ from panopticapi.evaluation import pq_compute
+
+ with contextlib.redirect_stdout(io.StringIO()):
+ pq_res = pq_compute(
+ gt_json,
+ PathManager.get_local_path(predictions_json),
+ gt_folder=gt_folder,
+ pred_folder=pred_dir,
+ )
+
+ res = {}
+ res["PQ"] = 100 * pq_res["All"]["pq"]
+ res["SQ"] = 100 * pq_res["All"]["sq"]
+ res["RQ"] = 100 * pq_res["All"]["rq"]
+ res["PQ_th"] = 100 * pq_res["Things"]["pq"]
+ res["SQ_th"] = 100 * pq_res["Things"]["sq"]
+ res["RQ_th"] = 100 * pq_res["Things"]["rq"]
+ res["PQ_st"] = 100 * pq_res["Stuff"]["pq"]
+ res["SQ_st"] = 100 * pq_res["Stuff"]["sq"]
+ res["RQ_st"] = 100 * pq_res["Stuff"]["rq"]
+
+ results = OrderedDict({"panoptic_seg": res})
+ _print_panoptic_results(pq_res)
+
+ return results
+
+
+def _print_panoptic_results(pq_res):
+ headers = ["", "PQ", "SQ", "RQ", "#categories"]
+ data = []
+ for name in ["All", "Things", "Stuff"]:
+ row = [name] + [pq_res[name][k] * 100 for k in ["pq", "sq", "rq"]] + [pq_res[name]["n"]]
+ data.append(row)
+ table = tabulate(
+ data, headers=headers, tablefmt="pipe", floatfmt=".3f", stralign="center", numalign="center"
+ )
+ logger.info("Panoptic Evaluation Results:\n" + table)
+
+
+if __name__ == "__main__":
+ from annotator.oneformer.detectron2.utils.logger import setup_logger
+
+ logger = setup_logger()
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--gt-json")
+ parser.add_argument("--gt-dir")
+ parser.add_argument("--pred-json")
+ parser.add_argument("--pred-dir")
+ args = parser.parse_args()
+
+ from panopticapi.evaluation import pq_compute
+
+ with contextlib.redirect_stdout(io.StringIO()):
+ pq_res = pq_compute(
+ args.gt_json, args.pred_json, gt_folder=args.gt_dir, pred_folder=args.pred_dir
+ )
+ _print_panoptic_results(pq_res)
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/pascal_voc_evaluation.py b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/pascal_voc_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2963e5dc5b6ed471f0c37056b35a350ea4cf020
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/pascal_voc_evaluation.py
@@ -0,0 +1,300 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import logging
+import numpy as np
+import os
+import tempfile
+import xml.etree.ElementTree as ET
+from collections import OrderedDict, defaultdict
+from functools import lru_cache
+import torch
+
+from annotator.oneformer.detectron2.data import MetadataCatalog
+from annotator.oneformer.detectron2.utils import comm
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+from .evaluator import DatasetEvaluator
+
+
+class PascalVOCDetectionEvaluator(DatasetEvaluator):
+ """
+ Evaluate Pascal VOC style AP for Pascal VOC dataset.
+ It contains a synchronization, therefore has to be called from all ranks.
+
+ Note that the concept of AP can be implemented in different ways and may not
+ produce identical results. This class mimics the implementation of the official
+ Pascal VOC Matlab API, and should produce similar but not identical results to the
+ official API.
+ """
+
+ def __init__(self, dataset_name):
+ """
+ Args:
+ dataset_name (str): name of the dataset, e.g., "voc_2007_test"
+ """
+ self._dataset_name = dataset_name
+ meta = MetadataCatalog.get(dataset_name)
+
+ # Too many tiny files, download all to local for speed.
+ annotation_dir_local = PathManager.get_local_path(
+ os.path.join(meta.dirname, "Annotations/")
+ )
+ self._anno_file_template = os.path.join(annotation_dir_local, "{}.xml")
+ self._image_set_path = os.path.join(meta.dirname, "ImageSets", "Main", meta.split + ".txt")
+ self._class_names = meta.thing_classes
+ assert meta.year in [2007, 2012], meta.year
+ self._is_2007 = meta.year == 2007
+ self._cpu_device = torch.device("cpu")
+ self._logger = logging.getLogger(__name__)
+
+ def reset(self):
+ self._predictions = defaultdict(list) # class name -> list of prediction strings
+
+ def process(self, inputs, outputs):
+ for input, output in zip(inputs, outputs):
+ image_id = input["image_id"]
+ instances = output["instances"].to(self._cpu_device)
+ boxes = instances.pred_boxes.tensor.numpy()
+ scores = instances.scores.tolist()
+ classes = instances.pred_classes.tolist()
+ for box, score, cls in zip(boxes, scores, classes):
+ xmin, ymin, xmax, ymax = box
+ # The inverse of data loading logic in `datasets/pascal_voc.py`
+ xmin += 1
+ ymin += 1
+ self._predictions[cls].append(
+ f"{image_id} {score:.3f} {xmin:.1f} {ymin:.1f} {xmax:.1f} {ymax:.1f}"
+ )
+
+ def evaluate(self):
+ """
+ Returns:
+ dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75".
+ """
+ all_predictions = comm.gather(self._predictions, dst=0)
+ if not comm.is_main_process():
+ return
+ predictions = defaultdict(list)
+ for predictions_per_rank in all_predictions:
+ for clsid, lines in predictions_per_rank.items():
+ predictions[clsid].extend(lines)
+ del all_predictions
+
+ self._logger.info(
+ "Evaluating {} using {} metric. "
+ "Note that results do not use the official Matlab API.".format(
+ self._dataset_name, 2007 if self._is_2007 else 2012
+ )
+ )
+
+ with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname:
+ res_file_template = os.path.join(dirname, "{}.txt")
+
+ aps = defaultdict(list) # iou -> ap per class
+ for cls_id, cls_name in enumerate(self._class_names):
+ lines = predictions.get(cls_id, [""])
+
+ with open(res_file_template.format(cls_name), "w") as f:
+ f.write("\n".join(lines))
+
+ for thresh in range(50, 100, 5):
+ rec, prec, ap = voc_eval(
+ res_file_template,
+ self._anno_file_template,
+ self._image_set_path,
+ cls_name,
+ ovthresh=thresh / 100.0,
+ use_07_metric=self._is_2007,
+ )
+ aps[thresh].append(ap * 100)
+
+ ret = OrderedDict()
+ mAP = {iou: np.mean(x) for iou, x in aps.items()}
+ ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]}
+ return ret
+
+
+##############################################################################
+#
+# Below code is modified from
+# https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py
+# --------------------------------------------------------
+# Fast/er R-CNN
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Bharath Hariharan
+# --------------------------------------------------------
+
+"""Python implementation of the PASCAL VOC devkit's AP evaluation code."""
+
+
+@lru_cache(maxsize=None)
+def parse_rec(filename):
+ """Parse a PASCAL VOC xml file."""
+ with PathManager.open(filename) as f:
+ tree = ET.parse(f)
+ objects = []
+ for obj in tree.findall("object"):
+ obj_struct = {}
+ obj_struct["name"] = obj.find("name").text
+ obj_struct["pose"] = obj.find("pose").text
+ obj_struct["truncated"] = int(obj.find("truncated").text)
+ obj_struct["difficult"] = int(obj.find("difficult").text)
+ bbox = obj.find("bndbox")
+ obj_struct["bbox"] = [
+ int(bbox.find("xmin").text),
+ int(bbox.find("ymin").text),
+ int(bbox.find("xmax").text),
+ int(bbox.find("ymax").text),
+ ]
+ objects.append(obj_struct)
+
+ return objects
+
+
+def voc_ap(rec, prec, use_07_metric=False):
+ """Compute VOC AP given precision and recall. If use_07_metric is true, uses
+ the VOC 07 11-point method (default:False).
+ """
+ if use_07_metric:
+ # 11 point metric
+ ap = 0.0
+ for t in np.arange(0.0, 1.1, 0.1):
+ if np.sum(rec >= t) == 0:
+ p = 0
+ else:
+ p = np.max(prec[rec >= t])
+ ap = ap + p / 11.0
+ else:
+ # correct AP calculation
+ # first append sentinel values at the end
+ mrec = np.concatenate(([0.0], rec, [1.0]))
+ mpre = np.concatenate(([0.0], prec, [0.0]))
+
+ # compute the precision envelope
+ for i in range(mpre.size - 1, 0, -1):
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
+
+ # to calculate area under PR curve, look for points
+ # where X axis (recall) changes value
+ i = np.where(mrec[1:] != mrec[:-1])[0]
+
+ # and sum (\Delta recall) * prec
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
+ return ap
+
+
+def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_metric=False):
+ """rec, prec, ap = voc_eval(detpath,
+ annopath,
+ imagesetfile,
+ classname,
+ [ovthresh],
+ [use_07_metric])
+
+ Top level function that does the PASCAL VOC evaluation.
+
+ detpath: Path to detections
+ detpath.format(classname) should produce the detection results file.
+ annopath: Path to annotations
+ annopath.format(imagename) should be the xml annotations file.
+ imagesetfile: Text file containing the list of images, one image per line.
+ classname: Category name (duh)
+ [ovthresh]: Overlap threshold (default = 0.5)
+ [use_07_metric]: Whether to use VOC07's 11 point AP computation
+ (default False)
+ """
+ # assumes detections are in detpath.format(classname)
+ # assumes annotations are in annopath.format(imagename)
+ # assumes imagesetfile is a text file with each line an image name
+
+ # first load gt
+ # read list of images
+ with PathManager.open(imagesetfile, "r") as f:
+ lines = f.readlines()
+ imagenames = [x.strip() for x in lines]
+
+ # load annots
+ recs = {}
+ for imagename in imagenames:
+ recs[imagename] = parse_rec(annopath.format(imagename))
+
+ # extract gt objects for this class
+ class_recs = {}
+ npos = 0
+ for imagename in imagenames:
+ R = [obj for obj in recs[imagename] if obj["name"] == classname]
+ bbox = np.array([x["bbox"] for x in R])
+ difficult = np.array([x["difficult"] for x in R]).astype(bool)
+ # difficult = np.array([False for x in R]).astype(bool) # treat all "difficult" as GT
+ det = [False] * len(R)
+ npos = npos + sum(~difficult)
+ class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det}
+
+ # read dets
+ detfile = detpath.format(classname)
+ with open(detfile, "r") as f:
+ lines = f.readlines()
+
+ splitlines = [x.strip().split(" ") for x in lines]
+ image_ids = [x[0] for x in splitlines]
+ confidence = np.array([float(x[1]) for x in splitlines])
+ BB = np.array([[float(z) for z in x[2:]] for x in splitlines]).reshape(-1, 4)
+
+ # sort by confidence
+ sorted_ind = np.argsort(-confidence)
+ BB = BB[sorted_ind, :]
+ image_ids = [image_ids[x] for x in sorted_ind]
+
+ # go down dets and mark TPs and FPs
+ nd = len(image_ids)
+ tp = np.zeros(nd)
+ fp = np.zeros(nd)
+ for d in range(nd):
+ R = class_recs[image_ids[d]]
+ bb = BB[d, :].astype(float)
+ ovmax = -np.inf
+ BBGT = R["bbox"].astype(float)
+
+ if BBGT.size > 0:
+ # compute overlaps
+ # intersection
+ ixmin = np.maximum(BBGT[:, 0], bb[0])
+ iymin = np.maximum(BBGT[:, 1], bb[1])
+ ixmax = np.minimum(BBGT[:, 2], bb[2])
+ iymax = np.minimum(BBGT[:, 3], bb[3])
+ iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
+ ih = np.maximum(iymax - iymin + 1.0, 0.0)
+ inters = iw * ih
+
+ # union
+ uni = (
+ (bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0)
+ + (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0)
+ - inters
+ )
+
+ overlaps = inters / uni
+ ovmax = np.max(overlaps)
+ jmax = np.argmax(overlaps)
+
+ if ovmax > ovthresh:
+ if not R["difficult"][jmax]:
+ if not R["det"][jmax]:
+ tp[d] = 1.0
+ R["det"][jmax] = 1
+ else:
+ fp[d] = 1.0
+ else:
+ fp[d] = 1.0
+
+ # compute precision recall
+ fp = np.cumsum(fp)
+ tp = np.cumsum(tp)
+ rec = tp / float(npos)
+ # avoid divide by zero in case the first detection matches a difficult
+ # ground truth
+ prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
+ ap = voc_ap(rec, prec, use_07_metric)
+
+ return rec, prec, ap
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/rotated_coco_evaluation.py b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/rotated_coco_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d5306c3a0601ed555c7bef20e0ac4ca64264442
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/rotated_coco_evaluation.py
@@ -0,0 +1,207 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import itertools
+import json
+import numpy as np
+import os
+import torch
+from annotator.oneformer.pycocotools.cocoeval import COCOeval, maskUtils
+
+from annotator.oneformer.detectron2.structures import BoxMode, RotatedBoxes, pairwise_iou_rotated
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+from .coco_evaluation import COCOEvaluator
+
+
+class RotatedCOCOeval(COCOeval):
+ @staticmethod
+ def is_rotated(box_list):
+ if type(box_list) == np.ndarray:
+ return box_list.shape[1] == 5
+ elif type(box_list) == list:
+ if box_list == []: # cannot decide the box_dim
+ return False
+ return np.all(
+ np.array(
+ [
+ (len(obj) == 5) and ((type(obj) == list) or (type(obj) == np.ndarray))
+ for obj in box_list
+ ]
+ )
+ )
+ return False
+
+ @staticmethod
+ def boxlist_to_tensor(boxlist, output_box_dim):
+ if type(boxlist) == np.ndarray:
+ box_tensor = torch.from_numpy(boxlist)
+ elif type(boxlist) == list:
+ if boxlist == []:
+ return torch.zeros((0, output_box_dim), dtype=torch.float32)
+ else:
+ box_tensor = torch.FloatTensor(boxlist)
+ else:
+ raise Exception("Unrecognized boxlist type")
+
+ input_box_dim = box_tensor.shape[1]
+ if input_box_dim != output_box_dim:
+ if input_box_dim == 4 and output_box_dim == 5:
+ box_tensor = BoxMode.convert(box_tensor, BoxMode.XYWH_ABS, BoxMode.XYWHA_ABS)
+ else:
+ raise Exception(
+ "Unable to convert from {}-dim box to {}-dim box".format(
+ input_box_dim, output_box_dim
+ )
+ )
+ return box_tensor
+
+ def compute_iou_dt_gt(self, dt, gt, is_crowd):
+ if self.is_rotated(dt) or self.is_rotated(gt):
+ # TODO: take is_crowd into consideration
+ assert all(c == 0 for c in is_crowd)
+ dt = RotatedBoxes(self.boxlist_to_tensor(dt, output_box_dim=5))
+ gt = RotatedBoxes(self.boxlist_to_tensor(gt, output_box_dim=5))
+ return pairwise_iou_rotated(dt, gt)
+ else:
+ # This is the same as the classical COCO evaluation
+ return maskUtils.iou(dt, gt, is_crowd)
+
+ def computeIoU(self, imgId, catId):
+ p = self.params
+ if p.useCats:
+ gt = self._gts[imgId, catId]
+ dt = self._dts[imgId, catId]
+ else:
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]]
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]]
+ if len(gt) == 0 and len(dt) == 0:
+ return []
+ inds = np.argsort([-d["score"] for d in dt], kind="mergesort")
+ dt = [dt[i] for i in inds]
+ if len(dt) > p.maxDets[-1]:
+ dt = dt[0 : p.maxDets[-1]]
+
+ assert p.iouType == "bbox", "unsupported iouType for iou computation"
+
+ g = [g["bbox"] for g in gt]
+ d = [d["bbox"] for d in dt]
+
+ # compute iou between each dt and gt region
+ iscrowd = [int(o["iscrowd"]) for o in gt]
+
+ # Note: this function is copied from cocoeval.py in cocoapi
+ # and the major difference is here.
+ ious = self.compute_iou_dt_gt(d, g, iscrowd)
+ return ious
+
+
+class RotatedCOCOEvaluator(COCOEvaluator):
+ """
+ Evaluate object proposal/instance detection outputs using COCO-like metrics and APIs,
+ with rotated boxes support.
+ Note: this uses IOU only and does not consider angle differences.
+ """
+
+ def process(self, inputs, outputs):
+ """
+ Args:
+ inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
+ It is a list of dict. Each dict corresponds to an image and
+ contains keys like "height", "width", "file_name", "image_id".
+ outputs: the outputs of a COCO model. It is a list of dicts with key
+ "instances" that contains :class:`Instances`.
+ """
+ for input, output in zip(inputs, outputs):
+ prediction = {"image_id": input["image_id"]}
+
+ if "instances" in output:
+ instances = output["instances"].to(self._cpu_device)
+
+ prediction["instances"] = self.instances_to_json(instances, input["image_id"])
+ if "proposals" in output:
+ prediction["proposals"] = output["proposals"].to(self._cpu_device)
+ self._predictions.append(prediction)
+
+ def instances_to_json(self, instances, img_id):
+ num_instance = len(instances)
+ if num_instance == 0:
+ return []
+
+ boxes = instances.pred_boxes.tensor.numpy()
+ if boxes.shape[1] == 4:
+ boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
+ boxes = boxes.tolist()
+ scores = instances.scores.tolist()
+ classes = instances.pred_classes.tolist()
+
+ results = []
+ for k in range(num_instance):
+ result = {
+ "image_id": img_id,
+ "category_id": classes[k],
+ "bbox": boxes[k],
+ "score": scores[k],
+ }
+
+ results.append(result)
+ return results
+
+ def _eval_predictions(self, predictions, img_ids=None): # img_ids: unused
+ """
+ Evaluate predictions on the given tasks.
+ Fill self._results with the metrics of the tasks.
+ """
+ self._logger.info("Preparing results for COCO format ...")
+ coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
+
+ # unmap the category ids for COCO
+ if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
+ reverse_id_mapping = {
+ v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
+ }
+ for result in coco_results:
+ result["category_id"] = reverse_id_mapping[result["category_id"]]
+
+ if self._output_dir:
+ file_path = os.path.join(self._output_dir, "coco_instances_results.json")
+ self._logger.info("Saving results to {}".format(file_path))
+ with PathManager.open(file_path, "w") as f:
+ f.write(json.dumps(coco_results))
+ f.flush()
+
+ if not self._do_evaluation:
+ self._logger.info("Annotations are not available for evaluation.")
+ return
+
+ self._logger.info("Evaluating predictions ...")
+
+ assert self._tasks is None or set(self._tasks) == {
+ "bbox"
+ }, "[RotatedCOCOEvaluator] Only bbox evaluation is supported"
+ coco_eval = (
+ self._evaluate_predictions_on_coco(self._coco_api, coco_results)
+ if len(coco_results) > 0
+ else None # cocoapi does not handle empty results very well
+ )
+
+ task = "bbox"
+ res = self._derive_coco_results(
+ coco_eval, task, class_names=self._metadata.get("thing_classes")
+ )
+ self._results[task] = res
+
+ def _evaluate_predictions_on_coco(self, coco_gt, coco_results):
+ """
+ Evaluate the coco results using COCOEval API.
+ """
+ assert len(coco_results) > 0
+
+ coco_dt = coco_gt.loadRes(coco_results)
+
+ # Only bbox is supported for now
+ coco_eval = RotatedCOCOeval(coco_gt, coco_dt, iouType="bbox")
+
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+
+ return coco_eval
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/sem_seg_evaluation.py b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/sem_seg_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c2f3f5a659bc270d313efb053908d9b1e942f44
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/sem_seg_evaluation.py
@@ -0,0 +1,265 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import itertools
+import json
+import logging
+import numpy as np
+import os
+from collections import OrderedDict
+from typing import Optional, Union
+import annotator.oneformer.pycocotools.mask as mask_util
+import torch
+from PIL import Image
+
+from annotator.oneformer.detectron2.data import DatasetCatalog, MetadataCatalog
+from annotator.oneformer.detectron2.utils.comm import all_gather, is_main_process, synchronize
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+from .evaluator import DatasetEvaluator
+
+_CV2_IMPORTED = True
+try:
+ import cv2 # noqa
+except ImportError:
+ # OpenCV is an optional dependency at the moment
+ _CV2_IMPORTED = False
+
+
+def load_image_into_numpy_array(
+ filename: str,
+ copy: bool = False,
+ dtype: Optional[Union[np.dtype, str]] = None,
+) -> np.ndarray:
+ with PathManager.open(filename, "rb") as f:
+ array = np.array(Image.open(f), copy=copy, dtype=dtype)
+ return array
+
+
+class SemSegEvaluator(DatasetEvaluator):
+ """
+ Evaluate semantic segmentation metrics.
+ """
+
+ def __init__(
+ self,
+ dataset_name,
+ distributed=True,
+ output_dir=None,
+ *,
+ sem_seg_loading_fn=load_image_into_numpy_array,
+ num_classes=None,
+ ignore_label=None,
+ ):
+ """
+ Args:
+ dataset_name (str): name of the dataset to be evaluated.
+ distributed (bool): if True, will collect results from all ranks for evaluation.
+ Otherwise, will evaluate the results in the current process.
+ output_dir (str): an output directory to dump results.
+ sem_seg_loading_fn: function to read sem seg file and load into numpy array.
+ Default provided, but projects can customize.
+ num_classes, ignore_label: deprecated argument
+ """
+ self._logger = logging.getLogger(__name__)
+ if num_classes is not None:
+ self._logger.warn(
+ "SemSegEvaluator(num_classes) is deprecated! It should be obtained from metadata."
+ )
+ if ignore_label is not None:
+ self._logger.warn(
+ "SemSegEvaluator(ignore_label) is deprecated! It should be obtained from metadata."
+ )
+ self._dataset_name = dataset_name
+ self._distributed = distributed
+ self._output_dir = output_dir
+
+ self._cpu_device = torch.device("cpu")
+
+ self.input_file_to_gt_file = {
+ dataset_record["file_name"]: dataset_record["sem_seg_file_name"]
+ for dataset_record in DatasetCatalog.get(dataset_name)
+ }
+
+ meta = MetadataCatalog.get(dataset_name)
+ # Dict that maps contiguous training ids to COCO category ids
+ try:
+ c2d = meta.stuff_dataset_id_to_contiguous_id
+ self._contiguous_id_to_dataset_id = {v: k for k, v in c2d.items()}
+ except AttributeError:
+ self._contiguous_id_to_dataset_id = None
+ self._class_names = meta.stuff_classes
+ self.sem_seg_loading_fn = sem_seg_loading_fn
+ self._num_classes = len(meta.stuff_classes)
+ if num_classes is not None:
+ assert self._num_classes == num_classes, f"{self._num_classes} != {num_classes}"
+ self._ignore_label = ignore_label if ignore_label is not None else meta.ignore_label
+
+ # This is because cv2.erode did not work for int datatype. Only works for uint8.
+ self._compute_boundary_iou = True
+ if not _CV2_IMPORTED:
+ self._compute_boundary_iou = False
+ self._logger.warn(
+ """Boundary IoU calculation requires OpenCV. B-IoU metrics are
+ not going to be computed because OpenCV is not available to import."""
+ )
+ if self._num_classes >= np.iinfo(np.uint8).max:
+ self._compute_boundary_iou = False
+ self._logger.warn(
+ f"""SemSegEvaluator(num_classes) is more than supported value for Boundary IoU calculation!
+ B-IoU metrics are not going to be computed. Max allowed value (exclusive)
+ for num_classes for calculating Boundary IoU is {np.iinfo(np.uint8).max}.
+ The number of classes of dataset {self._dataset_name} is {self._num_classes}"""
+ )
+
+ def reset(self):
+ self._conf_matrix = np.zeros((self._num_classes + 1, self._num_classes + 1), dtype=np.int64)
+ self._b_conf_matrix = np.zeros(
+ (self._num_classes + 1, self._num_classes + 1), dtype=np.int64
+ )
+ self._predictions = []
+
+ def process(self, inputs, outputs):
+ """
+ Args:
+ inputs: the inputs to a model.
+ It is a list of dicts. Each dict corresponds to an image and
+ contains keys like "height", "width", "file_name".
+ outputs: the outputs of a model. It is either list of semantic segmentation predictions
+ (Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic
+ segmentation prediction in the same format.
+ """
+ for input, output in zip(inputs, outputs):
+ output = output["sem_seg"].argmax(dim=0).to(self._cpu_device)
+ pred = np.array(output, dtype=np.int)
+ gt_filename = self.input_file_to_gt_file[input["file_name"]]
+ gt = self.sem_seg_loading_fn(gt_filename, dtype=np.int)
+
+ gt[gt == self._ignore_label] = self._num_classes
+
+ self._conf_matrix += np.bincount(
+ (self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1),
+ minlength=self._conf_matrix.size,
+ ).reshape(self._conf_matrix.shape)
+
+ if self._compute_boundary_iou:
+ b_gt = self._mask_to_boundary(gt.astype(np.uint8))
+ b_pred = self._mask_to_boundary(pred.astype(np.uint8))
+
+ self._b_conf_matrix += np.bincount(
+ (self._num_classes + 1) * b_pred.reshape(-1) + b_gt.reshape(-1),
+ minlength=self._conf_matrix.size,
+ ).reshape(self._conf_matrix.shape)
+
+ self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"]))
+
+ def evaluate(self):
+ """
+ Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval):
+
+ * Mean intersection-over-union averaged across classes (mIoU)
+ * Frequency Weighted IoU (fwIoU)
+ * Mean pixel accuracy averaged across classes (mACC)
+ * Pixel Accuracy (pACC)
+ """
+ if self._distributed:
+ synchronize()
+ conf_matrix_list = all_gather(self._conf_matrix)
+ b_conf_matrix_list = all_gather(self._b_conf_matrix)
+ self._predictions = all_gather(self._predictions)
+ self._predictions = list(itertools.chain(*self._predictions))
+ if not is_main_process():
+ return
+
+ self._conf_matrix = np.zeros_like(self._conf_matrix)
+ for conf_matrix in conf_matrix_list:
+ self._conf_matrix += conf_matrix
+
+ self._b_conf_matrix = np.zeros_like(self._b_conf_matrix)
+ for b_conf_matrix in b_conf_matrix_list:
+ self._b_conf_matrix += b_conf_matrix
+
+ if self._output_dir:
+ PathManager.mkdirs(self._output_dir)
+ file_path = os.path.join(self._output_dir, "sem_seg_predictions.json")
+ with PathManager.open(file_path, "w") as f:
+ f.write(json.dumps(self._predictions))
+
+ acc = np.full(self._num_classes, np.nan, dtype=np.float)
+ iou = np.full(self._num_classes, np.nan, dtype=np.float)
+ tp = self._conf_matrix.diagonal()[:-1].astype(np.float)
+ pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float)
+ class_weights = pos_gt / np.sum(pos_gt)
+ pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float)
+ acc_valid = pos_gt > 0
+ acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid]
+ union = pos_gt + pos_pred - tp
+ iou_valid = np.logical_and(acc_valid, union > 0)
+ iou[iou_valid] = tp[iou_valid] / union[iou_valid]
+ macc = np.sum(acc[acc_valid]) / np.sum(acc_valid)
+ miou = np.sum(iou[iou_valid]) / np.sum(iou_valid)
+ fiou = np.sum(iou[iou_valid] * class_weights[iou_valid])
+ pacc = np.sum(tp) / np.sum(pos_gt)
+
+ if self._compute_boundary_iou:
+ b_iou = np.full(self._num_classes, np.nan, dtype=np.float)
+ b_tp = self._b_conf_matrix.diagonal()[:-1].astype(np.float)
+ b_pos_gt = np.sum(self._b_conf_matrix[:-1, :-1], axis=0).astype(np.float)
+ b_pos_pred = np.sum(self._b_conf_matrix[:-1, :-1], axis=1).astype(np.float)
+ b_union = b_pos_gt + b_pos_pred - b_tp
+ b_iou_valid = b_union > 0
+ b_iou[b_iou_valid] = b_tp[b_iou_valid] / b_union[b_iou_valid]
+
+ res = {}
+ res["mIoU"] = 100 * miou
+ res["fwIoU"] = 100 * fiou
+ for i, name in enumerate(self._class_names):
+ res[f"IoU-{name}"] = 100 * iou[i]
+ if self._compute_boundary_iou:
+ res[f"BoundaryIoU-{name}"] = 100 * b_iou[i]
+ res[f"min(IoU, B-Iou)-{name}"] = 100 * min(iou[i], b_iou[i])
+ res["mACC"] = 100 * macc
+ res["pACC"] = 100 * pacc
+ for i, name in enumerate(self._class_names):
+ res[f"ACC-{name}"] = 100 * acc[i]
+
+ if self._output_dir:
+ file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth")
+ with PathManager.open(file_path, "wb") as f:
+ torch.save(res, f)
+ results = OrderedDict({"sem_seg": res})
+ self._logger.info(results)
+ return results
+
+ def encode_json_sem_seg(self, sem_seg, input_file_name):
+ """
+ Convert semantic segmentation to COCO stuff format with segments encoded as RLEs.
+ See http://cocodataset.org/#format-results
+ """
+ json_list = []
+ for label in np.unique(sem_seg):
+ if self._contiguous_id_to_dataset_id is not None:
+ assert (
+ label in self._contiguous_id_to_dataset_id
+ ), "Label {} is not in the metadata info for {}".format(label, self._dataset_name)
+ dataset_id = self._contiguous_id_to_dataset_id[label]
+ else:
+ dataset_id = int(label)
+ mask = (sem_seg == label).astype(np.uint8)
+ mask_rle = mask_util.encode(np.array(mask[:, :, None], order="F"))[0]
+ mask_rle["counts"] = mask_rle["counts"].decode("utf-8")
+ json_list.append(
+ {"file_name": input_file_name, "category_id": dataset_id, "segmentation": mask_rle}
+ )
+ return json_list
+
+ def _mask_to_boundary(self, mask: np.ndarray, dilation_ratio=0.02):
+ assert mask.ndim == 2, "mask_to_boundary expects a 2-dimensional image"
+ h, w = mask.shape
+ diag_len = np.sqrt(h**2 + w**2)
+ dilation = max(1, int(round(dilation_ratio * diag_len)))
+ kernel = np.ones((3, 3), dtype=np.uint8)
+
+ padded_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0)
+ eroded_mask_with_padding = cv2.erode(padded_mask, kernel, iterations=dilation)
+ eroded_mask = eroded_mask_with_padding[1:-1, 1:-1]
+ boundary = mask - eroded_mask
+ return boundary
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/testing.py b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/testing.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e5ae625bb0593fc20739dd3ea549157e4df4f3d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/evaluation/testing.py
@@ -0,0 +1,85 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+import numpy as np
+import pprint
+import sys
+from collections.abc import Mapping
+
+
+def print_csv_format(results):
+ """
+ Print main metrics in a format similar to Detectron,
+ so that they are easy to copypaste into a spreadsheet.
+
+ Args:
+ results (OrderedDict[dict]): task_name -> {metric -> score}
+ unordered dict can also be printed, but in arbitrary order
+ """
+ assert isinstance(results, Mapping) or not len(results), results
+ logger = logging.getLogger(__name__)
+ for task, res in results.items():
+ if isinstance(res, Mapping):
+ # Don't print "AP-category" metrics since they are usually not tracked.
+ important_res = [(k, v) for k, v in res.items() if "-" not in k]
+ logger.info("copypaste: Task: {}".format(task))
+ logger.info("copypaste: " + ",".join([k[0] for k in important_res]))
+ logger.info("copypaste: " + ",".join(["{0:.4f}".format(k[1]) for k in important_res]))
+ else:
+ logger.info(f"copypaste: {task}={res}")
+
+
+def verify_results(cfg, results):
+ """
+ Args:
+ results (OrderedDict[dict]): task_name -> {metric -> score}
+
+ Returns:
+ bool: whether the verification succeeds or not
+ """
+ expected_results = cfg.TEST.EXPECTED_RESULTS
+ if not len(expected_results):
+ return True
+
+ ok = True
+ for task, metric, expected, tolerance in expected_results:
+ actual = results[task].get(metric, None)
+ if actual is None:
+ ok = False
+ continue
+ if not np.isfinite(actual):
+ ok = False
+ continue
+ diff = abs(actual - expected)
+ if diff > tolerance:
+ ok = False
+
+ logger = logging.getLogger(__name__)
+ if not ok:
+ logger.error("Result verification failed!")
+ logger.error("Expected Results: " + str(expected_results))
+ logger.error("Actual Results: " + pprint.pformat(results))
+
+ sys.exit(1)
+ else:
+ logger.info("Results verification passed.")
+ return ok
+
+
+def flatten_results_dict(results):
+ """
+ Expand a hierarchical dict of scalars into a flat dict of scalars.
+ If results[k1][k2][k3] = v, the returned dict will have the entry
+ {"k1/k2/k3": v}.
+
+ Args:
+ results (dict):
+ """
+ r = {}
+ for k, v in results.items():
+ if isinstance(v, Mapping):
+ v = flatten_results_dict(v)
+ for kk, vv in v.items():
+ r[k + "/" + kk] = vv
+ else:
+ r[k] = v
+ return r
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/export/README.md b/sd-webui-controlnet/annotator/oneformer/detectron2/export/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c86ff62516f4e8e4b1a6c1f33f11192933cf3861
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/export/README.md
@@ -0,0 +1,15 @@
+
+This directory contains code to prepare a detectron2 model for deployment.
+Currently it supports exporting a detectron2 model to TorchScript, ONNX, or (deprecated) Caffe2 format.
+
+Please see [documentation](https://detectron2.readthedocs.io/tutorials/deployment.html) for its usage.
+
+
+### Acknowledgements
+
+Thanks to Mobile Vision team at Facebook for developing the Caffe2 conversion tools.
+
+Thanks to Computing Platform Department - PAI team at Alibaba Group (@bddpqq, @chenbohua3) who
+help export Detectron2 models to TorchScript.
+
+Thanks to ONNX Converter team at Microsoft who help export Detectron2 models to ONNX.
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/export/__init__.py b/sd-webui-controlnet/annotator/oneformer/detectron2/export/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a58758f64aae6071fa688be4400622ce6036efa
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/export/__init__.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+
+import warnings
+
+from .flatten import TracingAdapter
+from .torchscript import dump_torchscript_IR, scripting_with_instances
+
+try:
+ from caffe2.proto import caffe2_pb2 as _tmp
+ from caffe2.python import core
+
+ # caffe2 is optional
+except ImportError:
+ pass
+else:
+ from .api import *
+
+
+# TODO: Update ONNX Opset version and run tests when a newer PyTorch is supported
+STABLE_ONNX_OPSET_VERSION = 11
+
+
+def add_export_config(cfg):
+ warnings.warn(
+ "add_export_config has been deprecated and behaves as no-op function.", DeprecationWarning
+ )
+ return cfg
+
+
+__all__ = [k for k in globals().keys() if not k.startswith("_")]
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/export/api.py b/sd-webui-controlnet/annotator/oneformer/detectron2/export/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf1a27a4806ca83d97f5cd8c27726ec29f4e7e50
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/export/api.py
@@ -0,0 +1,230 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import copy
+import logging
+import os
+import torch
+from caffe2.proto import caffe2_pb2
+from torch import nn
+
+from annotator.oneformer.detectron2.config import CfgNode
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+from .caffe2_inference import ProtobufDetectionModel
+from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
+from .shared import get_pb_arg_vali, get_pb_arg_vals, save_graph
+
+__all__ = [
+ "Caffe2Model",
+ "Caffe2Tracer",
+]
+
+
+class Caffe2Tracer:
+ """
+ Make a detectron2 model traceable with Caffe2 operators.
+ This class creates a traceable version of a detectron2 model which:
+
+ 1. Rewrite parts of the model using ops in Caffe2. Note that some ops do
+ not have GPU implementation in Caffe2.
+ 2. Remove post-processing and only produce raw layer outputs
+
+ After making a traceable model, the class provide methods to export such a
+ model to different deployment formats.
+ Exported graph produced by this class take two input tensors:
+
+ 1. (1, C, H, W) float "data" which is an image (usually in [0, 255]).
+ (H, W) often has to be padded to multiple of 32 (depend on the model
+ architecture).
+ 2. 1x3 float "im_info", each row of which is (height, width, 1.0).
+ Height and width are true image shapes before padding.
+
+ The class currently only supports models using builtin meta architectures.
+ Batch inference is not supported, and contributions are welcome.
+ """
+
+ def __init__(self, cfg: CfgNode, model: nn.Module, inputs):
+ """
+ Args:
+ cfg (CfgNode): a detectron2 config used to construct caffe2-compatible model.
+ model (nn.Module): An original pytorch model. Must be among a few official models
+ in detectron2 that can be converted to become caffe2-compatible automatically.
+ Weights have to be already loaded to this model.
+ inputs: sample inputs that the given model takes for inference.
+ Will be used to trace the model. For most models, random inputs with
+ no detected objects will not work as they lead to wrong traces.
+ """
+ assert isinstance(cfg, CfgNode), cfg
+ assert isinstance(model, torch.nn.Module), type(model)
+
+ # TODO make it support custom models, by passing in c2 model directly
+ C2MetaArch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[cfg.MODEL.META_ARCHITECTURE]
+ self.traceable_model = C2MetaArch(cfg, copy.deepcopy(model))
+ self.inputs = inputs
+ self.traceable_inputs = self.traceable_model.get_caffe2_inputs(inputs)
+
+ def export_caffe2(self):
+ """
+ Export the model to Caffe2's protobuf format.
+ The returned object can be saved with its :meth:`.save_protobuf()` method.
+ The result can be loaded and executed using Caffe2 runtime.
+
+ Returns:
+ :class:`Caffe2Model`
+ """
+ from .caffe2_export import export_caffe2_detection_model
+
+ predict_net, init_net = export_caffe2_detection_model(
+ self.traceable_model, self.traceable_inputs
+ )
+ return Caffe2Model(predict_net, init_net)
+
+ def export_onnx(self):
+ """
+ Export the model to ONNX format.
+ Note that the exported model contains custom ops only available in caffe2, therefore it
+ cannot be directly executed by other runtime (such as onnxruntime or TensorRT).
+ Post-processing or transformation passes may be applied on the model to accommodate
+ different runtimes, but we currently do not provide support for them.
+
+ Returns:
+ onnx.ModelProto: an onnx model.
+ """
+ from .caffe2_export import export_onnx_model as export_onnx_model_impl
+
+ return export_onnx_model_impl(self.traceable_model, (self.traceable_inputs,))
+
+ def export_torchscript(self):
+ """
+ Export the model to a ``torch.jit.TracedModule`` by tracing.
+ The returned object can be saved to a file by ``.save()``.
+
+ Returns:
+ torch.jit.TracedModule: a torch TracedModule
+ """
+ logger = logging.getLogger(__name__)
+ logger.info("Tracing the model with torch.jit.trace ...")
+ with torch.no_grad():
+ return torch.jit.trace(self.traceable_model, (self.traceable_inputs,))
+
+
+class Caffe2Model(nn.Module):
+ """
+ A wrapper around the traced model in Caffe2's protobuf format.
+ The exported graph has different inputs/outputs from the original Pytorch
+ model, as explained in :class:`Caffe2Tracer`. This class wraps around the
+ exported graph to simulate the same interface as the original Pytorch model.
+ It also provides functions to save/load models in Caffe2's format.'
+
+ Examples:
+ ::
+ c2_model = Caffe2Tracer(cfg, torch_model, inputs).export_caffe2()
+ inputs = [{"image": img_tensor_CHW}]
+ outputs = c2_model(inputs)
+ orig_outputs = torch_model(inputs)
+ """
+
+ def __init__(self, predict_net, init_net):
+ super().__init__()
+ self.eval() # always in eval mode
+ self._predict_net = predict_net
+ self._init_net = init_net
+ self._predictor = None
+
+ __init__.__HIDE_SPHINX_DOC__ = True
+
+ @property
+ def predict_net(self):
+ """
+ caffe2.core.Net: the underlying caffe2 predict net
+ """
+ return self._predict_net
+
+ @property
+ def init_net(self):
+ """
+ caffe2.core.Net: the underlying caffe2 init net
+ """
+ return self._init_net
+
+ def save_protobuf(self, output_dir):
+ """
+ Save the model as caffe2's protobuf format.
+ It saves the following files:
+
+ * "model.pb": definition of the graph. Can be visualized with
+ tools like `netron `_.
+ * "model_init.pb": model parameters
+ * "model.pbtxt": human-readable definition of the graph. Not
+ needed for deployment.
+
+ Args:
+ output_dir (str): the output directory to save protobuf files.
+ """
+ logger = logging.getLogger(__name__)
+ logger.info("Saving model to {} ...".format(output_dir))
+ if not PathManager.exists(output_dir):
+ PathManager.mkdirs(output_dir)
+
+ with PathManager.open(os.path.join(output_dir, "model.pb"), "wb") as f:
+ f.write(self._predict_net.SerializeToString())
+ with PathManager.open(os.path.join(output_dir, "model.pbtxt"), "w") as f:
+ f.write(str(self._predict_net))
+ with PathManager.open(os.path.join(output_dir, "model_init.pb"), "wb") as f:
+ f.write(self._init_net.SerializeToString())
+
+ def save_graph(self, output_file, inputs=None):
+ """
+ Save the graph as SVG format.
+
+ Args:
+ output_file (str): a SVG file
+ inputs: optional inputs given to the model.
+ If given, the inputs will be used to run the graph to record
+ shape of every tensor. The shape information will be
+ saved together with the graph.
+ """
+ from .caffe2_export import run_and_save_graph
+
+ if inputs is None:
+ save_graph(self._predict_net, output_file, op_only=False)
+ else:
+ size_divisibility = get_pb_arg_vali(self._predict_net, "size_divisibility", 0)
+ device = get_pb_arg_vals(self._predict_net, "device", b"cpu").decode("ascii")
+ inputs = convert_batched_inputs_to_c2_format(inputs, size_divisibility, device)
+ inputs = [x.cpu().numpy() for x in inputs]
+ run_and_save_graph(self._predict_net, self._init_net, inputs, output_file)
+
+ @staticmethod
+ def load_protobuf(dir):
+ """
+ Args:
+ dir (str): a directory used to save Caffe2Model with
+ :meth:`save_protobuf`.
+ The files "model.pb" and "model_init.pb" are needed.
+
+ Returns:
+ Caffe2Model: the caffe2 model loaded from this directory.
+ """
+ predict_net = caffe2_pb2.NetDef()
+ with PathManager.open(os.path.join(dir, "model.pb"), "rb") as f:
+ predict_net.ParseFromString(f.read())
+
+ init_net = caffe2_pb2.NetDef()
+ with PathManager.open(os.path.join(dir, "model_init.pb"), "rb") as f:
+ init_net.ParseFromString(f.read())
+
+ return Caffe2Model(predict_net, init_net)
+
+ def __call__(self, inputs):
+ """
+ An interface that wraps around a Caffe2 model and mimics detectron2's models'
+ input/output format. See details about the format at :doc:`/tutorials/models`.
+ This is used to compare the outputs of caffe2 model with its original torch model.
+
+ Due to the extra conversion between Pytorch/Caffe2, this method is not meant for
+ benchmark. Because of the conversion, this method also has dependency
+ on detectron2 in order to convert to detectron2's output format.
+ """
+ if self._predictor is None:
+ self._predictor = ProtobufDetectionModel(self._predict_net, self._init_net)
+ return self._predictor(inputs)
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/export/c10.py b/sd-webui-controlnet/annotator/oneformer/detectron2/export/c10.py
new file mode 100644
index 0000000000000000000000000000000000000000..fde3fb71189e6f1061e83b878bfdd16add7d8350
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/export/c10.py
@@ -0,0 +1,557 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import math
+from typing import Dict
+import torch
+import torch.nn.functional as F
+
+from annotator.oneformer.detectron2.layers import ShapeSpec, cat
+from annotator.oneformer.detectron2.layers.roi_align_rotated import ROIAlignRotated
+from annotator.oneformer.detectron2.modeling import poolers
+from annotator.oneformer.detectron2.modeling.proposal_generator import rpn
+from annotator.oneformer.detectron2.modeling.roi_heads.mask_head import mask_rcnn_inference
+from annotator.oneformer.detectron2.structures import Boxes, ImageList, Instances, Keypoints, RotatedBoxes
+
+from .shared import alias, to_device
+
+
+"""
+This file contains caffe2-compatible implementation of several detectron2 components.
+"""
+
+
+class Caffe2Boxes(Boxes):
+ """
+ Representing a list of detectron2.structures.Boxes from minibatch, each box
+ is represented by a 5d vector (batch index + 4 coordinates), or a 6d vector
+ (batch index + 5 coordinates) for RotatedBoxes.
+ """
+
+ def __init__(self, tensor):
+ assert isinstance(tensor, torch.Tensor)
+ assert tensor.dim() == 2 and tensor.size(-1) in [4, 5, 6], tensor.size()
+ # TODO: make tensor immutable when dim is Nx5 for Boxes,
+ # and Nx6 for RotatedBoxes?
+ self.tensor = tensor
+
+
+# TODO clean up this class, maybe just extend Instances
+class InstancesList(object):
+ """
+ Tensor representation of a list of Instances object for a batch of images.
+
+ When dealing with a batch of images with Caffe2 ops, a list of bboxes
+ (instances) are usually represented by single Tensor with size
+ (sigma(Ni), 5) or (sigma(Ni), 4) plus a batch split Tensor. This class is
+ for providing common functions to convert between these two representations.
+ """
+
+ def __init__(self, im_info, indices, extra_fields=None):
+ # [N, 3] -> (H, W, Scale)
+ self.im_info = im_info
+ # [N,] -> indice of batch to which the instance belongs
+ self.indices = indices
+ # [N, ...]
+ self.batch_extra_fields = extra_fields or {}
+
+ self.image_size = self.im_info
+
+ def get_fields(self):
+ """like `get_fields` in the Instances object,
+ but return each field in tensor representations"""
+ ret = {}
+ for k, v in self.batch_extra_fields.items():
+ # if isinstance(v, torch.Tensor):
+ # tensor_rep = v
+ # elif isinstance(v, (Boxes, Keypoints)):
+ # tensor_rep = v.tensor
+ # else:
+ # raise ValueError("Can't find tensor representation for: {}".format())
+ ret[k] = v
+ return ret
+
+ def has(self, name):
+ return name in self.batch_extra_fields
+
+ def set(self, name, value):
+ # len(tensor) is a bad practice that generates ONNX constants during tracing.
+ # Although not a problem for the `assert` statement below, torch ONNX exporter
+ # still raises a misleading warning as it does not this call comes from `assert`
+ if isinstance(value, Boxes):
+ data_len = value.tensor.shape[0]
+ elif isinstance(value, torch.Tensor):
+ data_len = value.shape[0]
+ else:
+ data_len = len(value)
+ if len(self.batch_extra_fields):
+ assert (
+ len(self) == data_len
+ ), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self))
+ self.batch_extra_fields[name] = value
+
+ def __getattr__(self, name):
+ if name not in self.batch_extra_fields:
+ raise AttributeError("Cannot find field '{}' in the given Instances!".format(name))
+ return self.batch_extra_fields[name]
+
+ def __len__(self):
+ return len(self.indices)
+
+ def flatten(self):
+ ret = []
+ for _, v in self.batch_extra_fields.items():
+ if isinstance(v, (Boxes, Keypoints)):
+ ret.append(v.tensor)
+ else:
+ ret.append(v)
+ return ret
+
+ @staticmethod
+ def to_d2_instances_list(instances_list):
+ """
+ Convert InstancesList to List[Instances]. The input `instances_list` can
+ also be a List[Instances], in this case this method is a non-op.
+ """
+ if not isinstance(instances_list, InstancesList):
+ assert all(isinstance(x, Instances) for x in instances_list)
+ return instances_list
+
+ ret = []
+ for i, info in enumerate(instances_list.im_info):
+ instances = Instances(torch.Size([int(info[0].item()), int(info[1].item())]))
+
+ ids = instances_list.indices == i
+ for k, v in instances_list.batch_extra_fields.items():
+ if isinstance(v, torch.Tensor):
+ instances.set(k, v[ids])
+ continue
+ elif isinstance(v, Boxes):
+ instances.set(k, v[ids, -4:])
+ continue
+
+ target_type, tensor_source = v
+ assert isinstance(tensor_source, torch.Tensor)
+ assert tensor_source.shape[0] == instances_list.indices.shape[0]
+ tensor_source = tensor_source[ids]
+
+ if issubclass(target_type, Boxes):
+ instances.set(k, Boxes(tensor_source[:, -4:]))
+ elif issubclass(target_type, Keypoints):
+ instances.set(k, Keypoints(tensor_source))
+ elif issubclass(target_type, torch.Tensor):
+ instances.set(k, tensor_source)
+ else:
+ raise ValueError("Can't handle targe type: {}".format(target_type))
+
+ ret.append(instances)
+ return ret
+
+
+class Caffe2Compatible(object):
+ """
+ A model can inherit this class to indicate that it can be traced and deployed with caffe2.
+ """
+
+ def _get_tensor_mode(self):
+ return self._tensor_mode
+
+ def _set_tensor_mode(self, v):
+ self._tensor_mode = v
+
+ tensor_mode = property(_get_tensor_mode, _set_tensor_mode)
+ """
+ If true, the model expects C2-style tensor only inputs/outputs format.
+ """
+
+
+class Caffe2RPN(Caffe2Compatible, rpn.RPN):
+ @classmethod
+ def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+ ret = super(Caffe2Compatible, cls).from_config(cfg, input_shape)
+ assert tuple(cfg.MODEL.RPN.BBOX_REG_WEIGHTS) == (1.0, 1.0, 1.0, 1.0) or tuple(
+ cfg.MODEL.RPN.BBOX_REG_WEIGHTS
+ ) == (1.0, 1.0, 1.0, 1.0, 1.0)
+ return ret
+
+ def _generate_proposals(
+ self, images, objectness_logits_pred, anchor_deltas_pred, gt_instances=None
+ ):
+ assert isinstance(images, ImageList)
+ if self.tensor_mode:
+ im_info = images.image_sizes
+ else:
+ im_info = torch.tensor([[im_sz[0], im_sz[1], 1.0] for im_sz in images.image_sizes]).to(
+ images.tensor.device
+ )
+ assert isinstance(im_info, torch.Tensor)
+
+ rpn_rois_list = []
+ rpn_roi_probs_list = []
+ for scores, bbox_deltas, cell_anchors_tensor, feat_stride in zip(
+ objectness_logits_pred,
+ anchor_deltas_pred,
+ [b for (n, b) in self.anchor_generator.cell_anchors.named_buffers()],
+ self.anchor_generator.strides,
+ ):
+ scores = scores.detach()
+ bbox_deltas = bbox_deltas.detach()
+
+ rpn_rois, rpn_roi_probs = torch.ops._caffe2.GenerateProposals(
+ scores,
+ bbox_deltas,
+ im_info,
+ cell_anchors_tensor,
+ spatial_scale=1.0 / feat_stride,
+ pre_nms_topN=self.pre_nms_topk[self.training],
+ post_nms_topN=self.post_nms_topk[self.training],
+ nms_thresh=self.nms_thresh,
+ min_size=self.min_box_size,
+ # correct_transform_coords=True, # deprecated argument
+ angle_bound_on=True, # Default
+ angle_bound_lo=-180,
+ angle_bound_hi=180,
+ clip_angle_thresh=1.0, # Default
+ legacy_plus_one=False,
+ )
+ rpn_rois_list.append(rpn_rois)
+ rpn_roi_probs_list.append(rpn_roi_probs)
+
+ # For FPN in D2, in RPN all proposals from different levels are concated
+ # together, ranked and picked by top post_nms_topk. Then in ROIPooler
+ # it calculates level_assignments and calls the RoIAlign from
+ # the corresponding level.
+
+ if len(objectness_logits_pred) == 1:
+ rpn_rois = rpn_rois_list[0]
+ rpn_roi_probs = rpn_roi_probs_list[0]
+ else:
+ assert len(rpn_rois_list) == len(rpn_roi_probs_list)
+ rpn_post_nms_topN = self.post_nms_topk[self.training]
+
+ device = rpn_rois_list[0].device
+ input_list = [to_device(x, "cpu") for x in (rpn_rois_list + rpn_roi_probs_list)]
+
+ # TODO remove this after confirming rpn_max_level/rpn_min_level
+ # is not needed in CollectRpnProposals.
+ feature_strides = list(self.anchor_generator.strides)
+ rpn_min_level = int(math.log2(feature_strides[0]))
+ rpn_max_level = int(math.log2(feature_strides[-1]))
+ assert (rpn_max_level - rpn_min_level + 1) == len(
+ rpn_rois_list
+ ), "CollectRpnProposals requires continuous levels"
+
+ rpn_rois = torch.ops._caffe2.CollectRpnProposals(
+ input_list,
+ # NOTE: in current implementation, rpn_max_level and rpn_min_level
+ # are not needed, only the subtraction of two matters and it
+ # can be infer from the number of inputs. Keep them now for
+ # consistency.
+ rpn_max_level=2 + len(rpn_rois_list) - 1,
+ rpn_min_level=2,
+ rpn_post_nms_topN=rpn_post_nms_topN,
+ )
+ rpn_rois = to_device(rpn_rois, device)
+ rpn_roi_probs = []
+
+ proposals = self.c2_postprocess(im_info, rpn_rois, rpn_roi_probs, self.tensor_mode)
+ return proposals, {}
+
+ def forward(self, images, features, gt_instances=None):
+ assert not self.training
+ features = [features[f] for f in self.in_features]
+ objectness_logits_pred, anchor_deltas_pred = self.rpn_head(features)
+ return self._generate_proposals(
+ images,
+ objectness_logits_pred,
+ anchor_deltas_pred,
+ gt_instances,
+ )
+
+ @staticmethod
+ def c2_postprocess(im_info, rpn_rois, rpn_roi_probs, tensor_mode):
+ proposals = InstancesList(
+ im_info=im_info,
+ indices=rpn_rois[:, 0],
+ extra_fields={
+ "proposal_boxes": Caffe2Boxes(rpn_rois),
+ "objectness_logits": (torch.Tensor, rpn_roi_probs),
+ },
+ )
+ if not tensor_mode:
+ proposals = InstancesList.to_d2_instances_list(proposals)
+ else:
+ proposals = [proposals]
+ return proposals
+
+
+class Caffe2ROIPooler(Caffe2Compatible, poolers.ROIPooler):
+ @staticmethod
+ def c2_preprocess(box_lists):
+ assert all(isinstance(x, Boxes) for x in box_lists)
+ if all(isinstance(x, Caffe2Boxes) for x in box_lists):
+ # input is pure-tensor based
+ assert len(box_lists) == 1
+ pooler_fmt_boxes = box_lists[0].tensor
+ else:
+ pooler_fmt_boxes = poolers.convert_boxes_to_pooler_format(box_lists)
+ return pooler_fmt_boxes
+
+ def forward(self, x, box_lists):
+ assert not self.training
+
+ pooler_fmt_boxes = self.c2_preprocess(box_lists)
+ num_level_assignments = len(self.level_poolers)
+
+ if num_level_assignments == 1:
+ if isinstance(self.level_poolers[0], ROIAlignRotated):
+ c2_roi_align = torch.ops._caffe2.RoIAlignRotated
+ aligned = True
+ else:
+ c2_roi_align = torch.ops._caffe2.RoIAlign
+ aligned = self.level_poolers[0].aligned
+
+ x0 = x[0]
+ if x0.is_quantized:
+ x0 = x0.dequantize()
+
+ out = c2_roi_align(
+ x0,
+ pooler_fmt_boxes,
+ order="NCHW",
+ spatial_scale=float(self.level_poolers[0].spatial_scale),
+ pooled_h=int(self.output_size[0]),
+ pooled_w=int(self.output_size[1]),
+ sampling_ratio=int(self.level_poolers[0].sampling_ratio),
+ aligned=aligned,
+ )
+ return out
+
+ device = pooler_fmt_boxes.device
+ assert (
+ self.max_level - self.min_level + 1 == 4
+ ), "Currently DistributeFpnProposals only support 4 levels"
+ fpn_outputs = torch.ops._caffe2.DistributeFpnProposals(
+ to_device(pooler_fmt_boxes, "cpu"),
+ roi_canonical_scale=self.canonical_box_size,
+ roi_canonical_level=self.canonical_level,
+ roi_max_level=self.max_level,
+ roi_min_level=self.min_level,
+ legacy_plus_one=False,
+ )
+ fpn_outputs = [to_device(x, device) for x in fpn_outputs]
+
+ rois_fpn_list = fpn_outputs[:-1]
+ rois_idx_restore_int32 = fpn_outputs[-1]
+
+ roi_feat_fpn_list = []
+ for roi_fpn, x_level, pooler in zip(rois_fpn_list, x, self.level_poolers):
+ if isinstance(pooler, ROIAlignRotated):
+ c2_roi_align = torch.ops._caffe2.RoIAlignRotated
+ aligned = True
+ else:
+ c2_roi_align = torch.ops._caffe2.RoIAlign
+ aligned = bool(pooler.aligned)
+
+ if x_level.is_quantized:
+ x_level = x_level.dequantize()
+
+ roi_feat_fpn = c2_roi_align(
+ x_level,
+ roi_fpn,
+ order="NCHW",
+ spatial_scale=float(pooler.spatial_scale),
+ pooled_h=int(self.output_size[0]),
+ pooled_w=int(self.output_size[1]),
+ sampling_ratio=int(pooler.sampling_ratio),
+ aligned=aligned,
+ )
+ roi_feat_fpn_list.append(roi_feat_fpn)
+
+ roi_feat_shuffled = cat(roi_feat_fpn_list, dim=0)
+ assert roi_feat_shuffled.numel() > 0 and rois_idx_restore_int32.numel() > 0, (
+ "Caffe2 export requires tracing with a model checkpoint + input that can produce valid"
+ " detections. But no detections were obtained with the given checkpoint and input!"
+ )
+ roi_feat = torch.ops._caffe2.BatchPermutation(roi_feat_shuffled, rois_idx_restore_int32)
+ return roi_feat
+
+
+class Caffe2FastRCNNOutputsInference:
+ def __init__(self, tensor_mode):
+ self.tensor_mode = tensor_mode # whether the output is caffe2 tensor mode
+
+ def __call__(self, box_predictor, predictions, proposals):
+ """equivalent to FastRCNNOutputLayers.inference"""
+ num_classes = box_predictor.num_classes
+ score_thresh = box_predictor.test_score_thresh
+ nms_thresh = box_predictor.test_nms_thresh
+ topk_per_image = box_predictor.test_topk_per_image
+ is_rotated = len(box_predictor.box2box_transform.weights) == 5
+
+ if is_rotated:
+ box_dim = 5
+ assert box_predictor.box2box_transform.weights[4] == 1, (
+ "The weights for Rotated BBoxTransform in C2 have only 4 dimensions,"
+ + " thus enforcing the angle weight to be 1 for now"
+ )
+ box2box_transform_weights = box_predictor.box2box_transform.weights[:4]
+ else:
+ box_dim = 4
+ box2box_transform_weights = box_predictor.box2box_transform.weights
+
+ class_logits, box_regression = predictions
+ if num_classes + 1 == class_logits.shape[1]:
+ class_prob = F.softmax(class_logits, -1)
+ else:
+ assert num_classes == class_logits.shape[1]
+ class_prob = F.sigmoid(class_logits)
+ # BoxWithNMSLimit will infer num_classes from the shape of the class_prob
+ # So append a zero column as placeholder for the background class
+ class_prob = torch.cat((class_prob, torch.zeros(class_prob.shape[0], 1)), dim=1)
+
+ assert box_regression.shape[1] % box_dim == 0
+ cls_agnostic_bbox_reg = box_regression.shape[1] // box_dim == 1
+
+ input_tensor_mode = proposals[0].proposal_boxes.tensor.shape[1] == box_dim + 1
+
+ proposal_boxes = proposals[0].proposal_boxes
+ if isinstance(proposal_boxes, Caffe2Boxes):
+ rois = Caffe2Boxes.cat([p.proposal_boxes for p in proposals])
+ elif isinstance(proposal_boxes, RotatedBoxes):
+ rois = RotatedBoxes.cat([p.proposal_boxes for p in proposals])
+ elif isinstance(proposal_boxes, Boxes):
+ rois = Boxes.cat([p.proposal_boxes for p in proposals])
+ else:
+ raise NotImplementedError(
+ 'Expected proposals[0].proposal_boxes to be type "Boxes", '
+ f"instead got {type(proposal_boxes)}"
+ )
+
+ device, dtype = rois.tensor.device, rois.tensor.dtype
+ if input_tensor_mode:
+ im_info = proposals[0].image_size
+ rois = rois.tensor
+ else:
+ im_info = torch.tensor(
+ [[sz[0], sz[1], 1.0] for sz in [x.image_size for x in proposals]]
+ )
+ batch_ids = cat(
+ [
+ torch.full((b, 1), i, dtype=dtype, device=device)
+ for i, b in enumerate(len(p) for p in proposals)
+ ],
+ dim=0,
+ )
+ rois = torch.cat([batch_ids, rois.tensor], dim=1)
+
+ roi_pred_bbox, roi_batch_splits = torch.ops._caffe2.BBoxTransform(
+ to_device(rois, "cpu"),
+ to_device(box_regression, "cpu"),
+ to_device(im_info, "cpu"),
+ weights=box2box_transform_weights,
+ apply_scale=True,
+ rotated=is_rotated,
+ angle_bound_on=True,
+ angle_bound_lo=-180,
+ angle_bound_hi=180,
+ clip_angle_thresh=1.0,
+ legacy_plus_one=False,
+ )
+ roi_pred_bbox = to_device(roi_pred_bbox, device)
+ roi_batch_splits = to_device(roi_batch_splits, device)
+
+ nms_outputs = torch.ops._caffe2.BoxWithNMSLimit(
+ to_device(class_prob, "cpu"),
+ to_device(roi_pred_bbox, "cpu"),
+ to_device(roi_batch_splits, "cpu"),
+ score_thresh=float(score_thresh),
+ nms=float(nms_thresh),
+ detections_per_im=int(topk_per_image),
+ soft_nms_enabled=False,
+ soft_nms_method="linear",
+ soft_nms_sigma=0.5,
+ soft_nms_min_score_thres=0.001,
+ rotated=is_rotated,
+ cls_agnostic_bbox_reg=cls_agnostic_bbox_reg,
+ input_boxes_include_bg_cls=False,
+ output_classes_include_bg_cls=False,
+ legacy_plus_one=False,
+ )
+ roi_score_nms = to_device(nms_outputs[0], device)
+ roi_bbox_nms = to_device(nms_outputs[1], device)
+ roi_class_nms = to_device(nms_outputs[2], device)
+ roi_batch_splits_nms = to_device(nms_outputs[3], device)
+ roi_keeps_nms = to_device(nms_outputs[4], device)
+ roi_keeps_size_nms = to_device(nms_outputs[5], device)
+ if not self.tensor_mode:
+ roi_class_nms = roi_class_nms.to(torch.int64)
+
+ roi_batch_ids = cat(
+ [
+ torch.full((b, 1), i, dtype=dtype, device=device)
+ for i, b in enumerate(int(x.item()) for x in roi_batch_splits_nms)
+ ],
+ dim=0,
+ )
+
+ roi_class_nms = alias(roi_class_nms, "class_nms")
+ roi_score_nms = alias(roi_score_nms, "score_nms")
+ roi_bbox_nms = alias(roi_bbox_nms, "bbox_nms")
+ roi_batch_splits_nms = alias(roi_batch_splits_nms, "batch_splits_nms")
+ roi_keeps_nms = alias(roi_keeps_nms, "keeps_nms")
+ roi_keeps_size_nms = alias(roi_keeps_size_nms, "keeps_size_nms")
+
+ results = InstancesList(
+ im_info=im_info,
+ indices=roi_batch_ids[:, 0],
+ extra_fields={
+ "pred_boxes": Caffe2Boxes(roi_bbox_nms),
+ "scores": roi_score_nms,
+ "pred_classes": roi_class_nms,
+ },
+ )
+
+ if not self.tensor_mode:
+ results = InstancesList.to_d2_instances_list(results)
+ batch_splits = roi_batch_splits_nms.int().tolist()
+ kept_indices = list(roi_keeps_nms.to(torch.int64).split(batch_splits))
+ else:
+ results = [results]
+ kept_indices = [roi_keeps_nms]
+
+ return results, kept_indices
+
+
+class Caffe2MaskRCNNInference:
+ def __call__(self, pred_mask_logits, pred_instances):
+ """equivalent to mask_head.mask_rcnn_inference"""
+ if all(isinstance(x, InstancesList) for x in pred_instances):
+ assert len(pred_instances) == 1
+ mask_probs_pred = pred_mask_logits.sigmoid()
+ mask_probs_pred = alias(mask_probs_pred, "mask_fcn_probs")
+ pred_instances[0].set("pred_masks", mask_probs_pred)
+ else:
+ mask_rcnn_inference(pred_mask_logits, pred_instances)
+
+
+class Caffe2KeypointRCNNInference:
+ def __init__(self, use_heatmap_max_keypoint):
+ self.use_heatmap_max_keypoint = use_heatmap_max_keypoint
+
+ def __call__(self, pred_keypoint_logits, pred_instances):
+ # just return the keypoint heatmap for now,
+ # there will be option to call HeatmapMaxKeypointOp
+ output = alias(pred_keypoint_logits, "kps_score")
+ if all(isinstance(x, InstancesList) for x in pred_instances):
+ assert len(pred_instances) == 1
+ if self.use_heatmap_max_keypoint:
+ device = output.device
+ output = torch.ops._caffe2.HeatmapMaxKeypoint(
+ to_device(output, "cpu"),
+ pred_instances[0].pred_boxes.tensor,
+ should_output_softmax=True, # worth make it configerable?
+ )
+ output = to_device(output, device)
+ output = alias(output, "keypoints_out")
+ pred_instances[0].set("pred_keypoints", output)
+ return pred_keypoint_logits
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/export/caffe2_export.py b/sd-webui-controlnet/annotator/oneformer/detectron2/export/caffe2_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..d609c27c7deb396352967dbcbc79b1e00f2a2de1
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/export/caffe2_export.py
@@ -0,0 +1,203 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import copy
+import io
+import logging
+import numpy as np
+from typing import List
+import onnx
+import onnx.optimizer
+import torch
+from caffe2.proto import caffe2_pb2
+from caffe2.python import core
+from caffe2.python.onnx.backend import Caffe2Backend
+from tabulate import tabulate
+from termcolor import colored
+from torch.onnx import OperatorExportTypes
+
+from .shared import (
+ ScopedWS,
+ construct_init_net_from_params,
+ fuse_alias_placeholder,
+ fuse_copy_between_cpu_and_gpu,
+ get_params_from_init_net,
+ group_norm_replace_aten_with_caffe2,
+ infer_device_type,
+ remove_dead_end_ops,
+ remove_reshape_for_fc,
+ save_graph,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def export_onnx_model(model, inputs):
+ """
+ Trace and export a model to onnx format.
+
+ Args:
+ model (nn.Module):
+ inputs (tuple[args]): the model will be called by `model(*inputs)`
+
+ Returns:
+ an onnx model
+ """
+ assert isinstance(model, torch.nn.Module)
+
+ # make sure all modules are in eval mode, onnx may change the training state
+ # of the module if the states are not consistent
+ def _check_eval(module):
+ assert not module.training
+
+ model.apply(_check_eval)
+
+ # Export the model to ONNX
+ with torch.no_grad():
+ with io.BytesIO() as f:
+ torch.onnx.export(
+ model,
+ inputs,
+ f,
+ operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
+ # verbose=True, # NOTE: uncomment this for debugging
+ # export_params=True,
+ )
+ onnx_model = onnx.load_from_string(f.getvalue())
+
+ return onnx_model
+
+
+def _op_stats(net_def):
+ type_count = {}
+ for t in [op.type for op in net_def.op]:
+ type_count[t] = type_count.get(t, 0) + 1
+ type_count_list = sorted(type_count.items(), key=lambda kv: kv[0]) # alphabet
+ type_count_list = sorted(type_count_list, key=lambda kv: -kv[1]) # count
+ return "\n".join("{:>4}x {}".format(count, name) for name, count in type_count_list)
+
+
+def _assign_device_option(
+ predict_net: caffe2_pb2.NetDef, init_net: caffe2_pb2.NetDef, tensor_inputs: List[torch.Tensor]
+):
+ """
+ ONNX exported network doesn't have concept of device, assign necessary
+ device option for each op in order to make it runable on GPU runtime.
+ """
+
+ def _get_device_type(torch_tensor):
+ assert torch_tensor.device.type in ["cpu", "cuda"]
+ assert torch_tensor.device.index == 0
+ return torch_tensor.device.type
+
+ def _assign_op_device_option(net_proto, net_ssa, blob_device_types):
+ for op, ssa_i in zip(net_proto.op, net_ssa):
+ if op.type in ["CopyCPUToGPU", "CopyGPUToCPU"]:
+ op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))
+ else:
+ devices = [blob_device_types[b] for b in ssa_i[0] + ssa_i[1]]
+ assert all(d == devices[0] for d in devices)
+ if devices[0] == "cuda":
+ op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0))
+
+ # update ops in predict_net
+ predict_net_input_device_types = {
+ (name, 0): _get_device_type(tensor)
+ for name, tensor in zip(predict_net.external_input, tensor_inputs)
+ }
+ predict_net_device_types = infer_device_type(
+ predict_net, known_status=predict_net_input_device_types, device_name_style="pytorch"
+ )
+ predict_net_ssa, _ = core.get_ssa(predict_net)
+ _assign_op_device_option(predict_net, predict_net_ssa, predict_net_device_types)
+
+ # update ops in init_net
+ init_net_ssa, versions = core.get_ssa(init_net)
+ init_net_output_device_types = {
+ (name, versions[name]): predict_net_device_types[(name, 0)]
+ for name in init_net.external_output
+ }
+ init_net_device_types = infer_device_type(
+ init_net, known_status=init_net_output_device_types, device_name_style="pytorch"
+ )
+ _assign_op_device_option(init_net, init_net_ssa, init_net_device_types)
+
+
+def export_caffe2_detection_model(model: torch.nn.Module, tensor_inputs: List[torch.Tensor]):
+ """
+ Export a caffe2-compatible Detectron2 model to caffe2 format via ONNX.
+
+ Arg:
+ model: a caffe2-compatible version of detectron2 model, defined in caffe2_modeling.py
+ tensor_inputs: a list of tensors that caffe2 model takes as input.
+ """
+ model = copy.deepcopy(model)
+ assert isinstance(model, torch.nn.Module)
+ assert hasattr(model, "encode_additional_info")
+
+ # Export via ONNX
+ logger.info(
+ "Exporting a {} model via ONNX ...".format(type(model).__name__)
+ + " Some warnings from ONNX are expected and are usually not to worry about."
+ )
+ onnx_model = export_onnx_model(model, (tensor_inputs,))
+ # Convert ONNX model to Caffe2 protobuf
+ init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model)
+ ops_table = [[op.type, op.input, op.output] for op in predict_net.op]
+ table = tabulate(ops_table, headers=["type", "input", "output"], tablefmt="pipe")
+ logger.info(
+ "ONNX export Done. Exported predict_net (before optimizations):\n" + colored(table, "cyan")
+ )
+
+ # Apply protobuf optimization
+ fuse_alias_placeholder(predict_net, init_net)
+ if any(t.device.type != "cpu" for t in tensor_inputs):
+ fuse_copy_between_cpu_and_gpu(predict_net)
+ remove_dead_end_ops(init_net)
+ _assign_device_option(predict_net, init_net, tensor_inputs)
+ params, device_options = get_params_from_init_net(init_net)
+ predict_net, params = remove_reshape_for_fc(predict_net, params)
+ init_net = construct_init_net_from_params(params, device_options)
+ group_norm_replace_aten_with_caffe2(predict_net)
+
+ # Record necessary information for running the pb model in Detectron2 system.
+ model.encode_additional_info(predict_net, init_net)
+
+ logger.info("Operators used in predict_net: \n{}".format(_op_stats(predict_net)))
+ logger.info("Operators used in init_net: \n{}".format(_op_stats(init_net)))
+
+ return predict_net, init_net
+
+
+def run_and_save_graph(predict_net, init_net, tensor_inputs, graph_save_path):
+ """
+ Run the caffe2 model on given inputs, recording the shape and draw the graph.
+
+ predict_net/init_net: caffe2 model.
+ tensor_inputs: a list of tensors that caffe2 model takes as input.
+ graph_save_path: path for saving graph of exported model.
+ """
+
+ logger.info("Saving graph of ONNX exported model to {} ...".format(graph_save_path))
+ save_graph(predict_net, graph_save_path, op_only=False)
+
+ # Run the exported Caffe2 net
+ logger.info("Running ONNX exported model ...")
+ with ScopedWS("__ws_tmp__", True) as ws:
+ ws.RunNetOnce(init_net)
+ initialized_blobs = set(ws.Blobs())
+ uninitialized = [inp for inp in predict_net.external_input if inp not in initialized_blobs]
+ for name, blob in zip(uninitialized, tensor_inputs):
+ ws.FeedBlob(name, blob)
+
+ try:
+ ws.RunNetOnce(predict_net)
+ except RuntimeError as e:
+ logger.warning("Encountered RuntimeError: \n{}".format(str(e)))
+
+ ws_blobs = {b: ws.FetchBlob(b) for b in ws.Blobs()}
+ blob_sizes = {b: ws_blobs[b].shape for b in ws_blobs if isinstance(ws_blobs[b], np.ndarray)}
+
+ logger.info("Saving graph with blob shapes to {} ...".format(graph_save_path))
+ save_graph(predict_net, graph_save_path, op_only=False, blob_sizes=blob_sizes)
+
+ return ws_blobs
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/export/caffe2_inference.py b/sd-webui-controlnet/annotator/oneformer/detectron2/export/caffe2_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..deb886c0417285ed1d5ad85eb941fa1ac757cdab
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/export/caffe2_inference.py
@@ -0,0 +1,161 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import logging
+import numpy as np
+from itertools import count
+import torch
+from caffe2.proto import caffe2_pb2
+from caffe2.python import core
+
+from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format
+from .shared import ScopedWS, get_pb_arg_vali, get_pb_arg_vals, infer_device_type
+
+logger = logging.getLogger(__name__)
+
+
+# ===== ref: mobile-vision predictor's 'Caffe2Wrapper' class ======
+class ProtobufModel(torch.nn.Module):
+ """
+ Wrapper of a caffe2's protobuf model.
+ It works just like nn.Module, but running caffe2 under the hood.
+ Input/Output are tuple[tensor] that match the caffe2 net's external_input/output.
+ """
+
+ _ids = count(0)
+
+ def __init__(self, predict_net, init_net):
+ logger.info(f"Initializing ProtobufModel for: {predict_net.name} ...")
+ super().__init__()
+ assert isinstance(predict_net, caffe2_pb2.NetDef)
+ assert isinstance(init_net, caffe2_pb2.NetDef)
+ # create unique temporary workspace for each instance
+ self.ws_name = "__tmp_ProtobufModel_{}__".format(next(self._ids))
+ self.net = core.Net(predict_net)
+
+ logger.info("Running init_net once to fill the parameters ...")
+ with ScopedWS(self.ws_name, is_reset=True, is_cleanup=False) as ws:
+ ws.RunNetOnce(init_net)
+ uninitialized_external_input = []
+ for blob in self.net.Proto().external_input:
+ if blob not in ws.Blobs():
+ uninitialized_external_input.append(blob)
+ ws.CreateBlob(blob)
+ ws.CreateNet(self.net)
+
+ self._error_msgs = set()
+ self._input_blobs = uninitialized_external_input
+
+ def _infer_output_devices(self, inputs):
+ """
+ Returns:
+ list[str]: list of device for each external output
+ """
+
+ def _get_device_type(torch_tensor):
+ assert torch_tensor.device.type in ["cpu", "cuda"]
+ assert torch_tensor.device.index == 0
+ return torch_tensor.device.type
+
+ predict_net = self.net.Proto()
+ input_device_types = {
+ (name, 0): _get_device_type(tensor) for name, tensor in zip(self._input_blobs, inputs)
+ }
+ device_type_map = infer_device_type(
+ predict_net, known_status=input_device_types, device_name_style="pytorch"
+ )
+ ssa, versions = core.get_ssa(predict_net)
+ versioned_outputs = [(name, versions[name]) for name in predict_net.external_output]
+ output_devices = [device_type_map[outp] for outp in versioned_outputs]
+ return output_devices
+
+ def forward(self, inputs):
+ """
+ Args:
+ inputs (tuple[torch.Tensor])
+
+ Returns:
+ tuple[torch.Tensor]
+ """
+ assert len(inputs) == len(self._input_blobs), (
+ f"Length of inputs ({len(inputs)}) "
+ f"doesn't match the required input blobs: {self._input_blobs}"
+ )
+
+ with ScopedWS(self.ws_name, is_reset=False, is_cleanup=False) as ws:
+ for b, tensor in zip(self._input_blobs, inputs):
+ ws.FeedBlob(b, tensor)
+
+ try:
+ ws.RunNet(self.net.Proto().name)
+ except RuntimeError as e:
+ if not str(e) in self._error_msgs:
+ self._error_msgs.add(str(e))
+ logger.warning("Encountered new RuntimeError: \n{}".format(str(e)))
+ logger.warning("Catch the error and use partial results.")
+
+ c2_outputs = [ws.FetchBlob(b) for b in self.net.Proto().external_output]
+ # Remove outputs of current run, this is necessary in order to
+ # prevent fetching the result from previous run if the model fails
+ # in the middle.
+ for b in self.net.Proto().external_output:
+ # Needs to create uninitialized blob to make the net runable.
+ # This is "equivalent" to: ws.RemoveBlob(b) then ws.CreateBlob(b),
+ # but there'no such API.
+ ws.FeedBlob(b, f"{b}, a C++ native class of type nullptr (uninitialized).")
+
+ # Cast output to torch.Tensor on the desired device
+ output_devices = (
+ self._infer_output_devices(inputs)
+ if any(t.device.type != "cpu" for t in inputs)
+ else ["cpu" for _ in self.net.Proto().external_output]
+ )
+
+ outputs = []
+ for name, c2_output, device in zip(
+ self.net.Proto().external_output, c2_outputs, output_devices
+ ):
+ if not isinstance(c2_output, np.ndarray):
+ raise RuntimeError(
+ "Invalid output for blob {}, received: {}".format(name, c2_output)
+ )
+ outputs.append(torch.tensor(c2_output).to(device=device))
+ return tuple(outputs)
+
+
+class ProtobufDetectionModel(torch.nn.Module):
+ """
+ A class works just like a pytorch meta arch in terms of inference, but running
+ caffe2 model under the hood.
+ """
+
+ def __init__(self, predict_net, init_net, *, convert_outputs=None):
+ """
+ Args:
+ predict_net, init_net (core.Net): caffe2 nets
+ convert_outptus (callable): a function that converts caffe2
+ outputs to the same format of the original pytorch model.
+ By default, use the one defined in the caffe2 meta_arch.
+ """
+ super().__init__()
+ self.protobuf_model = ProtobufModel(predict_net, init_net)
+ self.size_divisibility = get_pb_arg_vali(predict_net, "size_divisibility", 0)
+ self.device = get_pb_arg_vals(predict_net, "device", b"cpu").decode("ascii")
+
+ if convert_outputs is None:
+ meta_arch = get_pb_arg_vals(predict_net, "meta_architecture", b"GeneralizedRCNN")
+ meta_arch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[meta_arch.decode("ascii")]
+ self._convert_outputs = meta_arch.get_outputs_converter(predict_net, init_net)
+ else:
+ self._convert_outputs = convert_outputs
+
+ def _convert_inputs(self, batched_inputs):
+ # currently all models convert inputs in the same way
+ return convert_batched_inputs_to_c2_format(
+ batched_inputs, self.size_divisibility, self.device
+ )
+
+ def forward(self, batched_inputs):
+ c2_inputs = self._convert_inputs(batched_inputs)
+ c2_results = self.protobuf_model(c2_inputs)
+ c2_results = dict(zip(self.protobuf_model.net.Proto().external_output, c2_results))
+ return self._convert_outputs(batched_inputs, c2_inputs, c2_results)
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/export/caffe2_modeling.py b/sd-webui-controlnet/annotator/oneformer/detectron2/export/caffe2_modeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0128e4672bc08eb2983d3d382614c6381baefd9
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/export/caffe2_modeling.py
@@ -0,0 +1,419 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import functools
+import io
+import struct
+import types
+import torch
+
+from annotator.oneformer.detectron2.modeling import meta_arch
+from annotator.oneformer.detectron2.modeling.box_regression import Box2BoxTransform
+from annotator.oneformer.detectron2.modeling.roi_heads import keypoint_head
+from annotator.oneformer.detectron2.structures import Boxes, ImageList, Instances, RotatedBoxes
+
+from .c10 import Caffe2Compatible
+from .caffe2_patch import ROIHeadsPatcher, patch_generalized_rcnn
+from .shared import (
+ alias,
+ check_set_pb_arg,
+ get_pb_arg_floats,
+ get_pb_arg_valf,
+ get_pb_arg_vali,
+ get_pb_arg_vals,
+ mock_torch_nn_functional_interpolate,
+)
+
+
+def assemble_rcnn_outputs_by_name(image_sizes, tensor_outputs, force_mask_on=False):
+ """
+ A function to assemble caffe2 model's outputs (i.e. Dict[str, Tensor])
+ to detectron2's format (i.e. list of Instances instance).
+ This only works when the model follows the Caffe2 detectron's naming convention.
+
+ Args:
+ image_sizes (List[List[int, int]]): [H, W] of every image.
+ tensor_outputs (Dict[str, Tensor]): external_output to its tensor.
+
+ force_mask_on (Bool): if true, the it make sure there'll be pred_masks even
+ if the mask is not found from tensor_outputs (usually due to model crash)
+ """
+
+ results = [Instances(image_size) for image_size in image_sizes]
+
+ batch_splits = tensor_outputs.get("batch_splits", None)
+ if batch_splits:
+ raise NotImplementedError()
+ assert len(image_sizes) == 1
+ result = results[0]
+
+ bbox_nms = tensor_outputs["bbox_nms"]
+ score_nms = tensor_outputs["score_nms"]
+ class_nms = tensor_outputs["class_nms"]
+ # Detection will always success because Conv support 0-batch
+ assert bbox_nms is not None
+ assert score_nms is not None
+ assert class_nms is not None
+ if bbox_nms.shape[1] == 5:
+ result.pred_boxes = RotatedBoxes(bbox_nms)
+ else:
+ result.pred_boxes = Boxes(bbox_nms)
+ result.scores = score_nms
+ result.pred_classes = class_nms.to(torch.int64)
+
+ mask_fcn_probs = tensor_outputs.get("mask_fcn_probs", None)
+ if mask_fcn_probs is not None:
+ # finish the mask pred
+ mask_probs_pred = mask_fcn_probs
+ num_masks = mask_probs_pred.shape[0]
+ class_pred = result.pred_classes
+ indices = torch.arange(num_masks, device=class_pred.device)
+ mask_probs_pred = mask_probs_pred[indices, class_pred][:, None]
+ result.pred_masks = mask_probs_pred
+ elif force_mask_on:
+ # NOTE: there's no way to know the height/width of mask here, it won't be
+ # used anyway when batch size is 0, so just set them to 0.
+ result.pred_masks = torch.zeros([0, 1, 0, 0], dtype=torch.uint8)
+
+ keypoints_out = tensor_outputs.get("keypoints_out", None)
+ kps_score = tensor_outputs.get("kps_score", None)
+ if keypoints_out is not None:
+ # keypoints_out: [N, 4, #kypoints], where 4 is in order of (x, y, score, prob)
+ keypoints_tensor = keypoints_out
+ # NOTE: it's possible that prob is not calculated if "should_output_softmax"
+ # is set to False in HeatmapMaxKeypoint, so just using raw score, seems
+ # it doesn't affect mAP. TODO: check more carefully.
+ keypoint_xyp = keypoints_tensor.transpose(1, 2)[:, :, [0, 1, 2]]
+ result.pred_keypoints = keypoint_xyp
+ elif kps_score is not None:
+ # keypoint heatmap to sparse data structure
+ pred_keypoint_logits = kps_score
+ keypoint_head.keypoint_rcnn_inference(pred_keypoint_logits, [result])
+
+ return results
+
+
+def _cast_to_f32(f64):
+ return struct.unpack("f", struct.pack("f", f64))[0]
+
+
+def set_caffe2_compatible_tensor_mode(model, enable=True):
+ def _fn(m):
+ if isinstance(m, Caffe2Compatible):
+ m.tensor_mode = enable
+
+ model.apply(_fn)
+
+
+def convert_batched_inputs_to_c2_format(batched_inputs, size_divisibility, device):
+ """
+ See get_caffe2_inputs() below.
+ """
+ assert all(isinstance(x, dict) for x in batched_inputs)
+ assert all(x["image"].dim() == 3 for x in batched_inputs)
+
+ images = [x["image"] for x in batched_inputs]
+ images = ImageList.from_tensors(images, size_divisibility)
+
+ im_info = []
+ for input_per_image, image_size in zip(batched_inputs, images.image_sizes):
+ target_height = input_per_image.get("height", image_size[0])
+ target_width = input_per_image.get("width", image_size[1]) # noqa
+ # NOTE: The scale inside im_info is kept as convention and for providing
+ # post-processing information if further processing is needed. For
+ # current Caffe2 model definitions that don't include post-processing inside
+ # the model, this number is not used.
+ # NOTE: There can be a slight difference between width and height
+ # scales, using a single number can results in numerical difference
+ # compared with D2's post-processing.
+ scale = target_height / image_size[0]
+ im_info.append([image_size[0], image_size[1], scale])
+ im_info = torch.Tensor(im_info)
+
+ return images.tensor.to(device), im_info.to(device)
+
+
+class Caffe2MetaArch(Caffe2Compatible, torch.nn.Module):
+ """
+ Base class for caffe2-compatible implementation of a meta architecture.
+ The forward is traceable and its traced graph can be converted to caffe2
+ graph through ONNX.
+ """
+
+ def __init__(self, cfg, torch_model):
+ """
+ Args:
+ cfg (CfgNode):
+ torch_model (nn.Module): the detectron2 model (meta_arch) to be
+ converted.
+ """
+ super().__init__()
+ self._wrapped_model = torch_model
+ self.eval()
+ set_caffe2_compatible_tensor_mode(self, True)
+
+ def get_caffe2_inputs(self, batched_inputs):
+ """
+ Convert pytorch-style structured inputs to caffe2-style inputs that
+ are tuples of tensors.
+
+ Args:
+ batched_inputs (list[dict]): inputs to a detectron2 model
+ in its standard format. Each dict has "image" (CHW tensor), and optionally
+ "height" and "width".
+
+ Returns:
+ tuple[Tensor]:
+ tuple of tensors that will be the inputs to the
+ :meth:`forward` method. For existing models, the first
+ is an NCHW tensor (padded and batched); the second is
+ a im_info Nx3 tensor, where the rows are
+ (height, width, unused legacy parameter)
+ """
+ return convert_batched_inputs_to_c2_format(
+ batched_inputs,
+ self._wrapped_model.backbone.size_divisibility,
+ self._wrapped_model.device,
+ )
+
+ def encode_additional_info(self, predict_net, init_net):
+ """
+ Save extra metadata that will be used by inference in the output protobuf.
+ """
+ pass
+
+ def forward(self, inputs):
+ """
+ Run the forward in caffe2-style. It has to use caffe2-compatible ops
+ and the method will be used for tracing.
+
+ Args:
+ inputs (tuple[Tensor]): inputs defined by :meth:`get_caffe2_input`.
+ They will be the inputs of the converted caffe2 graph.
+
+ Returns:
+ tuple[Tensor]: output tensors. They will be the outputs of the
+ converted caffe2 graph.
+ """
+ raise NotImplementedError
+
+ def _caffe2_preprocess_image(self, inputs):
+ """
+ Caffe2 implementation of preprocess_image, which is called inside each MetaArch's forward.
+ It normalizes the input images, and the final caffe2 graph assumes the
+ inputs have been batched already.
+ """
+ data, im_info = inputs
+ data = alias(data, "data")
+ im_info = alias(im_info, "im_info")
+ mean, std = self._wrapped_model.pixel_mean, self._wrapped_model.pixel_std
+ normalized_data = (data - mean) / std
+ normalized_data = alias(normalized_data, "normalized_data")
+
+ # Pack (data, im_info) into ImageList which is recognized by self.inference.
+ images = ImageList(tensor=normalized_data, image_sizes=im_info)
+ return images
+
+ @staticmethod
+ def get_outputs_converter(predict_net, init_net):
+ """
+ Creates a function that converts outputs of the caffe2 model to
+ detectron2's standard format.
+ The function uses information in `predict_net` and `init_net` that are
+ available at inferene time. Therefore the function logic can be used in inference.
+
+ The returned function has the following signature:
+
+ def convert(batched_inputs, c2_inputs, c2_results) -> detectron2_outputs
+
+ Where
+
+ * batched_inputs (list[dict]): the original input format of the meta arch
+ * c2_inputs (tuple[Tensor]): the caffe2 inputs.
+ * c2_results (dict[str, Tensor]): the caffe2 output format,
+ corresponding to the outputs of the :meth:`forward` function.
+ * detectron2_outputs: the original output format of the meta arch.
+
+ This function can be used to compare the outputs of the original meta arch and
+ the converted caffe2 graph.
+
+ Returns:
+ callable: a callable of the above signature.
+ """
+ raise NotImplementedError
+
+
+class Caffe2GeneralizedRCNN(Caffe2MetaArch):
+ def __init__(self, cfg, torch_model):
+ assert isinstance(torch_model, meta_arch.GeneralizedRCNN)
+ torch_model = patch_generalized_rcnn(torch_model)
+ super().__init__(cfg, torch_model)
+
+ try:
+ use_heatmap_max_keypoint = cfg.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT
+ except AttributeError:
+ use_heatmap_max_keypoint = False
+ self.roi_heads_patcher = ROIHeadsPatcher(
+ self._wrapped_model.roi_heads, use_heatmap_max_keypoint
+ )
+
+ def encode_additional_info(self, predict_net, init_net):
+ size_divisibility = self._wrapped_model.backbone.size_divisibility
+ check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility)
+ check_set_pb_arg(
+ predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii")
+ )
+ check_set_pb_arg(predict_net, "meta_architecture", "s", b"GeneralizedRCNN")
+
+ @mock_torch_nn_functional_interpolate()
+ def forward(self, inputs):
+ if not self.tensor_mode:
+ return self._wrapped_model.inference(inputs)
+ images = self._caffe2_preprocess_image(inputs)
+ features = self._wrapped_model.backbone(images.tensor)
+ proposals, _ = self._wrapped_model.proposal_generator(images, features)
+ with self.roi_heads_patcher.mock_roi_heads():
+ detector_results, _ = self._wrapped_model.roi_heads(images, features, proposals)
+ return tuple(detector_results[0].flatten())
+
+ @staticmethod
+ def get_outputs_converter(predict_net, init_net):
+ def f(batched_inputs, c2_inputs, c2_results):
+ _, im_info = c2_inputs
+ image_sizes = [[int(im[0]), int(im[1])] for im in im_info]
+ results = assemble_rcnn_outputs_by_name(image_sizes, c2_results)
+ return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes)
+
+ return f
+
+
+class Caffe2RetinaNet(Caffe2MetaArch):
+ def __init__(self, cfg, torch_model):
+ assert isinstance(torch_model, meta_arch.RetinaNet)
+ super().__init__(cfg, torch_model)
+
+ @mock_torch_nn_functional_interpolate()
+ def forward(self, inputs):
+ assert self.tensor_mode
+ images = self._caffe2_preprocess_image(inputs)
+
+ # explicitly return the images sizes to avoid removing "im_info" by ONNX
+ # since it's not used in the forward path
+ return_tensors = [images.image_sizes]
+
+ features = self._wrapped_model.backbone(images.tensor)
+ features = [features[f] for f in self._wrapped_model.head_in_features]
+ for i, feature_i in enumerate(features):
+ features[i] = alias(feature_i, "feature_{}".format(i), is_backward=True)
+ return_tensors.append(features[i])
+
+ pred_logits, pred_anchor_deltas = self._wrapped_model.head(features)
+ for i, (box_cls_i, box_delta_i) in enumerate(zip(pred_logits, pred_anchor_deltas)):
+ return_tensors.append(alias(box_cls_i, "box_cls_{}".format(i)))
+ return_tensors.append(alias(box_delta_i, "box_delta_{}".format(i)))
+
+ return tuple(return_tensors)
+
+ def encode_additional_info(self, predict_net, init_net):
+ size_divisibility = self._wrapped_model.backbone.size_divisibility
+ check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility)
+ check_set_pb_arg(
+ predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii")
+ )
+ check_set_pb_arg(predict_net, "meta_architecture", "s", b"RetinaNet")
+
+ # Inference parameters:
+ check_set_pb_arg(
+ predict_net, "score_threshold", "f", _cast_to_f32(self._wrapped_model.test_score_thresh)
+ )
+ check_set_pb_arg(
+ predict_net, "topk_candidates", "i", self._wrapped_model.test_topk_candidates
+ )
+ check_set_pb_arg(
+ predict_net, "nms_threshold", "f", _cast_to_f32(self._wrapped_model.test_nms_thresh)
+ )
+ check_set_pb_arg(
+ predict_net,
+ "max_detections_per_image",
+ "i",
+ self._wrapped_model.max_detections_per_image,
+ )
+
+ check_set_pb_arg(
+ predict_net,
+ "bbox_reg_weights",
+ "floats",
+ [_cast_to_f32(w) for w in self._wrapped_model.box2box_transform.weights],
+ )
+ self._encode_anchor_generator_cfg(predict_net)
+
+ def _encode_anchor_generator_cfg(self, predict_net):
+ # serialize anchor_generator for future use
+ serialized_anchor_generator = io.BytesIO()
+ torch.save(self._wrapped_model.anchor_generator, serialized_anchor_generator)
+ # Ideally we can put anchor generating inside the model, then we don't
+ # need to store this information.
+ bytes = serialized_anchor_generator.getvalue()
+ check_set_pb_arg(predict_net, "serialized_anchor_generator", "s", bytes)
+
+ @staticmethod
+ def get_outputs_converter(predict_net, init_net):
+ self = types.SimpleNamespace()
+ serialized_anchor_generator = io.BytesIO(
+ get_pb_arg_vals(predict_net, "serialized_anchor_generator", None)
+ )
+ self.anchor_generator = torch.load(serialized_anchor_generator)
+ bbox_reg_weights = get_pb_arg_floats(predict_net, "bbox_reg_weights", None)
+ self.box2box_transform = Box2BoxTransform(weights=tuple(bbox_reg_weights))
+ self.test_score_thresh = get_pb_arg_valf(predict_net, "score_threshold", None)
+ self.test_topk_candidates = get_pb_arg_vali(predict_net, "topk_candidates", None)
+ self.test_nms_thresh = get_pb_arg_valf(predict_net, "nms_threshold", None)
+ self.max_detections_per_image = get_pb_arg_vali(
+ predict_net, "max_detections_per_image", None
+ )
+
+ # hack to reuse inference code from RetinaNet
+ for meth in [
+ "forward_inference",
+ "inference_single_image",
+ "_transpose_dense_predictions",
+ "_decode_multi_level_predictions",
+ "_decode_per_level_predictions",
+ ]:
+ setattr(self, meth, functools.partial(getattr(meta_arch.RetinaNet, meth), self))
+
+ def f(batched_inputs, c2_inputs, c2_results):
+ _, im_info = c2_inputs
+ image_sizes = [[int(im[0]), int(im[1])] for im in im_info]
+ dummy_images = ImageList(
+ torch.randn(
+ (
+ len(im_info),
+ 3,
+ )
+ + tuple(image_sizes[0])
+ ),
+ image_sizes,
+ )
+
+ num_features = len([x for x in c2_results.keys() if x.startswith("box_cls_")])
+ pred_logits = [c2_results["box_cls_{}".format(i)] for i in range(num_features)]
+ pred_anchor_deltas = [c2_results["box_delta_{}".format(i)] for i in range(num_features)]
+
+ # For each feature level, feature should have the same batch size and
+ # spatial dimension as the box_cls and box_delta.
+ dummy_features = [x.clone()[:, 0:0, :, :] for x in pred_logits]
+ # self.num_classess can be inferred
+ self.num_classes = pred_logits[0].shape[1] // (pred_anchor_deltas[0].shape[1] // 4)
+
+ results = self.forward_inference(
+ dummy_images, dummy_features, [pred_logits, pred_anchor_deltas]
+ )
+ return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes)
+
+ return f
+
+
+META_ARCH_CAFFE2_EXPORT_TYPE_MAP = {
+ "GeneralizedRCNN": Caffe2GeneralizedRCNN,
+ "RetinaNet": Caffe2RetinaNet,
+}
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/export/caffe2_patch.py b/sd-webui-controlnet/annotator/oneformer/detectron2/export/caffe2_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c197cac1e7d5f665b6cbda46268716b1222f217
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/export/caffe2_patch.py
@@ -0,0 +1,152 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import contextlib
+from unittest import mock
+import torch
+
+from annotator.oneformer.detectron2.modeling import poolers
+from annotator.oneformer.detectron2.modeling.proposal_generator import rpn
+from annotator.oneformer.detectron2.modeling.roi_heads import keypoint_head, mask_head
+from annotator.oneformer.detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
+
+from .c10 import (
+ Caffe2Compatible,
+ Caffe2FastRCNNOutputsInference,
+ Caffe2KeypointRCNNInference,
+ Caffe2MaskRCNNInference,
+ Caffe2ROIPooler,
+ Caffe2RPN,
+)
+
+
+class GenericMixin(object):
+ pass
+
+
+class Caffe2CompatibleConverter(object):
+ """
+ A GenericUpdater which implements the `create_from` interface, by modifying
+ module object and assign it with another class replaceCls.
+ """
+
+ def __init__(self, replaceCls):
+ self.replaceCls = replaceCls
+
+ def create_from(self, module):
+ # update module's class to the new class
+ assert isinstance(module, torch.nn.Module)
+ if issubclass(self.replaceCls, GenericMixin):
+ # replaceCls should act as mixin, create a new class on-the-fly
+ new_class = type(
+ "{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__),
+ (self.replaceCls, module.__class__),
+ {}, # {"new_method": lambda self: ...},
+ )
+ module.__class__ = new_class
+ else:
+ # replaceCls is complete class, this allow arbitrary class swap
+ module.__class__ = self.replaceCls
+
+ # initialize Caffe2Compatible
+ if isinstance(module, Caffe2Compatible):
+ module.tensor_mode = False
+
+ return module
+
+
+def patch(model, target, updater, *args, **kwargs):
+ """
+ recursively (post-order) update all modules with the target type and its
+ subclasses, make a initialization/composition/inheritance/... via the
+ updater.create_from.
+ """
+ for name, module in model.named_children():
+ model._modules[name] = patch(module, target, updater, *args, **kwargs)
+ if isinstance(model, target):
+ return updater.create_from(model, *args, **kwargs)
+ return model
+
+
+def patch_generalized_rcnn(model):
+ ccc = Caffe2CompatibleConverter
+ model = patch(model, rpn.RPN, ccc(Caffe2RPN))
+ model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler))
+
+ return model
+
+
+@contextlib.contextmanager
+def mock_fastrcnn_outputs_inference(
+ tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers
+):
+ with mock.patch.object(
+ box_predictor_type,
+ "inference",
+ autospec=True,
+ side_effect=Caffe2FastRCNNOutputsInference(tensor_mode),
+ ) as mocked_func:
+ yield
+ if check:
+ assert mocked_func.call_count > 0
+
+
+@contextlib.contextmanager
+def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True):
+ with mock.patch(
+ "{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference()
+ ) as mocked_func:
+ yield
+ if check:
+ assert mocked_func.call_count > 0
+
+
+@contextlib.contextmanager
+def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True):
+ with mock.patch(
+ "{}.keypoint_rcnn_inference".format(patched_module),
+ side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint),
+ ) as mocked_func:
+ yield
+ if check:
+ assert mocked_func.call_count > 0
+
+
+class ROIHeadsPatcher:
+ def __init__(self, heads, use_heatmap_max_keypoint):
+ self.heads = heads
+ self.use_heatmap_max_keypoint = use_heatmap_max_keypoint
+
+ @contextlib.contextmanager
+ def mock_roi_heads(self, tensor_mode=True):
+ """
+ Patching several inference functions inside ROIHeads and its subclasses
+
+ Args:
+ tensor_mode (bool): whether the inputs/outputs are caffe2's tensor
+ format or not. Default to True.
+ """
+ # NOTE: this requries the `keypoint_rcnn_inference` and `mask_rcnn_inference`
+ # are called inside the same file as BaseXxxHead due to using mock.patch.
+ kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__
+ mask_head_mod = mask_head.BaseMaskRCNNHead.__module__
+
+ mock_ctx_managers = [
+ mock_fastrcnn_outputs_inference(
+ tensor_mode=tensor_mode,
+ check=True,
+ box_predictor_type=type(self.heads.box_predictor),
+ )
+ ]
+ if getattr(self.heads, "keypoint_on", False):
+ mock_ctx_managers += [
+ mock_keypoint_rcnn_inference(
+ tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint
+ )
+ ]
+ if getattr(self.heads, "mask_on", False):
+ mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)]
+
+ with contextlib.ExitStack() as stack: # python 3.3+
+ for mgr in mock_ctx_managers:
+ stack.enter_context(mgr)
+ yield
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/export/flatten.py b/sd-webui-controlnet/annotator/oneformer/detectron2/export/flatten.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fcb2bf49a0adad2798a10781a42accd9571218f
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/export/flatten.py
@@ -0,0 +1,330 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import collections
+from dataclasses import dataclass
+from typing import Callable, List, Optional, Tuple
+import torch
+from torch import nn
+
+from annotator.oneformer.detectron2.structures import Boxes, Instances, ROIMasks
+from annotator.oneformer.detectron2.utils.registry import _convert_target_to_string, locate
+
+from .torchscript_patch import patch_builtin_len
+
+
+@dataclass
+class Schema:
+ """
+ A Schema defines how to flatten a possibly hierarchical object into tuple of
+ primitive objects, so it can be used as inputs/outputs of PyTorch's tracing.
+
+ PyTorch does not support tracing a function that produces rich output
+ structures (e.g. dict, Instances, Boxes). To trace such a function, we
+ flatten the rich object into tuple of tensors, and return this tuple of tensors
+ instead. Meanwhile, we also need to know how to "rebuild" the original object
+ from the flattened results, so we can evaluate the flattened results.
+ A Schema defines how to flatten an object, and while flattening it, it records
+ necessary schemas so that the object can be rebuilt using the flattened outputs.
+
+ The flattened object and the schema object is returned by ``.flatten`` classmethod.
+ Then the original object can be rebuilt with the ``__call__`` method of schema.
+
+ A Schema is a dataclass that can be serialized easily.
+ """
+
+ # inspired by FetchMapper in tensorflow/python/client/session.py
+
+ @classmethod
+ def flatten(cls, obj):
+ raise NotImplementedError
+
+ def __call__(self, values):
+ raise NotImplementedError
+
+ @staticmethod
+ def _concat(values):
+ ret = ()
+ sizes = []
+ for v in values:
+ assert isinstance(v, tuple), "Flattened results must be a tuple"
+ ret = ret + v
+ sizes.append(len(v))
+ return ret, sizes
+
+ @staticmethod
+ def _split(values, sizes):
+ if len(sizes):
+ expected_len = sum(sizes)
+ assert (
+ len(values) == expected_len
+ ), f"Values has length {len(values)} but expect length {expected_len}."
+ ret = []
+ for k in range(len(sizes)):
+ begin, end = sum(sizes[:k]), sum(sizes[: k + 1])
+ ret.append(values[begin:end])
+ return ret
+
+
+@dataclass
+class ListSchema(Schema):
+ schemas: List[Schema] # the schemas that define how to flatten each element in the list
+ sizes: List[int] # the flattened length of each element
+
+ def __call__(self, values):
+ values = self._split(values, self.sizes)
+ if len(values) != len(self.schemas):
+ raise ValueError(
+ f"Values has length {len(values)} but schemas " f"has length {len(self.schemas)}!"
+ )
+ values = [m(v) for m, v in zip(self.schemas, values)]
+ return list(values)
+
+ @classmethod
+ def flatten(cls, obj):
+ res = [flatten_to_tuple(k) for k in obj]
+ values, sizes = cls._concat([k[0] for k in res])
+ return values, cls([k[1] for k in res], sizes)
+
+
+@dataclass
+class TupleSchema(ListSchema):
+ def __call__(self, values):
+ return tuple(super().__call__(values))
+
+
+@dataclass
+class IdentitySchema(Schema):
+ def __call__(self, values):
+ return values[0]
+
+ @classmethod
+ def flatten(cls, obj):
+ return (obj,), cls()
+
+
+@dataclass
+class DictSchema(ListSchema):
+ keys: List[str]
+
+ def __call__(self, values):
+ values = super().__call__(values)
+ return dict(zip(self.keys, values))
+
+ @classmethod
+ def flatten(cls, obj):
+ for k in obj.keys():
+ if not isinstance(k, str):
+ raise KeyError("Only support flattening dictionaries if keys are str.")
+ keys = sorted(obj.keys())
+ values = [obj[k] for k in keys]
+ ret, schema = ListSchema.flatten(values)
+ return ret, cls(schema.schemas, schema.sizes, keys)
+
+
+@dataclass
+class InstancesSchema(DictSchema):
+ def __call__(self, values):
+ image_size, fields = values[-1], values[:-1]
+ fields = super().__call__(fields)
+ return Instances(image_size, **fields)
+
+ @classmethod
+ def flatten(cls, obj):
+ ret, schema = super().flatten(obj.get_fields())
+ size = obj.image_size
+ if not isinstance(size, torch.Tensor):
+ size = torch.tensor(size)
+ return ret + (size,), schema
+
+
+@dataclass
+class TensorWrapSchema(Schema):
+ """
+ For classes that are simple wrapper of tensors, e.g.
+ Boxes, RotatedBoxes, BitMasks
+ """
+
+ class_name: str
+
+ def __call__(self, values):
+ return locate(self.class_name)(values[0])
+
+ @classmethod
+ def flatten(cls, obj):
+ return (obj.tensor,), cls(_convert_target_to_string(type(obj)))
+
+
+# if more custom structures needed in the future, can allow
+# passing in extra schemas for custom types
+def flatten_to_tuple(obj):
+ """
+ Flatten an object so it can be used for PyTorch tracing.
+ Also returns how to rebuild the original object from the flattened outputs.
+
+ Returns:
+ res (tuple): the flattened results that can be used as tracing outputs
+ schema: an object with a ``__call__`` method such that ``schema(res) == obj``.
+ It is a pure dataclass that can be serialized.
+ """
+ schemas = [
+ ((str, bytes), IdentitySchema),
+ (list, ListSchema),
+ (tuple, TupleSchema),
+ (collections.abc.Mapping, DictSchema),
+ (Instances, InstancesSchema),
+ ((Boxes, ROIMasks), TensorWrapSchema),
+ ]
+ for klass, schema in schemas:
+ if isinstance(obj, klass):
+ F = schema
+ break
+ else:
+ F = IdentitySchema
+
+ return F.flatten(obj)
+
+
+class TracingAdapter(nn.Module):
+ """
+ A model may take rich input/output format (e.g. dict or custom classes),
+ but `torch.jit.trace` requires tuple of tensors as input/output.
+ This adapter flattens input/output format of a model so it becomes traceable.
+
+ It also records the necessary schema to rebuild model's inputs/outputs from flattened
+ inputs/outputs.
+
+ Example:
+ ::
+ outputs = model(inputs) # inputs/outputs may be rich structure
+ adapter = TracingAdapter(model, inputs)
+
+ # can now trace the model, with adapter.flattened_inputs, or another
+ # tuple of tensors with the same length and meaning
+ traced = torch.jit.trace(adapter, adapter.flattened_inputs)
+
+ # traced model can only produce flattened outputs (tuple of tensors)
+ flattened_outputs = traced(*adapter.flattened_inputs)
+ # adapter knows the schema to convert it back (new_outputs == outputs)
+ new_outputs = adapter.outputs_schema(flattened_outputs)
+ """
+
+ flattened_inputs: Tuple[torch.Tensor] = None
+ """
+ Flattened version of inputs given to this class's constructor.
+ """
+
+ inputs_schema: Schema = None
+ """
+ Schema of the inputs given to this class's constructor.
+ """
+
+ outputs_schema: Schema = None
+ """
+ Schema of the output produced by calling the given model with inputs.
+ """
+
+ def __init__(
+ self,
+ model: nn.Module,
+ inputs,
+ inference_func: Optional[Callable] = None,
+ allow_non_tensor: bool = False,
+ ):
+ """
+ Args:
+ model: an nn.Module
+ inputs: An input argument or a tuple of input arguments used to call model.
+ After flattening, it has to only consist of tensors.
+ inference_func: a callable that takes (model, *inputs), calls the
+ model with inputs, and return outputs. By default it
+ is ``lambda model, *inputs: model(*inputs)``. Can be override
+ if you need to call the model differently.
+ allow_non_tensor: allow inputs/outputs to contain non-tensor objects.
+ This option will filter out non-tensor objects to make the
+ model traceable, but ``inputs_schema``/``outputs_schema`` cannot be
+ used anymore because inputs/outputs cannot be rebuilt from pure tensors.
+ This is useful when you're only interested in the single trace of
+ execution (e.g. for flop count), but not interested in
+ generalizing the traced graph to new inputs.
+ """
+ super().__init__()
+ if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
+ model = model.module
+ self.model = model
+ if not isinstance(inputs, tuple):
+ inputs = (inputs,)
+ self.inputs = inputs
+ self.allow_non_tensor = allow_non_tensor
+
+ if inference_func is None:
+ inference_func = lambda model, *inputs: model(*inputs) # noqa
+ self.inference_func = inference_func
+
+ self.flattened_inputs, self.inputs_schema = flatten_to_tuple(inputs)
+
+ if all(isinstance(x, torch.Tensor) for x in self.flattened_inputs):
+ return
+ if self.allow_non_tensor:
+ self.flattened_inputs = tuple(
+ [x for x in self.flattened_inputs if isinstance(x, torch.Tensor)]
+ )
+ self.inputs_schema = None
+ else:
+ for input in self.flattened_inputs:
+ if not isinstance(input, torch.Tensor):
+ raise ValueError(
+ "Inputs for tracing must only contain tensors. "
+ f"Got a {type(input)} instead."
+ )
+
+ def forward(self, *args: torch.Tensor):
+ with torch.no_grad(), patch_builtin_len():
+ if self.inputs_schema is not None:
+ inputs_orig_format = self.inputs_schema(args)
+ else:
+ if len(args) != len(self.flattened_inputs) or any(
+ x is not y for x, y in zip(args, self.flattened_inputs)
+ ):
+ raise ValueError(
+ "TracingAdapter does not contain valid inputs_schema."
+ " So it cannot generalize to other inputs and must be"
+ " traced with `.flattened_inputs`."
+ )
+ inputs_orig_format = self.inputs
+
+ outputs = self.inference_func(self.model, *inputs_orig_format)
+ flattened_outputs, schema = flatten_to_tuple(outputs)
+
+ flattened_output_tensors = tuple(
+ [x for x in flattened_outputs if isinstance(x, torch.Tensor)]
+ )
+ if len(flattened_output_tensors) < len(flattened_outputs):
+ if self.allow_non_tensor:
+ flattened_outputs = flattened_output_tensors
+ self.outputs_schema = None
+ else:
+ raise ValueError(
+ "Model cannot be traced because some model outputs "
+ "cannot flatten to tensors."
+ )
+ else: # schema is valid
+ if self.outputs_schema is None:
+ self.outputs_schema = schema
+ else:
+ assert self.outputs_schema == schema, (
+ "Model should always return outputs with the same "
+ "structure so it can be traced!"
+ )
+ return flattened_outputs
+
+ def _create_wrapper(self, traced_model):
+ """
+ Return a function that has an input/output interface the same as the
+ original model, but it calls the given traced model under the hood.
+ """
+
+ def forward(*args):
+ flattened_inputs, _ = flatten_to_tuple(args)
+ flattened_outputs = traced_model(*flattened_inputs)
+ return self.outputs_schema(flattened_outputs)
+
+ return forward
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/export/shared.py b/sd-webui-controlnet/annotator/oneformer/detectron2/export/shared.py
new file mode 100644
index 0000000000000000000000000000000000000000..53ba9335e26819f9381115eba17bbbe3816b469c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/export/shared.py
@@ -0,0 +1,1039 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import collections
+import copy
+import functools
+import logging
+import numpy as np
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from unittest import mock
+import caffe2.python.utils as putils
+import torch
+import torch.nn.functional as F
+from caffe2.proto import caffe2_pb2
+from caffe2.python import core, net_drawer, workspace
+from torch.nn.functional import interpolate as interp
+
+logger = logging.getLogger(__name__)
+
+
+# ==== torch/utils_toffee/cast.py =======================================
+
+
+def to_device(t, device_str):
+ """
+ This function is a replacement of .to(another_device) such that it allows the
+ casting to be traced properly by explicitly calling the underlying copy ops.
+ It also avoids introducing unncessary op when casting to the same device.
+ """
+ src = t.device
+ dst = torch.device(device_str)
+
+ if src == dst:
+ return t
+ elif src.type == "cuda" and dst.type == "cpu":
+ return torch.ops._caffe2.CopyGPUToCPU(t)
+ elif src.type == "cpu" and dst.type == "cuda":
+ return torch.ops._caffe2.CopyCPUToGPU(t)
+ else:
+ raise RuntimeError("Can't cast tensor from device {} to device {}".format(src, dst))
+
+
+# ==== torch/utils_toffee/interpolate.py =======================================
+
+
+# Note: borrowed from vision/detection/fair/detectron/detectron/modeling/detector.py
+def BilinearInterpolation(tensor_in, up_scale):
+ assert up_scale % 2 == 0, "Scale should be even"
+
+ def upsample_filt(size):
+ factor = (size + 1) // 2
+ if size % 2 == 1:
+ center = factor - 1
+ else:
+ center = factor - 0.5
+
+ og = np.ogrid[:size, :size]
+ return (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
+
+ kernel_size = int(up_scale) * 2
+ bil_filt = upsample_filt(kernel_size)
+
+ dim = int(tensor_in.shape[1])
+ kernel = np.zeros((dim, dim, kernel_size, kernel_size), dtype=np.float32)
+ kernel[range(dim), range(dim), :, :] = bil_filt
+
+ tensor_out = F.conv_transpose2d(
+ tensor_in,
+ weight=to_device(torch.Tensor(kernel), tensor_in.device),
+ bias=None,
+ stride=int(up_scale),
+ padding=int(up_scale / 2),
+ )
+
+ return tensor_out
+
+
+# NOTE: ONNX is incompatible with traced torch.nn.functional.interpolate if
+# using dynamic `scale_factor` rather than static `size`. (T43166860)
+# NOTE: Caffe2 Int8 conversion might not be able to quantize `size` properly.
+def onnx_compatibale_interpolate(
+ input, size=None, scale_factor=None, mode="nearest", align_corners=None
+):
+ # NOTE: The input dimensions are interpreted in the form:
+ # `mini-batch x channels x [optional depth] x [optional height] x width`.
+ if size is None and scale_factor is not None:
+ if input.dim() == 4:
+ if isinstance(scale_factor, (int, float)):
+ height_scale, width_scale = (scale_factor, scale_factor)
+ else:
+ assert isinstance(scale_factor, (tuple, list))
+ assert len(scale_factor) == 2
+ height_scale, width_scale = scale_factor
+
+ assert not align_corners, "No matching C2 op for align_corners == True"
+ if mode == "nearest":
+ return torch.ops._caffe2.ResizeNearest(
+ input, order="NCHW", width_scale=width_scale, height_scale=height_scale
+ )
+ elif mode == "bilinear":
+ logger.warning(
+ "Use F.conv_transpose2d for bilinear interpolate"
+ " because there's no such C2 op, this may cause significant"
+ " slowdown and the boundary pixels won't be as same as"
+ " using F.interpolate due to padding."
+ )
+ assert height_scale == width_scale
+ return BilinearInterpolation(input, up_scale=height_scale)
+ logger.warning("Output size is not static, it might cause ONNX conversion issue")
+
+ return interp(input, size, scale_factor, mode, align_corners)
+
+
+def mock_torch_nn_functional_interpolate():
+ def decorator(func):
+ @functools.wraps(func)
+ def _mock_torch_nn_functional_interpolate(*args, **kwargs):
+ if torch.onnx.is_in_onnx_export():
+ with mock.patch(
+ "torch.nn.functional.interpolate", side_effect=onnx_compatibale_interpolate
+ ):
+ return func(*args, **kwargs)
+ else:
+ return func(*args, **kwargs)
+
+ return _mock_torch_nn_functional_interpolate
+
+ return decorator
+
+
+# ==== torch/utils_caffe2/ws_utils.py ==========================================
+
+
+class ScopedWS(object):
+ def __init__(self, ws_name, is_reset, is_cleanup=False):
+ self.ws_name = ws_name
+ self.is_reset = is_reset
+ self.is_cleanup = is_cleanup
+ self.org_ws = ""
+
+ def __enter__(self):
+ self.org_ws = workspace.CurrentWorkspace()
+ if self.ws_name is not None:
+ workspace.SwitchWorkspace(self.ws_name, True)
+ if self.is_reset:
+ workspace.ResetWorkspace()
+
+ return workspace
+
+ def __exit__(self, *args):
+ if self.is_cleanup:
+ workspace.ResetWorkspace()
+ if self.ws_name is not None:
+ workspace.SwitchWorkspace(self.org_ws)
+
+
+def fetch_any_blob(name):
+ bb = None
+ try:
+ bb = workspace.FetchBlob(name)
+ except TypeError:
+ bb = workspace.FetchInt8Blob(name)
+ except Exception as e:
+ logger.error("Get blob {} error: {}".format(name, e))
+
+ return bb
+
+
+# ==== torch/utils_caffe2/protobuf.py ==========================================
+
+
+def get_pb_arg(pb, arg_name):
+ for x in pb.arg:
+ if x.name == arg_name:
+ return x
+ return None
+
+
+def get_pb_arg_valf(pb, arg_name, default_val):
+ arg = get_pb_arg(pb, arg_name)
+ return arg.f if arg is not None else default_val
+
+
+def get_pb_arg_floats(pb, arg_name, default_val):
+ arg = get_pb_arg(pb, arg_name)
+ return list(map(float, arg.floats)) if arg is not None else default_val
+
+
+def get_pb_arg_ints(pb, arg_name, default_val):
+ arg = get_pb_arg(pb, arg_name)
+ return list(map(int, arg.ints)) if arg is not None else default_val
+
+
+def get_pb_arg_vali(pb, arg_name, default_val):
+ arg = get_pb_arg(pb, arg_name)
+ return arg.i if arg is not None else default_val
+
+
+def get_pb_arg_vals(pb, arg_name, default_val):
+ arg = get_pb_arg(pb, arg_name)
+ return arg.s if arg is not None else default_val
+
+
+def get_pb_arg_valstrings(pb, arg_name, default_val):
+ arg = get_pb_arg(pb, arg_name)
+ return list(arg.strings) if arg is not None else default_val
+
+
+def check_set_pb_arg(pb, arg_name, arg_attr, arg_value, allow_override=False):
+ arg = get_pb_arg(pb, arg_name)
+ if arg is None:
+ arg = putils.MakeArgument(arg_name, arg_value)
+ assert hasattr(arg, arg_attr)
+ pb.arg.extend([arg])
+ if allow_override and getattr(arg, arg_attr) != arg_value:
+ logger.warning(
+ "Override argument {}: {} -> {}".format(arg_name, getattr(arg, arg_attr), arg_value)
+ )
+ setattr(arg, arg_attr, arg_value)
+ else:
+ assert arg is not None
+ assert getattr(arg, arg_attr) == arg_value, "Existing value {}, new value {}".format(
+ getattr(arg, arg_attr), arg_value
+ )
+
+
+def _create_const_fill_op_from_numpy(name, tensor, device_option=None):
+ assert type(tensor) == np.ndarray
+ kTypeNameMapper = {
+ np.dtype("float32"): "GivenTensorFill",
+ np.dtype("int32"): "GivenTensorIntFill",
+ np.dtype("int64"): "GivenTensorInt64Fill",
+ np.dtype("uint8"): "GivenTensorStringFill",
+ }
+
+ args_dict = {}
+ if tensor.dtype == np.dtype("uint8"):
+ args_dict.update({"values": [str(tensor.data)], "shape": [1]})
+ else:
+ args_dict.update({"values": tensor, "shape": tensor.shape})
+
+ if device_option is not None:
+ args_dict["device_option"] = device_option
+
+ return core.CreateOperator(kTypeNameMapper[tensor.dtype], [], [name], **args_dict)
+
+
+def _create_const_fill_op_from_c2_int8_tensor(name, int8_tensor):
+ assert type(int8_tensor) == workspace.Int8Tensor
+ kTypeNameMapper = {
+ np.dtype("int32"): "Int8GivenIntTensorFill",
+ np.dtype("uint8"): "Int8GivenTensorFill",
+ }
+
+ tensor = int8_tensor.data
+ assert tensor.dtype in [np.dtype("uint8"), np.dtype("int32")]
+ values = tensor.tobytes() if tensor.dtype == np.dtype("uint8") else tensor
+
+ return core.CreateOperator(
+ kTypeNameMapper[tensor.dtype],
+ [],
+ [name],
+ values=values,
+ shape=tensor.shape,
+ Y_scale=int8_tensor.scale,
+ Y_zero_point=int8_tensor.zero_point,
+ )
+
+
+def create_const_fill_op(
+ name: str,
+ blob: Union[np.ndarray, workspace.Int8Tensor],
+ device_option: Optional[caffe2_pb2.DeviceOption] = None,
+) -> caffe2_pb2.OperatorDef:
+ """
+ Given a blob object, return the Caffe2 operator that creates this blob
+ as constant. Currently support NumPy tensor and Caffe2 Int8Tensor.
+ """
+
+ tensor_type = type(blob)
+ assert tensor_type in [
+ np.ndarray,
+ workspace.Int8Tensor,
+ ], 'Error when creating const fill op for "{}", unsupported blob type: {}'.format(
+ name, type(blob)
+ )
+
+ if tensor_type == np.ndarray:
+ return _create_const_fill_op_from_numpy(name, blob, device_option)
+ elif tensor_type == workspace.Int8Tensor:
+ assert device_option is None
+ return _create_const_fill_op_from_c2_int8_tensor(name, blob)
+
+
+def construct_init_net_from_params(
+ params: Dict[str, Any], device_options: Optional[Dict[str, caffe2_pb2.DeviceOption]] = None
+) -> caffe2_pb2.NetDef:
+ """
+ Construct the init_net from params dictionary
+ """
+ init_net = caffe2_pb2.NetDef()
+ device_options = device_options or {}
+ for name, blob in params.items():
+ if isinstance(blob, str):
+ logger.warning(
+ (
+ "Blob {} with type {} is not supported in generating init net,"
+ " skipped.".format(name, type(blob))
+ )
+ )
+ continue
+ init_net.op.extend(
+ [create_const_fill_op(name, blob, device_option=device_options.get(name, None))]
+ )
+ init_net.external_output.append(name)
+ return init_net
+
+
+def get_producer_map(ssa):
+ """
+ Return dict from versioned blob to (i, j),
+ where i is index of producer op, j is the index of output of that op.
+ """
+ producer_map = {}
+ for i in range(len(ssa)):
+ outputs = ssa[i][1]
+ for j, outp in enumerate(outputs):
+ producer_map[outp] = (i, j)
+ return producer_map
+
+
+def get_consumer_map(ssa):
+ """
+ Return dict from versioned blob to list of (i, j),
+ where i is index of consumer op, j is the index of input of that op.
+ """
+ consumer_map = collections.defaultdict(list)
+ for i in range(len(ssa)):
+ inputs = ssa[i][0]
+ for j, inp in enumerate(inputs):
+ consumer_map[inp].append((i, j))
+ return consumer_map
+
+
+def get_params_from_init_net(
+ init_net: caffe2_pb2.NetDef,
+) -> [Dict[str, Any], Dict[str, caffe2_pb2.DeviceOption]]:
+ """
+ Take the output blobs from init_net by running it.
+ Outputs:
+ params: dict from blob name to numpy array
+ device_options: dict from blob name to the device option of its creating op
+ """
+ # NOTE: this assumes that the params is determined by producer op with the
+ # only exception be CopyGPUToCPU which is CUDA op but returns CPU tensor.
+ def _get_device_option(producer_op):
+ if producer_op.type == "CopyGPUToCPU":
+ return caffe2_pb2.DeviceOption()
+ else:
+ return producer_op.device_option
+
+ with ScopedWS("__get_params_from_init_net__", is_reset=True, is_cleanup=True) as ws:
+ ws.RunNetOnce(init_net)
+ params = {b: fetch_any_blob(b) for b in init_net.external_output}
+ ssa, versions = core.get_ssa(init_net)
+ producer_map = get_producer_map(ssa)
+ device_options = {
+ b: _get_device_option(init_net.op[producer_map[(b, versions[b])][0]])
+ for b in init_net.external_output
+ }
+ return params, device_options
+
+
+def _updater_raise(op, input_types, output_types):
+ raise RuntimeError(
+ "Failed to apply updater for op {} given input_types {} and"
+ " output_types {}".format(op, input_types, output_types)
+ )
+
+
+def _generic_status_identifier(
+ predict_net: caffe2_pb2.NetDef,
+ status_updater: Callable,
+ known_status: Dict[Tuple[str, int], Any],
+) -> Dict[Tuple[str, int], Any]:
+ """
+ Statically infer the status of each blob, the status can be such as device type
+ (CPU/GPU), layout (NCHW/NHWC), data type (float32/int8), etc. "Blob" here
+ is versioned blob (Tuple[str, int]) in the format compatible with ssa.
+ Inputs:
+ predict_net: the caffe2 network
+ status_updater: a callable, given an op and the status of its input/output,
+ it returns the updated status of input/output. `None` is used for
+ representing unknown status.
+ known_status: a dict containing known status, used as initialization.
+ Outputs:
+ A dict mapping from versioned blob to its status
+ """
+ ssa, versions = core.get_ssa(predict_net)
+ versioned_ext_input = [(b, 0) for b in predict_net.external_input]
+ versioned_ext_output = [(b, versions[b]) for b in predict_net.external_output]
+ all_versioned_blobs = set().union(*[set(x[0] + x[1]) for x in ssa])
+
+ allowed_vbs = all_versioned_blobs.union(versioned_ext_input).union(versioned_ext_output)
+ assert all(k in allowed_vbs for k in known_status)
+ assert all(v is not None for v in known_status.values())
+ _known_status = copy.deepcopy(known_status)
+
+ def _check_and_update(key, value):
+ assert value is not None
+ if key in _known_status:
+ if not _known_status[key] == value:
+ raise RuntimeError(
+ "Confilict status for {}, existing status {}, new status {}".format(
+ key, _known_status[key], value
+ )
+ )
+ _known_status[key] = value
+
+ def _update_i(op, ssa_i):
+ versioned_inputs = ssa_i[0]
+ versioned_outputs = ssa_i[1]
+
+ inputs_status = [_known_status.get(b, None) for b in versioned_inputs]
+ outputs_status = [_known_status.get(b, None) for b in versioned_outputs]
+
+ new_inputs_status, new_outputs_status = status_updater(op, inputs_status, outputs_status)
+
+ for versioned_blob, status in zip(
+ versioned_inputs + versioned_outputs, new_inputs_status + new_outputs_status
+ ):
+ if status is not None:
+ _check_and_update(versioned_blob, status)
+
+ for op, ssa_i in zip(predict_net.op, ssa):
+ _update_i(op, ssa_i)
+ for op, ssa_i in zip(reversed(predict_net.op), reversed(ssa)):
+ _update_i(op, ssa_i)
+
+ # NOTE: This strictly checks all the blob from predict_net must be assgined
+ # a known status. However sometimes it's impossible (eg. having deadend op),
+ # we may relax this constraint if
+ for k in all_versioned_blobs:
+ if k not in _known_status:
+ raise NotImplementedError(
+ "Can not infer the status for {}. Currently only support the case where"
+ " a single forward and backward pass can identify status for all blobs.".format(k)
+ )
+
+ return _known_status
+
+
+def infer_device_type(
+ predict_net: caffe2_pb2.NetDef,
+ known_status: Dict[Tuple[str, int], Any],
+ device_name_style: str = "caffe2",
+) -> Dict[Tuple[str, int], str]:
+ """Return the device type ("cpu" or "gpu"/"cuda") of each (versioned) blob"""
+
+ assert device_name_style in ["caffe2", "pytorch"]
+ _CPU_STR = "cpu"
+ _GPU_STR = "gpu" if device_name_style == "caffe2" else "cuda"
+
+ def _copy_cpu_to_gpu_updater(op, input_types, output_types):
+ if input_types[0] == _GPU_STR or output_types[0] == _CPU_STR:
+ _updater_raise(op, input_types, output_types)
+ return ([_CPU_STR], [_GPU_STR])
+
+ def _copy_gpu_to_cpu_updater(op, input_types, output_types):
+ if input_types[0] == _CPU_STR or output_types[0] == _GPU_STR:
+ _updater_raise(op, input_types, output_types)
+ return ([_GPU_STR], [_CPU_STR])
+
+ def _other_ops_updater(op, input_types, output_types):
+ non_none_types = [x for x in input_types + output_types if x is not None]
+ if len(non_none_types) > 0:
+ the_type = non_none_types[0]
+ if not all(x == the_type for x in non_none_types):
+ _updater_raise(op, input_types, output_types)
+ else:
+ the_type = None
+ return ([the_type for _ in op.input], [the_type for _ in op.output])
+
+ def _device_updater(op, *args, **kwargs):
+ return {
+ "CopyCPUToGPU": _copy_cpu_to_gpu_updater,
+ "CopyGPUToCPU": _copy_gpu_to_cpu_updater,
+ }.get(op.type, _other_ops_updater)(op, *args, **kwargs)
+
+ return _generic_status_identifier(predict_net, _device_updater, known_status)
+
+
+# ==== torch/utils_caffe2/vis.py ===============================================
+
+
+def _modify_blob_names(ops, blob_rename_f):
+ ret = []
+
+ def _replace_list(blob_list, replaced_list):
+ del blob_list[:]
+ blob_list.extend(replaced_list)
+
+ for x in ops:
+ cur = copy.deepcopy(x)
+ _replace_list(cur.input, list(map(blob_rename_f, cur.input)))
+ _replace_list(cur.output, list(map(blob_rename_f, cur.output)))
+ ret.append(cur)
+
+ return ret
+
+
+def _rename_blob(name, blob_sizes, blob_ranges):
+ def _list_to_str(bsize):
+ ret = ", ".join([str(x) for x in bsize])
+ ret = "[" + ret + "]"
+ return ret
+
+ ret = name
+ if blob_sizes is not None and name in blob_sizes:
+ ret += "\n" + _list_to_str(blob_sizes[name])
+ if blob_ranges is not None and name in blob_ranges:
+ ret += "\n" + _list_to_str(blob_ranges[name])
+
+ return ret
+
+
+# graph_name could not contain word 'graph'
+def save_graph(net, file_name, graph_name="net", op_only=True, blob_sizes=None, blob_ranges=None):
+ blob_rename_f = functools.partial(_rename_blob, blob_sizes=blob_sizes, blob_ranges=blob_ranges)
+ return save_graph_base(net, file_name, graph_name, op_only, blob_rename_f)
+
+
+def save_graph_base(net, file_name, graph_name="net", op_only=True, blob_rename_func=None):
+ graph = None
+ ops = net.op
+ if blob_rename_func is not None:
+ ops = _modify_blob_names(ops, blob_rename_func)
+ if not op_only:
+ graph = net_drawer.GetPydotGraph(ops, graph_name, rankdir="TB")
+ else:
+ graph = net_drawer.GetPydotGraphMinimal(
+ ops, graph_name, rankdir="TB", minimal_dependency=True
+ )
+
+ try:
+ par_dir = os.path.dirname(file_name)
+ if not os.path.exists(par_dir):
+ os.makedirs(par_dir)
+
+ format = os.path.splitext(os.path.basename(file_name))[-1]
+ if format == ".png":
+ graph.write_png(file_name)
+ elif format == ".pdf":
+ graph.write_pdf(file_name)
+ elif format == ".svg":
+ graph.write_svg(file_name)
+ else:
+ print("Incorrect format {}".format(format))
+ except Exception as e:
+ print("Error when writing graph to image {}".format(e))
+
+ return graph
+
+
+# ==== torch/utils_toffee/aten_to_caffe2.py ====================================
+
+
+def group_norm_replace_aten_with_caffe2(predict_net: caffe2_pb2.NetDef):
+ """
+ For ONNX exported model, GroupNorm will be represented as ATen op,
+ this can be a drop in replacement from ATen to GroupNorm
+ """
+ count = 0
+ for op in predict_net.op:
+ if op.type == "ATen":
+ op_name = get_pb_arg_vals(op, "operator", None) # return byte in py3
+ if op_name and op_name.decode() == "group_norm":
+ op.arg.remove(get_pb_arg(op, "operator"))
+
+ if get_pb_arg_vali(op, "cudnn_enabled", None):
+ op.arg.remove(get_pb_arg(op, "cudnn_enabled"))
+
+ num_groups = get_pb_arg_vali(op, "num_groups", None)
+ if num_groups is not None:
+ op.arg.remove(get_pb_arg(op, "num_groups"))
+ check_set_pb_arg(op, "group", "i", num_groups)
+
+ op.type = "GroupNorm"
+ count += 1
+ if count > 1:
+ logger.info("Replaced {} ATen operator to GroupNormOp".format(count))
+
+
+# ==== torch/utils_toffee/alias.py =============================================
+
+
+def alias(x, name, is_backward=False):
+ if not torch.onnx.is_in_onnx_export():
+ return x
+ assert isinstance(x, torch.Tensor)
+ return torch.ops._caffe2.AliasWithName(x, name, is_backward=is_backward)
+
+
+def fuse_alias_placeholder(predict_net, init_net):
+ """Remove AliasWithName placeholder and rename the input/output of it"""
+ # First we finish all the re-naming
+ for i, op in enumerate(predict_net.op):
+ if op.type == "AliasWithName":
+ assert len(op.input) == 1
+ assert len(op.output) == 1
+ name = get_pb_arg_vals(op, "name", None).decode()
+ is_backward = bool(get_pb_arg_vali(op, "is_backward", 0))
+ rename_op_input(predict_net, init_net, i, 0, name, from_producer=is_backward)
+ rename_op_output(predict_net, i, 0, name)
+
+ # Remove AliasWithName, should be very safe since it's a non-op
+ new_ops = []
+ for op in predict_net.op:
+ if op.type != "AliasWithName":
+ new_ops.append(op)
+ else:
+ # safety check
+ assert op.input == op.output
+ assert op.input[0] == op.arg[0].s.decode()
+ del predict_net.op[:]
+ predict_net.op.extend(new_ops)
+
+
+# ==== torch/utils_caffe2/graph_transform.py ===================================
+
+
+class IllegalGraphTransformError(ValueError):
+ """When a graph transform function call can't be executed."""
+
+
+def _rename_versioned_blob_in_proto(
+ proto: caffe2_pb2.NetDef,
+ old_name: str,
+ new_name: str,
+ version: int,
+ ssa: List[Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]],
+ start_versions: Dict[str, int],
+ end_versions: Dict[str, int],
+):
+ """In given proto, rename all blobs with matched version"""
+ # Operater list
+ for op, i_th_ssa in zip(proto.op, ssa):
+ versioned_inputs, versioned_outputs = i_th_ssa
+ for i in range(len(op.input)):
+ if versioned_inputs[i] == (old_name, version):
+ op.input[i] = new_name
+ for i in range(len(op.output)):
+ if versioned_outputs[i] == (old_name, version):
+ op.output[i] = new_name
+ # external_input
+ if start_versions.get(old_name, 0) == version:
+ for i in range(len(proto.external_input)):
+ if proto.external_input[i] == old_name:
+ proto.external_input[i] = new_name
+ # external_output
+ if end_versions.get(old_name, 0) == version:
+ for i in range(len(proto.external_output)):
+ if proto.external_output[i] == old_name:
+ proto.external_output[i] = new_name
+
+
+def rename_op_input(
+ predict_net: caffe2_pb2.NetDef,
+ init_net: caffe2_pb2.NetDef,
+ op_id: int,
+ input_id: int,
+ new_name: str,
+ from_producer: bool = False,
+):
+ """
+ Rename the op_id-th operator in predict_net, change it's input_id-th input's
+ name to the new_name. It also does automatic re-route and change
+ external_input and init_net if necessary.
+ - It requires the input is only consumed by this op.
+ - This function modifies predict_net and init_net in-place.
+ - When from_producer is enable, this also updates other operators that consumes
+ the same input. Be cautious because may trigger unintended behavior.
+ """
+ assert isinstance(predict_net, caffe2_pb2.NetDef)
+ assert isinstance(init_net, caffe2_pb2.NetDef)
+
+ init_net_ssa, init_net_versions = core.get_ssa(init_net)
+ predict_net_ssa, predict_net_versions = core.get_ssa(
+ predict_net, copy.deepcopy(init_net_versions)
+ )
+
+ versioned_inputs, versioned_outputs = predict_net_ssa[op_id]
+ old_name, version = versioned_inputs[input_id]
+
+ if from_producer:
+ producer_map = get_producer_map(predict_net_ssa)
+ if not (old_name, version) in producer_map:
+ raise NotImplementedError(
+ "Can't find producer, the input {} is probably from"
+ " init_net, this is not supported yet.".format(old_name)
+ )
+ producer = producer_map[(old_name, version)]
+ rename_op_output(predict_net, producer[0], producer[1], new_name)
+ return
+
+ def contain_targets(op_ssa):
+ return (old_name, version) in op_ssa[0]
+
+ is_consumer = [contain_targets(op_ssa) for op_ssa in predict_net_ssa]
+ if sum(is_consumer) > 1:
+ raise IllegalGraphTransformError(
+ (
+ "Input '{}' of operator(#{}) are consumed by other ops, please use"
+ + " rename_op_output on the producer instead. Offending op: \n{}"
+ ).format(old_name, op_id, predict_net.op[op_id])
+ )
+
+ # update init_net
+ _rename_versioned_blob_in_proto(
+ init_net, old_name, new_name, version, init_net_ssa, {}, init_net_versions
+ )
+ # update predict_net
+ _rename_versioned_blob_in_proto(
+ predict_net,
+ old_name,
+ new_name,
+ version,
+ predict_net_ssa,
+ init_net_versions,
+ predict_net_versions,
+ )
+
+
+def rename_op_output(predict_net: caffe2_pb2.NetDef, op_id: int, output_id: int, new_name: str):
+ """
+ Rename the op_id-th operator in predict_net, change it's output_id-th input's
+ name to the new_name. It also does automatic re-route and change
+ external_output and if necessary.
+ - It allows multiple consumers of its output.
+ - This function modifies predict_net in-place, doesn't need init_net.
+ """
+ assert isinstance(predict_net, caffe2_pb2.NetDef)
+
+ ssa, blob_versions = core.get_ssa(predict_net)
+
+ versioned_inputs, versioned_outputs = ssa[op_id]
+ old_name, version = versioned_outputs[output_id]
+
+ # update predict_net
+ _rename_versioned_blob_in_proto(
+ predict_net, old_name, new_name, version, ssa, {}, blob_versions
+ )
+
+
+def get_sub_graph_external_input_output(
+ predict_net: caffe2_pb2.NetDef, sub_graph_op_indices: List[int]
+) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
+ """
+ Return the list of external input/output of sub-graph,
+ each element is tuple of the name and corresponding version in predict_net.
+
+ external input/output is defined the same way as caffe2 NetDef.
+ """
+ ssa, versions = core.get_ssa(predict_net)
+
+ all_inputs = []
+ all_outputs = []
+ for op_id in sub_graph_op_indices:
+ all_inputs += [inp for inp in ssa[op_id][0] if inp not in all_inputs]
+ all_outputs += list(ssa[op_id][1]) # ssa output won't repeat
+
+ # for versioned blobs, external inputs are just those blob in all_inputs
+ # but not in all_outputs
+ ext_inputs = [inp for inp in all_inputs if inp not in all_outputs]
+
+ # external outputs are essentially outputs of this subgraph that are used
+ # outside of this sub-graph (including predict_net.external_output)
+ all_other_inputs = sum(
+ (ssa[i][0] for i in range(len(ssa)) if i not in sub_graph_op_indices),
+ [(outp, versions[outp]) for outp in predict_net.external_output],
+ )
+ ext_outputs = [outp for outp in all_outputs if outp in set(all_other_inputs)]
+
+ return ext_inputs, ext_outputs
+
+
+class DiGraph:
+ """A DAG representation of caffe2 graph, each vertice is a versioned blob."""
+
+ def __init__(self):
+ self.vertices = set()
+ self.graph = collections.defaultdict(list)
+
+ def add_edge(self, u, v):
+ self.graph[u].append(v)
+ self.vertices.add(u)
+ self.vertices.add(v)
+
+ # grab from https://www.geeksforgeeks.org/find-paths-given-source-destination/
+ def get_all_paths(self, s, d):
+ visited = {k: False for k in self.vertices}
+ path = []
+ all_paths = []
+
+ def _get_all_paths_util(graph, u, d, visited, path):
+ visited[u] = True
+ path.append(u)
+ if u == d:
+ all_paths.append(copy.deepcopy(path))
+ else:
+ for i in graph[u]:
+ if not visited[i]:
+ _get_all_paths_util(graph, i, d, visited, path)
+ path.pop()
+ visited[u] = False
+
+ _get_all_paths_util(self.graph, s, d, visited, path)
+ return all_paths
+
+ @staticmethod
+ def from_ssa(ssa):
+ graph = DiGraph()
+ for op_id in range(len(ssa)):
+ for inp in ssa[op_id][0]:
+ for outp in ssa[op_id][1]:
+ graph.add_edge(inp, outp)
+ return graph
+
+
+def _get_dependency_chain(ssa, versioned_target, versioned_source):
+ """
+ Return the index list of relevant operator to produce target blob from source blob,
+ if there's no dependency, return empty list.
+ """
+
+ # finding all paths between nodes can be O(N!), thus we can only search
+ # in the subgraph using the op starting from the first consumer of source blob
+ # to the producer of the target blob.
+ consumer_map = get_consumer_map(ssa)
+ producer_map = get_producer_map(ssa)
+ start_op = min(x[0] for x in consumer_map[versioned_source]) - 15
+ end_op = (
+ producer_map[versioned_target][0] + 15 if versioned_target in producer_map else start_op
+ )
+ sub_graph_ssa = ssa[start_op : end_op + 1]
+ if len(sub_graph_ssa) > 30:
+ logger.warning(
+ "Subgraph bebetween {} and {} is large (from op#{} to op#{}), it"
+ " might take non-trival time to find all paths between them.".format(
+ versioned_source, versioned_target, start_op, end_op
+ )
+ )
+
+ dag = DiGraph.from_ssa(sub_graph_ssa)
+ paths = dag.get_all_paths(versioned_source, versioned_target) # include two ends
+ ops_in_paths = [[producer_map[blob][0] for blob in path[1:]] for path in paths]
+ return sorted(set().union(*[set(ops) for ops in ops_in_paths]))
+
+
+def identify_reshape_sub_graph(predict_net: caffe2_pb2.NetDef) -> List[List[int]]:
+ """
+ Idenfity the reshape sub-graph in a protobuf.
+ The reshape sub-graph is defined as matching the following pattern:
+
+ (input_blob) -> Op_1 -> ... -> Op_N -> (new_shape) -─┐
+ └-------------------------------------------> Reshape -> (output_blob)
+
+ Return:
+ List of sub-graphs, each sub-graph is represented as a list of indices
+ of the relavent ops, [Op_1, Op_2, ..., Op_N, Reshape]
+ """
+
+ ssa, _ = core.get_ssa(predict_net)
+
+ ret = []
+ for i, op in enumerate(predict_net.op):
+ if op.type == "Reshape":
+ assert len(op.input) == 2
+ input_ssa = ssa[i][0]
+ data_source = input_ssa[0]
+ shape_source = input_ssa[1]
+ op_indices = _get_dependency_chain(ssa, shape_source, data_source)
+ ret.append(op_indices + [i])
+ return ret
+
+
+def remove_reshape_for_fc(predict_net, params):
+ """
+ In PyTorch nn.Linear has to take 2D tensor, this often leads to reshape
+ a 4D tensor to 2D by calling .view(). However this (dynamic) reshaping
+ doesn't work well with ONNX and Int8 tools, and cause using extra
+ ops (eg. ExpandDims) that might not be available on mobile.
+ Luckily Caffe2 supports 4D tensor for FC, so we can remove those reshape
+ after exporting ONNX model.
+ """
+ from caffe2.python import core
+
+ # find all reshape sub-graph that can be removed, which is now all Reshape
+ # sub-graph whose output is only consumed by FC.
+ # TODO: to make it safer, we may need the actually value to better determine
+ # if a Reshape before FC is removable.
+ reshape_sub_graphs = identify_reshape_sub_graph(predict_net)
+ sub_graphs_to_remove = []
+ for reshape_sub_graph in reshape_sub_graphs:
+ reshape_op_id = reshape_sub_graph[-1]
+ assert predict_net.op[reshape_op_id].type == "Reshape"
+ ssa, _ = core.get_ssa(predict_net)
+ reshape_output = ssa[reshape_op_id][1][0]
+ consumers = [i for i in range(len(ssa)) if reshape_output in ssa[i][0]]
+ if all(predict_net.op[consumer].type == "FC" for consumer in consumers):
+ # safety check if the sub-graph is isolated, for this reshape sub-graph,
+ # it means it has one non-param external input and one external output.
+ ext_inputs, ext_outputs = get_sub_graph_external_input_output(
+ predict_net, reshape_sub_graph
+ )
+ non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
+ if len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1:
+ sub_graphs_to_remove.append(reshape_sub_graph)
+
+ # perform removing subgraph by:
+ # 1: rename the Reshape's output to its input, then the graph can be
+ # seen as in-place itentify, meaning whose external input/output are the same.
+ # 2: simply remove those ops.
+ remove_op_ids = []
+ params_to_remove = []
+ for sub_graph in sub_graphs_to_remove:
+ logger.info(
+ "Remove Reshape sub-graph:\n{}".format(
+ "".join(["(#{:>4})\n{}".format(i, predict_net.op[i]) for i in sub_graph])
+ )
+ )
+ reshape_op_id = sub_graph[-1]
+ new_reshap_output = predict_net.op[reshape_op_id].input[0]
+ rename_op_output(predict_net, reshape_op_id, 0, new_reshap_output)
+ ext_inputs, ext_outputs = get_sub_graph_external_input_output(predict_net, sub_graph)
+ non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
+ params_ext_inputs = [inp for inp in ext_inputs if inp[1] == 0]
+ assert len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1
+ assert ext_outputs[0][0] == non_params_ext_inputs[0][0]
+ assert ext_outputs[0][1] == non_params_ext_inputs[0][1] + 1
+ remove_op_ids.extend(sub_graph)
+ params_to_remove.extend(params_ext_inputs)
+
+ predict_net = copy.deepcopy(predict_net)
+ new_ops = [op for i, op in enumerate(predict_net.op) if i not in remove_op_ids]
+ del predict_net.op[:]
+ predict_net.op.extend(new_ops)
+ for versioned_params in params_to_remove:
+ name = versioned_params[0]
+ logger.info("Remove params: {} from init_net and predict_net.external_input".format(name))
+ del params[name]
+ predict_net.external_input.remove(name)
+
+ return predict_net, params
+
+
+def fuse_copy_between_cpu_and_gpu(predict_net: caffe2_pb2.NetDef):
+ """
+ In-place fuse extra copy ops between cpu/gpu for the following case:
+ a -CopyAToB-> b -CopyBToA> c1 -NextOp1-> d1
+ -CopyBToA> c2 -NextOp2-> d2
+ The fused network will look like:
+ a -NextOp1-> d1
+ -NextOp2-> d2
+ """
+
+ _COPY_OPS = ["CopyCPUToGPU", "CopyGPUToCPU"]
+
+ def _fuse_once(predict_net):
+ ssa, blob_versions = core.get_ssa(predict_net)
+ consumer_map = get_consumer_map(ssa)
+ versioned_external_output = [
+ (name, blob_versions[name]) for name in predict_net.external_output
+ ]
+
+ for op_id, op in enumerate(predict_net.op):
+ if op.type in _COPY_OPS:
+ fw_copy_versioned_output = ssa[op_id][1][0]
+ consumer_ids = [x[0] for x in consumer_map[fw_copy_versioned_output]]
+ reverse_op_type = _COPY_OPS[1 - _COPY_OPS.index(op.type)]
+
+ is_fusable = (
+ len(consumer_ids) > 0
+ and fw_copy_versioned_output not in versioned_external_output
+ and all(
+ predict_net.op[_op_id].type == reverse_op_type
+ and ssa[_op_id][1][0] not in versioned_external_output
+ for _op_id in consumer_ids
+ )
+ )
+
+ if is_fusable:
+ for rv_copy_op_id in consumer_ids:
+ # making each NextOp uses "a" directly and removing Copy ops
+ rs_copy_versioned_output = ssa[rv_copy_op_id][1][0]
+ next_op_id, inp_id = consumer_map[rs_copy_versioned_output][0]
+ predict_net.op[next_op_id].input[inp_id] = op.input[0]
+ # remove CopyOps
+ new_ops = [
+ op
+ for i, op in enumerate(predict_net.op)
+ if i != op_id and i not in consumer_ids
+ ]
+ del predict_net.op[:]
+ predict_net.op.extend(new_ops)
+ return True
+
+ return False
+
+ # _fuse_once returns False is nothing can be fused
+ while _fuse_once(predict_net):
+ pass
+
+
+def remove_dead_end_ops(net_def: caffe2_pb2.NetDef):
+ """remove ops if its output is not used or not in external_output"""
+ ssa, versions = core.get_ssa(net_def)
+ versioned_external_output = [(name, versions[name]) for name in net_def.external_output]
+ consumer_map = get_consumer_map(ssa)
+ removed_op_ids = set()
+
+ def _is_dead_end(versioned_blob):
+ return not (
+ versioned_blob in versioned_external_output
+ or (
+ len(consumer_map[versioned_blob]) > 0
+ and all(x[0] not in removed_op_ids for x in consumer_map[versioned_blob])
+ )
+ )
+
+ for i, ssa_i in reversed(list(enumerate(ssa))):
+ versioned_outputs = ssa_i[1]
+ if all(_is_dead_end(outp) for outp in versioned_outputs):
+ removed_op_ids.add(i)
+
+ # simply removing those deadend ops should have no effect to external_output
+ new_ops = [op for i, op in enumerate(net_def.op) if i not in removed_op_ids]
+ del net_def.op[:]
+ net_def.op.extend(new_ops)
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/export/torchscript.py b/sd-webui-controlnet/annotator/oneformer/detectron2/export/torchscript.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ce1c81e1b7abb65415055ae0d1d4b83e1ae111d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/export/torchscript.py
@@ -0,0 +1,132 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import os
+import torch
+
+from annotator.oneformer.detectron2.utils.file_io import PathManager
+
+from .torchscript_patch import freeze_training_mode, patch_instances
+
+__all__ = ["scripting_with_instances", "dump_torchscript_IR"]
+
+
+def scripting_with_instances(model, fields):
+ """
+ Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since
+ attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult
+ for scripting to support it out of the box. This function is made to support scripting
+ a model that uses :class:`Instances`. It does the following:
+
+ 1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``,
+ but with all attributes been "static".
+ The attributes need to be statically declared in the ``fields`` argument.
+ 2. Register ``new_Instances``, and force scripting compiler to
+ use it when trying to compile ``Instances``.
+
+ After this function, the process will be reverted. User should be able to script another model
+ using different fields.
+
+ Example:
+ Assume that ``Instances`` in the model consist of two attributes named
+ ``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and
+ :class:`Tensor` respectively during inference. You can call this function like:
+ ::
+ fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor}
+ torchscipt_model = scripting_with_instances(model, fields)
+
+ Note:
+ It only support models in evaluation mode.
+
+ Args:
+ model (nn.Module): The input model to be exported by scripting.
+ fields (Dict[str, type]): Attribute names and corresponding type that
+ ``Instances`` will use in the model. Note that all attributes used in ``Instances``
+ need to be added, regardless of whether they are inputs/outputs of the model.
+ Data type not defined in detectron2 is not supported for now.
+
+ Returns:
+ torch.jit.ScriptModule: the model in torchscript format
+ """
+ assert (
+ not model.training
+ ), "Currently we only support exporting models in evaluation mode to torchscript"
+
+ with freeze_training_mode(model), patch_instances(fields):
+ scripted_model = torch.jit.script(model)
+ return scripted_model
+
+
+# alias for old name
+export_torchscript_with_instances = scripting_with_instances
+
+
+def dump_torchscript_IR(model, dir):
+ """
+ Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph,
+ inlined graph). Useful for debugging.
+
+ Args:
+ model (TracedModule/ScriptModule/ScriptFUnction): traced or scripted module
+ dir (str): output directory to dump files.
+ """
+ dir = os.path.expanduser(dir)
+ PathManager.mkdirs(dir)
+
+ def _get_script_mod(mod):
+ if isinstance(mod, torch.jit.TracedModule):
+ return mod._actual_script_module
+ return mod
+
+ # Dump pretty-printed code: https://pytorch.org/docs/stable/jit.html#inspecting-code
+ with PathManager.open(os.path.join(dir, "model_ts_code.txt"), "w") as f:
+
+ def get_code(mod):
+ # Try a few ways to get code using private attributes.
+ try:
+ # This contains more information than just `mod.code`
+ return _get_script_mod(mod)._c.code
+ except AttributeError:
+ pass
+ try:
+ return mod.code
+ except AttributeError:
+ return None
+
+ def dump_code(prefix, mod):
+ code = get_code(mod)
+ name = prefix or "root model"
+ if code is None:
+ f.write(f"Could not found code for {name} (type={mod.original_name})\n")
+ f.write("\n")
+ else:
+ f.write(f"\nCode for {name}, type={mod.original_name}:\n")
+ f.write(code)
+ f.write("\n")
+ f.write("-" * 80)
+
+ for name, m in mod.named_children():
+ dump_code(prefix + "." + name, m)
+
+ if isinstance(model, torch.jit.ScriptFunction):
+ f.write(get_code(model))
+ else:
+ dump_code("", model)
+
+ def _get_graph(model):
+ try:
+ # Recursively dump IR of all modules
+ return _get_script_mod(model)._c.dump_to_str(True, False, False)
+ except AttributeError:
+ return model.graph.str()
+
+ with PathManager.open(os.path.join(dir, "model_ts_IR.txt"), "w") as f:
+ f.write(_get_graph(model))
+
+ # Dump IR of the entire graph (all submodules inlined)
+ with PathManager.open(os.path.join(dir, "model_ts_IR_inlined.txt"), "w") as f:
+ f.write(str(model.inlined_graph))
+
+ if not isinstance(model, torch.jit.ScriptFunction):
+ # Dump the model structure in pytorch style
+ with PathManager.open(os.path.join(dir, "model.txt"), "w") as f:
+ f.write(str(model))
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/export/torchscript_patch.py b/sd-webui-controlnet/annotator/oneformer/detectron2/export/torchscript_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..24c69b25dbec19221bcd8fc2e928a8393dd3aaf6
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/export/torchscript_patch.py
@@ -0,0 +1,406 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import os
+import sys
+import tempfile
+from contextlib import ExitStack, contextmanager
+from copy import deepcopy
+from unittest import mock
+import torch
+from torch import nn
+
+# need some explicit imports due to https://github.com/pytorch/pytorch/issues/38964
+import annotator.oneformer.detectron2 # noqa F401
+from annotator.oneformer.detectron2.structures import Boxes, Instances
+from annotator.oneformer.detectron2.utils.env import _import_file
+
+_counter = 0
+
+
+def _clear_jit_cache():
+ from torch.jit._recursive import concrete_type_store
+ from torch.jit._state import _jit_caching_layer
+
+ concrete_type_store.type_store.clear() # for modules
+ _jit_caching_layer.clear() # for free functions
+
+
+def _add_instances_conversion_methods(newInstances):
+ """
+ Add from_instances methods to the scripted Instances class.
+ """
+ cls_name = newInstances.__name__
+
+ @torch.jit.unused
+ def from_instances(instances: Instances):
+ """
+ Create scripted Instances from original Instances
+ """
+ fields = instances.get_fields()
+ image_size = instances.image_size
+ ret = newInstances(image_size)
+ for name, val in fields.items():
+ assert hasattr(ret, f"_{name}"), f"No attribute named {name} in {cls_name}"
+ setattr(ret, name, deepcopy(val))
+ return ret
+
+ newInstances.from_instances = from_instances
+
+
+@contextmanager
+def patch_instances(fields):
+ """
+ A contextmanager, under which the Instances class in detectron2 is replaced
+ by a statically-typed scriptable class, defined by `fields`.
+ See more in `scripting_with_instances`.
+ """
+
+ with tempfile.TemporaryDirectory(prefix="detectron2") as dir, tempfile.NamedTemporaryFile(
+ mode="w", encoding="utf-8", suffix=".py", dir=dir, delete=False
+ ) as f:
+ try:
+ # Objects that use Instances should not reuse previously-compiled
+ # results in cache, because `Instances` could be a new class each time.
+ _clear_jit_cache()
+
+ cls_name, s = _gen_instance_module(fields)
+ f.write(s)
+ f.flush()
+ f.close()
+
+ module = _import(f.name)
+ new_instances = getattr(module, cls_name)
+ _ = torch.jit.script(new_instances)
+ # let torchscript think Instances was scripted already
+ Instances.__torch_script_class__ = True
+ # let torchscript find new_instances when looking for the jit type of Instances
+ Instances._jit_override_qualname = torch._jit_internal._qualified_name(new_instances)
+
+ _add_instances_conversion_methods(new_instances)
+ yield new_instances
+ finally:
+ try:
+ del Instances.__torch_script_class__
+ del Instances._jit_override_qualname
+ except AttributeError:
+ pass
+ sys.modules.pop(module.__name__)
+
+
+def _gen_instance_class(fields):
+ """
+ Args:
+ fields (dict[name: type])
+ """
+
+ class _FieldType:
+ def __init__(self, name, type_):
+ assert isinstance(name, str), f"Field name must be str, got {name}"
+ self.name = name
+ self.type_ = type_
+ self.annotation = f"{type_.__module__}.{type_.__name__}"
+
+ fields = [_FieldType(k, v) for k, v in fields.items()]
+
+ def indent(level, s):
+ return " " * 4 * level + s
+
+ lines = []
+
+ global _counter
+ _counter += 1
+
+ cls_name = "ScriptedInstances{}".format(_counter)
+
+ field_names = tuple(x.name for x in fields)
+ extra_args = ", ".join([f"{f.name}: Optional[{f.annotation}] = None" for f in fields])
+ lines.append(
+ f"""
+class {cls_name}:
+ def __init__(self, image_size: Tuple[int, int], {extra_args}):
+ self.image_size = image_size
+ self._field_names = {field_names}
+"""
+ )
+
+ for f in fields:
+ lines.append(
+ indent(2, f"self._{f.name} = torch.jit.annotate(Optional[{f.annotation}], {f.name})")
+ )
+
+ for f in fields:
+ lines.append(
+ f"""
+ @property
+ def {f.name}(self) -> {f.annotation}:
+ # has to use a local for type refinement
+ # https://pytorch.org/docs/stable/jit_language_reference.html#optional-type-refinement
+ t = self._{f.name}
+ assert t is not None, "{f.name} is None and cannot be accessed!"
+ return t
+
+ @{f.name}.setter
+ def {f.name}(self, value: {f.annotation}) -> None:
+ self._{f.name} = value
+"""
+ )
+
+ # support method `__len__`
+ lines.append(
+ """
+ def __len__(self) -> int:
+"""
+ )
+ for f in fields:
+ lines.append(
+ f"""
+ t = self._{f.name}
+ if t is not None:
+ return len(t)
+"""
+ )
+ lines.append(
+ """
+ raise NotImplementedError("Empty Instances does not support __len__!")
+"""
+ )
+
+ # support method `has`
+ lines.append(
+ """
+ def has(self, name: str) -> bool:
+"""
+ )
+ for f in fields:
+ lines.append(
+ f"""
+ if name == "{f.name}":
+ return self._{f.name} is not None
+"""
+ )
+ lines.append(
+ """
+ return False
+"""
+ )
+
+ # support method `to`
+ none_args = ", None" * len(fields)
+ lines.append(
+ f"""
+ def to(self, device: torch.device) -> "{cls_name}":
+ ret = {cls_name}(self.image_size{none_args})
+"""
+ )
+ for f in fields:
+ if hasattr(f.type_, "to"):
+ lines.append(
+ f"""
+ t = self._{f.name}
+ if t is not None:
+ ret._{f.name} = t.to(device)
+"""
+ )
+ else:
+ # For now, ignore fields that cannot be moved to devices.
+ # Maybe can support other tensor-like classes (e.g. __torch_function__)
+ pass
+ lines.append(
+ """
+ return ret
+"""
+ )
+
+ # support method `getitem`
+ none_args = ", None" * len(fields)
+ lines.append(
+ f"""
+ def __getitem__(self, item) -> "{cls_name}":
+ ret = {cls_name}(self.image_size{none_args})
+"""
+ )
+ for f in fields:
+ lines.append(
+ f"""
+ t = self._{f.name}
+ if t is not None:
+ ret._{f.name} = t[item]
+"""
+ )
+ lines.append(
+ """
+ return ret
+"""
+ )
+
+ # support method `cat`
+ # this version does not contain checks that all instances have same size and fields
+ none_args = ", None" * len(fields)
+ lines.append(
+ f"""
+ def cat(self, instances: List["{cls_name}"]) -> "{cls_name}":
+ ret = {cls_name}(self.image_size{none_args})
+"""
+ )
+ for f in fields:
+ lines.append(
+ f"""
+ t = self._{f.name}
+ if t is not None:
+ values: List[{f.annotation}] = [x.{f.name} for x in instances]
+ if torch.jit.isinstance(t, torch.Tensor):
+ ret._{f.name} = torch.cat(values, dim=0)
+ else:
+ ret._{f.name} = t.cat(values)
+"""
+ )
+ lines.append(
+ """
+ return ret"""
+ )
+
+ # support method `get_fields()`
+ lines.append(
+ """
+ def get_fields(self) -> Dict[str, Tensor]:
+ ret = {}
+ """
+ )
+ for f in fields:
+ if f.type_ == Boxes:
+ stmt = "t.tensor"
+ elif f.type_ == torch.Tensor:
+ stmt = "t"
+ else:
+ stmt = f'assert False, "unsupported type {str(f.type_)}"'
+ lines.append(
+ f"""
+ t = self._{f.name}
+ if t is not None:
+ ret["{f.name}"] = {stmt}
+ """
+ )
+ lines.append(
+ """
+ return ret"""
+ )
+ return cls_name, os.linesep.join(lines)
+
+
+def _gen_instance_module(fields):
+ # TODO: find a more automatic way to enable import of other classes
+ s = """
+from copy import deepcopy
+import torch
+from torch import Tensor
+import typing
+from typing import *
+
+import annotator.oneformer.detectron2
+from annotator.oneformer.detectron2.structures import Boxes, Instances
+
+"""
+
+ cls_name, cls_def = _gen_instance_class(fields)
+ s += cls_def
+ return cls_name, s
+
+
+def _import(path):
+ return _import_file(
+ "{}{}".format(sys.modules[__name__].__name__, _counter), path, make_importable=True
+ )
+
+
+@contextmanager
+def patch_builtin_len(modules=()):
+ """
+ Patch the builtin len() function of a few detectron2 modules
+ to use __len__ instead, because __len__ does not convert values to
+ integers and therefore is friendly to tracing.
+
+ Args:
+ modules (list[stsr]): names of extra modules to patch len(), in
+ addition to those in detectron2.
+ """
+
+ def _new_len(obj):
+ return obj.__len__()
+
+ with ExitStack() as stack:
+ MODULES = [
+ "detectron2.modeling.roi_heads.fast_rcnn",
+ "detectron2.modeling.roi_heads.mask_head",
+ "detectron2.modeling.roi_heads.keypoint_head",
+ ] + list(modules)
+ ctxs = [stack.enter_context(mock.patch(mod + ".len")) for mod in MODULES]
+ for m in ctxs:
+ m.side_effect = _new_len
+ yield
+
+
+def patch_nonscriptable_classes():
+ """
+ Apply patches on a few nonscriptable detectron2 classes.
+ Should not have side-effects on eager usage.
+ """
+ # __prepare_scriptable__ can also be added to models for easier maintenance.
+ # But it complicates the clean model code.
+
+ from annotator.oneformer.detectron2.modeling.backbone import ResNet, FPN
+
+ # Due to https://github.com/pytorch/pytorch/issues/36061,
+ # we change backbone to use ModuleList for scripting.
+ # (note: this changes param names in state_dict)
+
+ def prepare_resnet(self):
+ ret = deepcopy(self)
+ ret.stages = nn.ModuleList(ret.stages)
+ for k in self.stage_names:
+ delattr(ret, k)
+ return ret
+
+ ResNet.__prepare_scriptable__ = prepare_resnet
+
+ def prepare_fpn(self):
+ ret = deepcopy(self)
+ ret.lateral_convs = nn.ModuleList(ret.lateral_convs)
+ ret.output_convs = nn.ModuleList(ret.output_convs)
+ for name, _ in self.named_children():
+ if name.startswith("fpn_"):
+ delattr(ret, name)
+ return ret
+
+ FPN.__prepare_scriptable__ = prepare_fpn
+
+ # Annotate some attributes to be constants for the purpose of scripting,
+ # even though they are not constants in eager mode.
+ from annotator.oneformer.detectron2.modeling.roi_heads import StandardROIHeads
+
+ if hasattr(StandardROIHeads, "__annotations__"):
+ # copy first to avoid editing annotations of base class
+ StandardROIHeads.__annotations__ = deepcopy(StandardROIHeads.__annotations__)
+ StandardROIHeads.__annotations__["mask_on"] = torch.jit.Final[bool]
+ StandardROIHeads.__annotations__["keypoint_on"] = torch.jit.Final[bool]
+
+
+# These patches are not supposed to have side-effects.
+patch_nonscriptable_classes()
+
+
+@contextmanager
+def freeze_training_mode(model):
+ """
+ A context manager that annotates the "training" attribute of every submodule
+ to constant, so that the training codepath in these modules can be
+ meta-compiled away. Upon exiting, the annotations are reverted.
+ """
+ classes = {type(x) for x in model.modules()}
+ # __constants__ is the old way to annotate constants and not compatible
+ # with __annotations__ .
+ classes = {x for x in classes if not hasattr(x, "__constants__")}
+ for cls in classes:
+ cls.__annotations__["training"] = torch.jit.Final[bool]
+ yield
+ for cls in classes:
+ cls.__annotations__["training"] = bool
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/layers/__init__.py b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..761a3d1c7afa049e9779ee9fc4d299e9aae38cad
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/__init__.py
@@ -0,0 +1,26 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from .batch_norm import FrozenBatchNorm2d, get_norm, NaiveSyncBatchNorm, CycleBatchNormList
+from .deform_conv import DeformConv, ModulatedDeformConv
+from .mask_ops import paste_masks_in_image
+from .nms import batched_nms, batched_nms_rotated, nms, nms_rotated
+from .roi_align import ROIAlign, roi_align
+from .roi_align_rotated import ROIAlignRotated, roi_align_rotated
+from .shape_spec import ShapeSpec
+from .wrappers import (
+ BatchNorm2d,
+ Conv2d,
+ ConvTranspose2d,
+ cat,
+ interpolate,
+ Linear,
+ nonzero_tuple,
+ cross_entropy,
+ empty_input_loss_func_wrapper,
+ shapes_to_tensor,
+ move_device_like,
+)
+from .blocks import CNNBlockBase, DepthwiseSeparableConv2d
+from .aspp import ASPP
+from .losses import ciou_loss, diou_loss
+
+__all__ = [k for k in globals().keys() if not k.startswith("_")]
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/layers/aspp.py b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/aspp.py
new file mode 100644
index 0000000000000000000000000000000000000000..14861aa9ede4fea6a69a49f189bcab997b558148
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/aspp.py
@@ -0,0 +1,144 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+from copy import deepcopy
+import fvcore.nn.weight_init as weight_init
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .batch_norm import get_norm
+from .blocks import DepthwiseSeparableConv2d
+from .wrappers import Conv2d
+
+
+class ASPP(nn.Module):
+ """
+ Atrous Spatial Pyramid Pooling (ASPP).
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ dilations,
+ *,
+ norm,
+ activation,
+ pool_kernel_size=None,
+ dropout: float = 0.0,
+ use_depthwise_separable_conv=False,
+ ):
+ """
+ Args:
+ in_channels (int): number of input channels for ASPP.
+ out_channels (int): number of output channels.
+ dilations (list): a list of 3 dilations in ASPP.
+ norm (str or callable): normalization for all conv layers.
+ See :func:`layers.get_norm` for supported format. norm is
+ applied to all conv layers except the conv following
+ global average pooling.
+ activation (callable): activation function.
+ pool_kernel_size (tuple, list): the average pooling size (kh, kw)
+ for image pooling layer in ASPP. If set to None, it always
+ performs global average pooling. If not None, it must be
+ divisible by the shape of inputs in forward(). It is recommended
+ to use a fixed input feature size in training, and set this
+ option to match this size, so that it performs global average
+ pooling in training, and the size of the pooling window stays
+ consistent in inference.
+ dropout (float): apply dropout on the output of ASPP. It is used in
+ the official DeepLab implementation with a rate of 0.1:
+ https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/model.py#L532 # noqa
+ use_depthwise_separable_conv (bool): use DepthwiseSeparableConv2d
+ for 3x3 convs in ASPP, proposed in :paper:`DeepLabV3+`.
+ """
+ super(ASPP, self).__init__()
+ assert len(dilations) == 3, "ASPP expects 3 dilations, got {}".format(len(dilations))
+ self.pool_kernel_size = pool_kernel_size
+ self.dropout = dropout
+ use_bias = norm == ""
+ self.convs = nn.ModuleList()
+ # conv 1x1
+ self.convs.append(
+ Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ bias=use_bias,
+ norm=get_norm(norm, out_channels),
+ activation=deepcopy(activation),
+ )
+ )
+ weight_init.c2_xavier_fill(self.convs[-1])
+ # atrous convs
+ for dilation in dilations:
+ if use_depthwise_separable_conv:
+ self.convs.append(
+ DepthwiseSeparableConv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ padding=dilation,
+ dilation=dilation,
+ norm1=norm,
+ activation1=deepcopy(activation),
+ norm2=norm,
+ activation2=deepcopy(activation),
+ )
+ )
+ else:
+ self.convs.append(
+ Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ padding=dilation,
+ dilation=dilation,
+ bias=use_bias,
+ norm=get_norm(norm, out_channels),
+ activation=deepcopy(activation),
+ )
+ )
+ weight_init.c2_xavier_fill(self.convs[-1])
+ # image pooling
+ # We do not add BatchNorm because the spatial resolution is 1x1,
+ # the original TF implementation has BatchNorm.
+ if pool_kernel_size is None:
+ image_pooling = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)),
+ )
+ else:
+ image_pooling = nn.Sequential(
+ nn.AvgPool2d(kernel_size=pool_kernel_size, stride=1),
+ Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)),
+ )
+ weight_init.c2_xavier_fill(image_pooling[1])
+ self.convs.append(image_pooling)
+
+ self.project = Conv2d(
+ 5 * out_channels,
+ out_channels,
+ kernel_size=1,
+ bias=use_bias,
+ norm=get_norm(norm, out_channels),
+ activation=deepcopy(activation),
+ )
+ weight_init.c2_xavier_fill(self.project)
+
+ def forward(self, x):
+ size = x.shape[-2:]
+ if self.pool_kernel_size is not None:
+ if size[0] % self.pool_kernel_size[0] or size[1] % self.pool_kernel_size[1]:
+ raise ValueError(
+ "`pool_kernel_size` must be divisible by the shape of inputs. "
+ "Input size: {} `pool_kernel_size`: {}".format(size, self.pool_kernel_size)
+ )
+ res = []
+ for conv in self.convs:
+ res.append(conv(x))
+ res[-1] = F.interpolate(res[-1], size=size, mode="bilinear", align_corners=False)
+ res = torch.cat(res, dim=1)
+ res = self.project(res)
+ res = F.dropout(res, self.dropout, training=self.training) if self.dropout > 0 else res
+ return res
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/layers/batch_norm.py b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/batch_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..32a1e05470065e75b6caad18d36211d27af8eec0
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/batch_norm.py
@@ -0,0 +1,300 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import torch
+import torch.distributed as dist
+from fvcore.nn.distributed import differentiable_all_reduce
+from torch import nn
+from torch.nn import functional as F
+
+from annotator.oneformer.detectron2.utils import comm, env
+
+from .wrappers import BatchNorm2d
+
+
+class FrozenBatchNorm2d(nn.Module):
+ """
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
+
+ It contains non-trainable buffers called
+ "weight" and "bias", "running_mean", "running_var",
+ initialized to perform identity transformation.
+
+ The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
+ which are computed from the original four parameters of BN.
+ The affine transform `x * weight + bias` will perform the equivalent
+ computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
+ When loading a backbone model from Caffe2, "running_mean" and "running_var"
+ will be left unchanged as identity transformation.
+
+ Other pre-trained backbone models may contain all 4 parameters.
+
+ The forward is implemented by `F.batch_norm(..., training=False)`.
+ """
+
+ _version = 3
+
+ def __init__(self, num_features, eps=1e-5):
+ super().__init__()
+ self.num_features = num_features
+ self.eps = eps
+ self.register_buffer("weight", torch.ones(num_features))
+ self.register_buffer("bias", torch.zeros(num_features))
+ self.register_buffer("running_mean", torch.zeros(num_features))
+ self.register_buffer("running_var", torch.ones(num_features) - eps)
+
+ def forward(self, x):
+ if x.requires_grad:
+ # When gradients are needed, F.batch_norm will use extra memory
+ # because its backward op computes gradients for weight/bias as well.
+ scale = self.weight * (self.running_var + self.eps).rsqrt()
+ bias = self.bias - self.running_mean * scale
+ scale = scale.reshape(1, -1, 1, 1)
+ bias = bias.reshape(1, -1, 1, 1)
+ out_dtype = x.dtype # may be half
+ return x * scale.to(out_dtype) + bias.to(out_dtype)
+ else:
+ # When gradients are not needed, F.batch_norm is a single fused op
+ # and provide more optimization opportunities.
+ return F.batch_norm(
+ x,
+ self.running_mean,
+ self.running_var,
+ self.weight,
+ self.bias,
+ training=False,
+ eps=self.eps,
+ )
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ version = local_metadata.get("version", None)
+
+ if version is None or version < 2:
+ # No running_mean/var in early versions
+ # This will silent the warnings
+ if prefix + "running_mean" not in state_dict:
+ state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
+ if prefix + "running_var" not in state_dict:
+ state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
+
+ super()._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ )
+
+ def __repr__(self):
+ return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
+
+ @classmethod
+ def convert_frozen_batchnorm(cls, module):
+ """
+ Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
+
+ Args:
+ module (torch.nn.Module):
+
+ Returns:
+ If module is BatchNorm/SyncBatchNorm, returns a new module.
+ Otherwise, in-place convert module and return it.
+
+ Similar to convert_sync_batchnorm in
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
+ """
+ bn_module = nn.modules.batchnorm
+ bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
+ res = module
+ if isinstance(module, bn_module):
+ res = cls(module.num_features)
+ if module.affine:
+ res.weight.data = module.weight.data.clone().detach()
+ res.bias.data = module.bias.data.clone().detach()
+ res.running_mean.data = module.running_mean.data
+ res.running_var.data = module.running_var.data
+ res.eps = module.eps
+ else:
+ for name, child in module.named_children():
+ new_child = cls.convert_frozen_batchnorm(child)
+ if new_child is not child:
+ res.add_module(name, new_child)
+ return res
+
+
+def get_norm(norm, out_channels):
+ """
+ Args:
+ norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
+ or a callable that takes a channel number and returns
+ the normalization layer as a nn.Module.
+
+ Returns:
+ nn.Module or None: the normalization layer
+ """
+ if norm is None:
+ return None
+ if isinstance(norm, str):
+ if len(norm) == 0:
+ return None
+ norm = {
+ "BN": BatchNorm2d,
+ # Fixed in https://github.com/pytorch/pytorch/pull/36382
+ "SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm,
+ "FrozenBN": FrozenBatchNorm2d,
+ "GN": lambda channels: nn.GroupNorm(32, channels),
+ # for debugging:
+ "nnSyncBN": nn.SyncBatchNorm,
+ "naiveSyncBN": NaiveSyncBatchNorm,
+ # expose stats_mode N as an option to caller, required for zero-len inputs
+ "naiveSyncBN_N": lambda channels: NaiveSyncBatchNorm(channels, stats_mode="N"),
+ "LN": lambda channels: LayerNorm(channels),
+ }[norm]
+ return norm(out_channels)
+
+
+class NaiveSyncBatchNorm(BatchNorm2d):
+ """
+ In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient
+ when the batch size on each worker is different.
+ (e.g., when scale augmentation is used, or when it is applied to mask head).
+
+ This is a slower but correct alternative to `nn.SyncBatchNorm`.
+
+ Note:
+ There isn't a single definition of Sync BatchNorm.
+
+ When ``stats_mode==""``, this module computes overall statistics by using
+ statistics of each worker with equal weight. The result is true statistics
+ of all samples (as if they are all on one worker) only when all workers
+ have the same (N, H, W). This mode does not support inputs with zero batch size.
+
+ When ``stats_mode=="N"``, this module computes overall statistics by weighting
+ the statistics of each worker by their ``N``. The result is true statistics
+ of all samples (as if they are all on one worker) only when all workers
+ have the same (H, W). It is slower than ``stats_mode==""``.
+
+ Even though the result of this module may not be the true statistics of all samples,
+ it may still be reasonable because it might be preferrable to assign equal weights
+ to all workers, regardless of their (H, W) dimension, instead of putting larger weight
+ on larger images. From preliminary experiments, little difference is found between such
+ a simplified implementation and an accurate computation of overall mean & variance.
+ """
+
+ def __init__(self, *args, stats_mode="", **kwargs):
+ super().__init__(*args, **kwargs)
+ assert stats_mode in ["", "N"]
+ self._stats_mode = stats_mode
+
+ def forward(self, input):
+ if comm.get_world_size() == 1 or not self.training:
+ return super().forward(input)
+
+ B, C = input.shape[0], input.shape[1]
+
+ half_input = input.dtype == torch.float16
+ if half_input:
+ # fp16 does not have good enough numerics for the reduction here
+ input = input.float()
+ mean = torch.mean(input, dim=[0, 2, 3])
+ meansqr = torch.mean(input * input, dim=[0, 2, 3])
+
+ if self._stats_mode == "":
+ assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.'
+ vec = torch.cat([mean, meansqr], dim=0)
+ vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size())
+ mean, meansqr = torch.split(vec, C)
+ momentum = self.momentum
+ else:
+ if B == 0:
+ vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype)
+ vec = vec + input.sum() # make sure there is gradient w.r.t input
+ else:
+ vec = torch.cat(
+ [mean, meansqr, torch.ones([1], device=mean.device, dtype=mean.dtype)], dim=0
+ )
+ vec = differentiable_all_reduce(vec * B)
+
+ total_batch = vec[-1].detach()
+ momentum = total_batch.clamp(max=1) * self.momentum # no update if total_batch is 0
+ mean, meansqr, _ = torch.split(vec / total_batch.clamp(min=1), C) # avoid div-by-zero
+
+ var = meansqr - mean * mean
+ invstd = torch.rsqrt(var + self.eps)
+ scale = self.weight * invstd
+ bias = self.bias - mean * scale
+ scale = scale.reshape(1, -1, 1, 1)
+ bias = bias.reshape(1, -1, 1, 1)
+
+ self.running_mean += momentum * (mean.detach() - self.running_mean)
+ self.running_var += momentum * (var.detach() - self.running_var)
+ ret = input * scale + bias
+ if half_input:
+ ret = ret.half()
+ return ret
+
+
+class CycleBatchNormList(nn.ModuleList):
+ """
+ Implement domain-specific BatchNorm by cycling.
+
+ When a BatchNorm layer is used for multiple input domains or input
+ features, it might need to maintain a separate test-time statistics
+ for each domain. See Sec 5.2 in :paper:`rethinking-batchnorm`.
+
+ This module implements it by using N separate BN layers
+ and it cycles through them every time a forward() is called.
+
+ NOTE: The caller of this module MUST guarantee to always call
+ this module by multiple of N times. Otherwise its test-time statistics
+ will be incorrect.
+ """
+
+ def __init__(self, length: int, bn_class=nn.BatchNorm2d, **kwargs):
+ """
+ Args:
+ length: number of BatchNorm layers to cycle.
+ bn_class: the BatchNorm class to use
+ kwargs: arguments of the BatchNorm class, such as num_features.
+ """
+ self._affine = kwargs.pop("affine", True)
+ super().__init__([bn_class(**kwargs, affine=False) for k in range(length)])
+ if self._affine:
+ # shared affine, domain-specific BN
+ channels = self[0].num_features
+ self.weight = nn.Parameter(torch.ones(channels))
+ self.bias = nn.Parameter(torch.zeros(channels))
+ self._pos = 0
+
+ def forward(self, x):
+ ret = self[self._pos](x)
+ self._pos = (self._pos + 1) % len(self)
+
+ if self._affine:
+ w = self.weight.reshape(1, -1, 1, 1)
+ b = self.bias.reshape(1, -1, 1, 1)
+ return ret * w + b
+ else:
+ return ret
+
+ def extra_repr(self):
+ return f"affine={self._affine}"
+
+
+class LayerNorm(nn.Module):
+ """
+ A LayerNorm variant, popularized by Transformers, that performs point-wise mean and
+ variance normalization over the channel dimension for inputs that have shape
+ (batch_size, channels, height, width).
+ https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950
+ """
+
+ def __init__(self, normalized_shape, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.normalized_shape = (normalized_shape,)
+
+ def forward(self, x):
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/layers/blocks.py b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..1995a4bf7339e8deb7eaaffda4f819dda55e7ac7
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/blocks.py
@@ -0,0 +1,111 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import fvcore.nn.weight_init as weight_init
+from torch import nn
+
+from .batch_norm import FrozenBatchNorm2d, get_norm
+from .wrappers import Conv2d
+
+
+"""
+CNN building blocks.
+"""
+
+
+class CNNBlockBase(nn.Module):
+ """
+ A CNN block is assumed to have input channels, output channels and a stride.
+ The input and output of `forward()` method must be NCHW tensors.
+ The method can perform arbitrary computation but must match the given
+ channels and stride specification.
+
+ Attribute:
+ in_channels (int):
+ out_channels (int):
+ stride (int):
+ """
+
+ def __init__(self, in_channels, out_channels, stride):
+ """
+ The `__init__` method of any subclass should also contain these arguments.
+
+ Args:
+ in_channels (int):
+ out_channels (int):
+ stride (int):
+ """
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.stride = stride
+
+ def freeze(self):
+ """
+ Make this block not trainable.
+ This method sets all parameters to `requires_grad=False`,
+ and convert all BatchNorm layers to FrozenBatchNorm
+
+ Returns:
+ the block itself
+ """
+ for p in self.parameters():
+ p.requires_grad = False
+ FrozenBatchNorm2d.convert_frozen_batchnorm(self)
+ return self
+
+
+class DepthwiseSeparableConv2d(nn.Module):
+ """
+ A kxk depthwise convolution + a 1x1 convolution.
+
+ In :paper:`xception`, norm & activation are applied on the second conv.
+ :paper:`mobilenet` uses norm & activation on both convs.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ dilation=1,
+ *,
+ norm1=None,
+ activation1=None,
+ norm2=None,
+ activation2=None,
+ ):
+ """
+ Args:
+ norm1, norm2 (str or callable): normalization for the two conv layers.
+ activation1, activation2 (callable(Tensor) -> Tensor): activation
+ function for the two conv layers.
+ """
+ super().__init__()
+ self.depthwise = Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ dilation=dilation,
+ groups=in_channels,
+ bias=not norm1,
+ norm=get_norm(norm1, in_channels),
+ activation=activation1,
+ )
+ self.pointwise = Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ bias=not norm2,
+ norm=get_norm(norm2, out_channels),
+ activation=activation2,
+ )
+
+ # default initialization
+ weight_init.c2_msra_fill(self.depthwise)
+ weight_init.c2_msra_fill(self.pointwise)
+
+ def forward(self, x):
+ return self.pointwise(self.depthwise(x))
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/README.md b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..778ed3da0bae89820831bcd8a72ff7b9cad8d4dd
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/README.md
@@ -0,0 +1,7 @@
+
+
+To add a new Op:
+
+1. Create a new directory
+2. Implement new ops there
+3. Delcare its Python interface in `vision.cpp`.
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h
new file mode 100644
index 0000000000000000000000000000000000000000..03f4211003f42f601f0cfcf4a690f5da4a0a1f67
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h
@@ -0,0 +1,115 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+#pragma once
+#include
+
+namespace detectron2 {
+
+at::Tensor ROIAlignRotated_forward_cpu(
+ const at::Tensor& input,
+ const at::Tensor& rois,
+ const float spatial_scale,
+ const int pooled_height,
+ const int pooled_width,
+ const int sampling_ratio);
+
+at::Tensor ROIAlignRotated_backward_cpu(
+ const at::Tensor& grad,
+ const at::Tensor& rois,
+ const float spatial_scale,
+ const int pooled_height,
+ const int pooled_width,
+ const int batch_size,
+ const int channels,
+ const int height,
+ const int width,
+ const int sampling_ratio);
+
+#if defined(WITH_CUDA) || defined(WITH_HIP)
+at::Tensor ROIAlignRotated_forward_cuda(
+ const at::Tensor& input,
+ const at::Tensor& rois,
+ const float spatial_scale,
+ const int pooled_height,
+ const int pooled_width,
+ const int sampling_ratio);
+
+at::Tensor ROIAlignRotated_backward_cuda(
+ const at::Tensor& grad,
+ const at::Tensor& rois,
+ const float spatial_scale,
+ const int pooled_height,
+ const int pooled_width,
+ const int batch_size,
+ const int channels,
+ const int height,
+ const int width,
+ const int sampling_ratio);
+#endif
+
+// Interface for Python
+inline at::Tensor ROIAlignRotated_forward(
+ const at::Tensor& input,
+ const at::Tensor& rois,
+ const double spatial_scale,
+ const int64_t pooled_height,
+ const int64_t pooled_width,
+ const int64_t sampling_ratio) {
+ if (input.is_cuda()) {
+#if defined(WITH_CUDA) || defined(WITH_HIP)
+ return ROIAlignRotated_forward_cuda(
+ input,
+ rois,
+ spatial_scale,
+ pooled_height,
+ pooled_width,
+ sampling_ratio);
+#else
+ AT_ERROR("Detectron2 is not compiled with GPU support!");
+#endif
+ }
+ return ROIAlignRotated_forward_cpu(
+ input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
+}
+
+inline at::Tensor ROIAlignRotated_backward(
+ const at::Tensor& grad,
+ const at::Tensor& rois,
+ const double spatial_scale,
+ const int64_t pooled_height,
+ const int64_t pooled_width,
+ const int64_t batch_size,
+ const int64_t channels,
+ const int64_t height,
+ const int64_t width,
+ const int64_t sampling_ratio) {
+ if (grad.is_cuda()) {
+#if defined(WITH_CUDA) || defined(WITH_HIP)
+ return ROIAlignRotated_backward_cuda(
+ grad,
+ rois,
+ spatial_scale,
+ pooled_height,
+ pooled_width,
+ batch_size,
+ channels,
+ height,
+ width,
+ sampling_ratio);
+#else
+ AT_ERROR("Detectron2 is not compiled with GPU support!");
+#endif
+ }
+ return ROIAlignRotated_backward_cpu(
+ grad,
+ rois,
+ spatial_scale,
+ pooled_height,
+ pooled_width,
+ batch_size,
+ channels,
+ height,
+ width,
+ sampling_ratio);
+}
+
+} // namespace detectron2
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2a3d3056cc71a4acaafb570739a9dd247a7eb1ed
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp
@@ -0,0 +1,522 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+#include
+#include "ROIAlignRotated.h"
+
+// Note: this implementation originates from the Caffe2 ROIAlignRotated Op
+// and PyTorch ROIAlign (non-rotated) Op implementations.
+// The key difference between this implementation and those ones is
+// we don't do "legacy offset" in this version, as there aren't many previous
+// works, if any, using the "legacy" ROIAlignRotated Op.
+// This would make the interface a bit cleaner.
+
+namespace detectron2 {
+
+namespace {
+template
+struct PreCalc {
+ int pos1;
+ int pos2;
+ int pos3;
+ int pos4;
+ T w1;
+ T w2;
+ T w3;
+ T w4;
+};
+
+template
+void pre_calc_for_bilinear_interpolate(
+ const int height,
+ const int width,
+ const int pooled_height,
+ const int pooled_width,
+ const int iy_upper,
+ const int ix_upper,
+ T roi_start_h,
+ T roi_start_w,
+ T bin_size_h,
+ T bin_size_w,
+ int roi_bin_grid_h,
+ int roi_bin_grid_w,
+ T roi_center_h,
+ T roi_center_w,
+ T cos_theta,
+ T sin_theta,
+ std::vector>& pre_calc) {
+ int pre_calc_index = 0;
+ for (int ph = 0; ph < pooled_height; ph++) {
+ for (int pw = 0; pw < pooled_width; pw++) {
+ for (int iy = 0; iy < iy_upper; iy++) {
+ const T yy = roi_start_h + ph * bin_size_h +
+ static_cast(iy + .5f) * bin_size_h /
+ static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5
+ for (int ix = 0; ix < ix_upper; ix++) {
+ const T xx = roi_start_w + pw * bin_size_w +
+ static_cast(ix + .5f) * bin_size_w /
+ static_cast(roi_bin_grid_w);
+
+ // Rotate by theta around the center and translate
+ // In image space, (y, x) is the order for Right Handed System,
+ // and this is essentially multiplying the point by a rotation matrix
+ // to rotate it counterclockwise through angle theta.
+ T y = yy * cos_theta - xx * sin_theta + roi_center_h;
+ T x = yy * sin_theta + xx * cos_theta + roi_center_w;
+ // deal with: inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
+ // empty
+ PreCalc pc;
+ pc.pos1 = 0;
+ pc.pos2 = 0;
+ pc.pos3 = 0;
+ pc.pos4 = 0;
+ pc.w1 = 0;
+ pc.w2 = 0;
+ pc.w3 = 0;
+ pc.w4 = 0;
+ pre_calc[pre_calc_index] = pc;
+ pre_calc_index += 1;
+ continue;
+ }
+
+ if (y < 0) {
+ y = 0;
+ }
+ if (x < 0) {
+ x = 0;
+ }
+
+ int y_low = (int)y;
+ int x_low = (int)x;
+ int y_high;
+ int x_high;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+ T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ // save weights and indices
+ PreCalc pc;
+ pc.pos1 = y_low * width + x_low;
+ pc.pos2 = y_low * width + x_high;
+ pc.pos3 = y_high * width + x_low;
+ pc.pos4 = y_high * width + x_high;
+ pc.w1 = w1;
+ pc.w2 = w2;
+ pc.w3 = w3;
+ pc.w4 = w4;
+ pre_calc[pre_calc_index] = pc;
+
+ pre_calc_index += 1;
+ }
+ }
+ }
+ }
+}
+
+template
+void bilinear_interpolate_gradient(
+ const int height,
+ const int width,
+ T y,
+ T x,
+ T& w1,
+ T& w2,
+ T& w3,
+ T& w4,
+ int& x_low,
+ int& x_high,
+ int& y_low,
+ int& y_high) {
+ // deal with cases that inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
+ // empty
+ w1 = w2 = w3 = w4 = 0.;
+ x_low = x_high = y_low = y_high = -1;
+ return;
+ }
+
+ if (y < 0) {
+ y = 0;
+ }
+
+ if (x < 0) {
+ x = 0;
+ }
+
+ y_low = (int)y;
+ x_low = (int)x;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+
+ // reference in forward
+ // T v1 = input[y_low * width + x_low];
+ // T v2 = input[y_low * width + x_high];
+ // T v3 = input[y_high * width + x_low];
+ // T v4 = input[y_high * width + x_high];
+ // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+ w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ return;
+}
+
+template
+inline void add(T* address, const T& val) {
+ *address += val;
+}
+
+} // namespace
+
+template
+void ROIAlignRotatedForward(
+ const int nthreads,
+ const T* input,
+ const T& spatial_scale,
+ const int channels,
+ const int height,
+ const int width,
+ const int pooled_height,
+ const int pooled_width,
+ const int sampling_ratio,
+ const T* rois,
+ T* output) {
+ int n_rois = nthreads / channels / pooled_width / pooled_height;
+ // (n, c, ph, pw) is an element in the pooled output
+ // can be parallelized using omp
+ // #pragma omp parallel for num_threads(32)
+ for (int n = 0; n < n_rois; n++) {
+ int index_n = n * channels * pooled_width * pooled_height;
+
+ const T* current_roi = rois + n * 6;
+ int roi_batch_ind = current_roi[0];
+
+ // Do not use rounding; this implementation detail is critical
+ // ROIAlignRotated supports align == true, i.e., continuous coordinate
+ // by default, thus the 0.5 offset
+ T offset = (T)0.5;
+ T roi_center_w = current_roi[1] * spatial_scale - offset;
+ T roi_center_h = current_roi[2] * spatial_scale - offset;
+ T roi_width = current_roi[3] * spatial_scale;
+ T roi_height = current_roi[4] * spatial_scale;
+ T theta = current_roi[5] * M_PI / 180.0;
+ T cos_theta = cos(theta);
+ T sin_theta = sin(theta);
+
+ AT_ASSERTM(
+ roi_width >= 0 && roi_height >= 0,
+ "ROIs in ROIAlignRotated do not have non-negative size!");
+
+ T bin_size_h = static_cast(roi_height) / static_cast(pooled_height);
+ T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h = (sampling_ratio > 0)
+ ? sampling_ratio
+ : ceil(roi_height / pooled_height); // e.g., = 2
+ int roi_bin_grid_w =
+ (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+ // We do average (integral) pooling inside a bin
+ const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
+
+ // we want to precalculate indices and weights shared by all channels,
+ // this is the key point of optimization
+ std::vector> pre_calc(
+ roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
+
+ // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
+ // Appropriate translation needs to be applied after.
+ T roi_start_h = -roi_height / 2.0;
+ T roi_start_w = -roi_width / 2.0;
+
+ pre_calc_for_bilinear_interpolate(
+ height,
+ width,
+ pooled_height,
+ pooled_width,
+ roi_bin_grid_h,
+ roi_bin_grid_w,
+ roi_start_h,
+ roi_start_w,
+ bin_size_h,
+ bin_size_w,
+ roi_bin_grid_h,
+ roi_bin_grid_w,
+ roi_center_h,
+ roi_center_w,
+ cos_theta,
+ sin_theta,
+ pre_calc);
+
+ for (int c = 0; c < channels; c++) {
+ int index_n_c = index_n + c * pooled_width * pooled_height;
+ const T* offset_input =
+ input + (roi_batch_ind * channels + c) * height * width;
+ int pre_calc_index = 0;
+
+ for (int ph = 0; ph < pooled_height; ph++) {
+ for (int pw = 0; pw < pooled_width; pw++) {
+ int index = index_n_c + ph * pooled_width + pw;
+
+ T output_val = 0.;
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ PreCalc pc = pre_calc[pre_calc_index];
+ output_val += pc.w1 * offset_input[pc.pos1] +
+ pc.w2 * offset_input[pc.pos2] +
+ pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4];
+
+ pre_calc_index += 1;
+ }
+ }
+ output_val /= count;
+
+ output[index] = output_val;
+ } // for pw
+ } // for ph
+ } // for c
+ } // for n
+}
+
+template
+void ROIAlignRotatedBackward(
+ const int nthreads,
+ // may not be contiguous. should index using n_stride, etc
+ const T* grad_output,
+ const T& spatial_scale,
+ const int channels,
+ const int height,
+ const int width,
+ const int pooled_height,
+ const int pooled_width,
+ const int sampling_ratio,
+ T* grad_input,
+ const T* rois,
+ const int n_stride,
+ const int c_stride,
+ const int h_stride,
+ const int w_stride) {
+ for (int index = 0; index < nthreads; index++) {
+ // (n, c, ph, pw) is an element in the pooled output
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+
+ const T* current_roi = rois + n * 6;
+ int roi_batch_ind = current_roi[0];
+
+ // Do not use rounding; this implementation detail is critical
+ // ROIAlignRotated supports align == true, i.e., continuous coordinate
+ // by default, thus the 0.5 offset
+ T offset = (T)0.5;
+ T roi_center_w = current_roi[1] * spatial_scale - offset;
+ T roi_center_h = current_roi[2] * spatial_scale - offset;
+ T roi_width = current_roi[3] * spatial_scale;
+ T roi_height = current_roi[4] * spatial_scale;
+ T theta = current_roi[5] * M_PI / 180.0;
+ T cos_theta = cos(theta);
+ T sin_theta = sin(theta);
+
+ AT_ASSERTM(
+ roi_width >= 0 && roi_height >= 0,
+ "ROIs in ROIAlignRotated do not have non-negative size!");
+
+ T bin_size_h = static_cast(roi_height) / static_cast(pooled_height);
+ T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);
+
+ T* offset_grad_input =
+ grad_input + ((roi_batch_ind * channels + c) * height * width);
+
+ int output_offset = n * n_stride + c * c_stride;
+ const T* offset_grad_output = grad_output + output_offset;
+ const T grad_output_this_bin =
+ offset_grad_output[ph * h_stride + pw * w_stride];
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h = (sampling_ratio > 0)
+ ? sampling_ratio
+ : ceil(roi_height / pooled_height); // e.g., = 2
+ int roi_bin_grid_w =
+ (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+ // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
+ // Appropriate translation needs to be applied after.
+ T roi_start_h = -roi_height / 2.0;
+ T roi_start_w = -roi_width / 2.0;
+
+ // We do average (integral) pooling inside a bin
+ const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+ const T yy = roi_start_h + ph * bin_size_h +
+ static_cast(iy + .5f) * bin_size_h /
+ static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ const T xx = roi_start_w + pw * bin_size_w +
+ static_cast(ix + .5f) * bin_size_w /
+ static_cast(roi_bin_grid_w);
+
+ // Rotate by theta around the center and translate
+ T y = yy * cos_theta - xx * sin_theta + roi_center_h;
+ T x = yy * sin_theta + xx * cos_theta + roi_center_w;
+
+ T w1, w2, w3, w4;
+ int x_low, x_high, y_low, y_high;
+
+ bilinear_interpolate_gradient(
+ height, width, y, x, w1, w2, w3, w4, x_low, x_high, y_low, y_high);
+
+ T g1 = grad_output_this_bin * w1 / count;
+ T g2 = grad_output_this_bin * w2 / count;
+ T g3 = grad_output_this_bin * w3 / count;
+ T g4 = grad_output_this_bin * w4 / count;
+
+ if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+ // atomic add is not needed for now since it is single threaded
+ add(offset_grad_input + y_low * width + x_low, static_cast(g1));
+ add(offset_grad_input + y_low * width + x_high, static_cast(g2));
+ add(offset_grad_input + y_high * width + x_low, static_cast(g3));
+ add(offset_grad_input + y_high * width + x_high, static_cast(g4));
+ } // if
+ } // ix
+ } // iy
+ } // for
+} // ROIAlignRotatedBackward
+
+at::Tensor ROIAlignRotated_forward_cpu(
+ const at::Tensor& input,
+ const at::Tensor& rois,
+ const float spatial_scale,
+ const int pooled_height,
+ const int pooled_width,
+ const int sampling_ratio) {
+ AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
+ AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
+
+ at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
+
+ at::CheckedFrom c = "ROIAlign_forward_cpu";
+ at::checkAllSameType(c, {input_t, rois_t});
+
+ auto num_rois = rois.size(0);
+ auto channels = input.size(1);
+ auto height = input.size(2);
+ auto width = input.size(3);
+
+ at::Tensor output = at::zeros(
+ {num_rois, channels, pooled_height, pooled_width}, input.options());
+
+ auto output_size = num_rois * pooled_height * pooled_width * channels;
+
+ if (output.numel() == 0) {
+ return output;
+ }
+
+ auto input_ = input.contiguous(), rois_ = rois.contiguous();
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ input.scalar_type(), "ROIAlignRotated_forward", [&] {
+ ROIAlignRotatedForward(
+ output_size,
+ input_.data_ptr(),
+ spatial_scale,
+ channels,
+ height,
+ width,
+ pooled_height,
+ pooled_width,
+ sampling_ratio,
+ rois_.data_ptr(),
+ output.data_ptr());
+ });
+ return output;
+}
+
+at::Tensor ROIAlignRotated_backward_cpu(
+ const at::Tensor& grad,
+ const at::Tensor& rois,
+ const float spatial_scale,
+ const int pooled_height,
+ const int pooled_width,
+ const int batch_size,
+ const int channels,
+ const int height,
+ const int width,
+ const int sampling_ratio) {
+ AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
+ AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
+
+ at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
+
+ at::CheckedFrom c = "ROIAlignRotated_backward_cpu";
+ at::checkAllSameType(c, {grad_t, rois_t});
+
+ at::Tensor grad_input =
+ at::zeros({batch_size, channels, height, width}, grad.options());
+
+ // handle possibly empty gradients
+ if (grad.numel() == 0) {
+ return grad_input;
+ }
+
+ // get stride values to ensure indexing into gradients is correct.
+ int n_stride = grad.stride(0);
+ int c_stride = grad.stride(1);
+ int h_stride = grad.stride(2);
+ int w_stride = grad.stride(3);
+
+ auto rois_ = rois.contiguous();
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ grad.scalar_type(), "ROIAlignRotated_forward", [&] {
+ ROIAlignRotatedBackward(
+ grad.numel(),
+ grad.data_ptr(),
+ spatial_scale,
+ channels,
+ height,
+ width,
+ pooled_height,
+ pooled_width,
+ sampling_ratio,
+ grad_input.data_ptr(),
+ rois_.data_ptr(),
+ n_stride,
+ c_stride,
+ h_stride,
+ w_stride);
+ });
+ return grad_input;
+}
+
+} // namespace detectron2
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..fca186519143b168a912c880a4cf495a0a5a9322
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu
@@ -0,0 +1,443 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+#include
+#include
+#include
+#include
+
+// TODO make it in a common file
+#define CUDA_1D_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
+ i += blockDim.x * gridDim.x)
+
+// Note: this implementation originates from the Caffe2 ROIAlignRotated Op
+// and PyTorch ROIAlign (non-rotated) Op implementations.
+// The key difference between this implementation and those ones is
+// we don't do "legacy offset" in this version, as there aren't many previous
+// works, if any, using the "legacy" ROIAlignRotated Op.
+// This would make the interface a bit cleaner.
+
+namespace detectron2 {
+
+namespace {
+
+template
+__device__ T bilinear_interpolate(
+ const T* input,
+ const int height,
+ const int width,
+ T y,
+ T x) {
+ // deal with cases that inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
+ // empty
+ return 0;
+ }
+
+ if (y < 0) {
+ y = 0;
+ }
+
+ if (x < 0) {
+ x = 0;
+ }
+
+ int y_low = (int)y;
+ int x_low = (int)x;
+ int y_high;
+ int x_high;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+ // do bilinear interpolation
+ T v1 = input[y_low * width + x_low];
+ T v2 = input[y_low * width + x_high];
+ T v3 = input[y_high * width + x_low];
+ T v4 = input[y_high * width + x_high];
+ T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+ return val;
+}
+
+template
+__device__ void bilinear_interpolate_gradient(
+ const int height,
+ const int width,
+ T y,
+ T x,
+ T& w1,
+ T& w2,
+ T& w3,
+ T& w4,
+ int& x_low,
+ int& x_high,
+ int& y_low,
+ int& y_high) {
+ // deal with cases that inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
+ // empty
+ w1 = w2 = w3 = w4 = 0.;
+ x_low = x_high = y_low = y_high = -1;
+ return;
+ }
+
+ if (y < 0) {
+ y = 0;
+ }
+
+ if (x < 0) {
+ x = 0;
+ }
+
+ y_low = (int)y;
+ x_low = (int)x;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (T)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (T)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ T ly = y - y_low;
+ T lx = x - x_low;
+ T hy = 1. - ly, hx = 1. - lx;
+
+ // reference in forward
+ // T v1 = input[y_low * width + x_low];
+ // T v2 = input[y_low * width + x_high];
+ // T v3 = input[y_high * width + x_low];
+ // T v4 = input[y_high * width + x_high];
+ // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+
+ w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+
+ return;
+}
+
+} // namespace
+
+template
+__global__ void RoIAlignRotatedForward(
+ const int nthreads,
+ const T* input,
+ const T spatial_scale,
+ const int channels,
+ const int height,
+ const int width,
+ const int pooled_height,
+ const int pooled_width,
+ const int sampling_ratio,
+ const T* rois,
+ T* top_data) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (n, c, ph, pw) is an element in the pooled output
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+
+ const T* current_roi = rois + n * 6;
+ int roi_batch_ind = current_roi[0];
+
+ // Do not use rounding; this implementation detail is critical
+ // ROIAlignRotated supports align == true, i.e., continuous coordinate
+ // by default, thus the 0.5 offset
+ T offset = (T)0.5;
+ T roi_center_w = current_roi[1] * spatial_scale - offset;
+ T roi_center_h = current_roi[2] * spatial_scale - offset;
+ T roi_width = current_roi[3] * spatial_scale;
+ T roi_height = current_roi[4] * spatial_scale;
+ T theta = current_roi[5] * M_PI / 180.0;
+ T cos_theta = cos(theta);
+ T sin_theta = sin(theta);
+
+ T bin_size_h = static_cast(roi_height) / static_cast(pooled_height);
+ T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);
+
+ const T* offset_input =
+ input + (roi_batch_ind * channels + c) * height * width;
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h = (sampling_ratio > 0)
+ ? sampling_ratio
+ : ceil(roi_height / pooled_height); // e.g., = 2
+ int roi_bin_grid_w =
+ (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+ // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
+ // Appropriate translation needs to be applied after.
+ T roi_start_h = -roi_height / 2.0;
+ T roi_start_w = -roi_width / 2.0;
+
+ // We do average (inte gral) pooling inside a bin
+ const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
+
+ T output_val = 0.;
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
+ {
+ const T yy = roi_start_h + ph * bin_size_h +
+ static_cast(iy + .5f) * bin_size_h /
+ static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ const T xx = roi_start_w + pw * bin_size_w +
+ static_cast(ix + .5f) * bin_size_w /
+ static_cast(roi_bin_grid_w);
+
+ // Rotate by theta around the center and translate
+ T y = yy * cos_theta - xx * sin_theta + roi_center_h;
+ T x = yy * sin_theta + xx * cos_theta + roi_center_w;
+
+ T val = bilinear_interpolate(offset_input, height, width, y, x);
+ output_val += val;
+ }
+ }
+ output_val /= count;
+
+ top_data[index] = output_val;
+ }
+}
+
+template
+__global__ void RoIAlignRotatedBackwardFeature(
+ const int nthreads,
+ const T* top_diff,
+ const int num_rois,
+ const T spatial_scale,
+ const int channels,
+ const int height,
+ const int width,
+ const int pooled_height,
+ const int pooled_width,
+ const int sampling_ratio,
+ T* bottom_diff,
+ const T* rois) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // (n, c, ph, pw) is an element in the pooled output
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+
+ const T* current_roi = rois + n * 6;
+ int roi_batch_ind = current_roi[0];
+
+ // Do not use rounding; this implementation detail is critical
+ // ROIAlignRotated supports align == true, i.e., continuous coordinate
+ // by default, thus the 0.5 offset
+ T offset = (T)0.5;
+ T roi_center_w = current_roi[1] * spatial_scale - offset;
+ T roi_center_h = current_roi[2] * spatial_scale - offset;
+ T roi_width = current_roi[3] * spatial_scale;
+ T roi_height = current_roi[4] * spatial_scale;
+ T theta = current_roi[5] * M_PI / 180.0;
+ T cos_theta = cos(theta);
+ T sin_theta = sin(theta);
+
+ T bin_size_h = static_cast(roi_height) / static_cast(pooled_height);
+ T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);
+
+ T* offset_bottom_diff =
+ bottom_diff + (roi_batch_ind * channels + c) * height * width;
+
+ int top_offset = (n * channels + c) * pooled_height * pooled_width;
+ const T* offset_top_diff = top_diff + top_offset;
+ const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h = (sampling_ratio > 0)
+ ? sampling_ratio
+ : ceil(roi_height / pooled_height); // e.g., = 2
+ int roi_bin_grid_w =
+ (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+ // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y).
+ // Appropriate translation needs to be applied after.
+ T roi_start_h = -roi_height / 2.0;
+ T roi_start_w = -roi_width / 2.0;
+
+ // We do average (integral) pooling inside a bin
+ const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
+ {
+ const T yy = roi_start_h + ph * bin_size_h +
+ static_cast(iy + .5f) * bin_size_h /
+ static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ const T xx = roi_start_w + pw * bin_size_w +
+ static_cast(ix + .5f) * bin_size_w /
+ static_cast(roi_bin_grid_w);
+
+ // Rotate by theta around the center and translate
+ T y = yy * cos_theta - xx * sin_theta + roi_center_h;
+ T x = yy * sin_theta + xx * cos_theta + roi_center_w;
+
+ T w1, w2, w3, w4;
+ int x_low, x_high, y_low, y_high;
+
+ bilinear_interpolate_gradient(
+ height, width, y, x, w1, w2, w3, w4, x_low, x_high, y_low, y_high);
+
+ T g1 = top_diff_this_bin * w1 / count;
+ T g2 = top_diff_this_bin * w2 / count;
+ T g3 = top_diff_this_bin * w3 / count;
+ T g4 = top_diff_this_bin * w4 / count;
+
+ if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
+ atomicAdd(
+ offset_bottom_diff + y_low * width + x_low, static_cast(g1));
+ atomicAdd(
+ offset_bottom_diff + y_low * width + x_high, static_cast(g2));
+ atomicAdd(
+ offset_bottom_diff + y_high * width + x_low, static_cast(g3));
+ atomicAdd(
+ offset_bottom_diff + y_high * width + x_high, static_cast(g4));
+ } // if
+ } // ix
+ } // iy
+ } // CUDA_1D_KERNEL_LOOP
+} // RoIAlignRotatedBackward
+
+at::Tensor ROIAlignRotated_forward_cuda(
+ const at::Tensor& input,
+ const at::Tensor& rois,
+ const float spatial_scale,
+ const int pooled_height,
+ const int pooled_width,
+ const int sampling_ratio) {
+ AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
+ AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
+ at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
+
+ at::CheckedFrom c = "ROIAlignRotated_forward_cuda";
+ at::checkAllSameGPU(c, {input_t, rois_t});
+ at::checkAllSameType(c, {input_t, rois_t});
+ at::cuda::CUDAGuard device_guard(input.device());
+
+ auto num_rois = rois.size(0);
+ auto channels = input.size(1);
+ auto height = input.size(2);
+ auto width = input.size(3);
+
+ auto output = at::empty(
+ {num_rois, channels, pooled_height, pooled_width}, input.options());
+ auto output_size = num_rois * pooled_height * pooled_width * channels;
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ dim3 grid(std::min(
+ at::cuda::ATenCeilDiv(
+ static_cast(output_size), static_cast(512)),
+ static_cast(4096)));
+ dim3 block(512);
+
+ if (output.numel() == 0) {
+ AT_CUDA_CHECK(cudaGetLastError());
+ return output;
+ }
+
+ auto input_ = input.contiguous(), rois_ = rois.contiguous();
+ AT_DISPATCH_FLOATING_TYPES(
+ input.scalar_type(), "ROIAlignRotated_forward", [&] {
+ RoIAlignRotatedForward<<>>(
+ output_size,
+ input_.data_ptr(),
+ spatial_scale,
+ channels,
+ height,
+ width,
+ pooled_height,
+ pooled_width,
+ sampling_ratio,
+ rois_.data_ptr(),
+ output.data_ptr());
+ });
+ cudaDeviceSynchronize();
+ AT_CUDA_CHECK(cudaGetLastError());
+ return output;
+}
+
+// TODO remove the dependency on input and use instead its sizes -> save memory
+at::Tensor ROIAlignRotated_backward_cuda(
+ const at::Tensor& grad,
+ const at::Tensor& rois,
+ const float spatial_scale,
+ const int pooled_height,
+ const int pooled_width,
+ const int batch_size,
+ const int channels,
+ const int height,
+ const int width,
+ const int sampling_ratio) {
+ AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
+ AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
+
+ at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
+ at::CheckedFrom c = "ROIAlign_backward_cuda";
+ at::checkAllSameGPU(c, {grad_t, rois_t});
+ at::checkAllSameType(c, {grad_t, rois_t});
+ at::cuda::CUDAGuard device_guard(grad.device());
+
+ auto num_rois = rois.size(0);
+ auto grad_input =
+ at::zeros({batch_size, channels, height, width}, grad.options());
+
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ dim3 grid(std::min(
+ at::cuda::ATenCeilDiv(
+ static_cast(grad.numel()), static_cast(512)),
+ static_cast(4096)));
+ dim3 block(512);
+
+ // handle possibly empty gradients
+ if (grad.numel() == 0) {
+ AT_CUDA_CHECK(cudaGetLastError());
+ return grad_input;
+ }
+
+ auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
+ AT_DISPATCH_FLOATING_TYPES(
+ grad.scalar_type(), "ROIAlignRotated_backward", [&] {
+ RoIAlignRotatedBackwardFeature<<>>(
+ grad.numel(),
+ grad_.data_ptr(),
+ num_rois,
+ spatial_scale,
+ channels,
+ height,
+ width,
+ pooled_height,
+ pooled_width,
+ sampling_ratio,
+ grad_input.data_ptr(),
+ rois_.data_ptr());
+ });
+ AT_CUDA_CHECK(cudaGetLastError());
+ return grad_input;
+}
+
+} // namespace detectron2
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h
new file mode 100644
index 0000000000000000000000000000000000000000..3bf383b8ed9b358b5313d433a9682c294dfb77e4
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h
@@ -0,0 +1,35 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+#pragma once
+#include
+
+namespace detectron2 {
+
+at::Tensor box_iou_rotated_cpu(
+ const at::Tensor& boxes1,
+ const at::Tensor& boxes2);
+
+#if defined(WITH_CUDA) || defined(WITH_HIP)
+at::Tensor box_iou_rotated_cuda(
+ const at::Tensor& boxes1,
+ const at::Tensor& boxes2);
+#endif
+
+// Interface for Python
+// inline is needed to prevent multiple function definitions when this header is
+// included by different cpps
+inline at::Tensor box_iou_rotated(
+ const at::Tensor& boxes1,
+ const at::Tensor& boxes2) {
+ assert(boxes1.device().is_cuda() == boxes2.device().is_cuda());
+ if (boxes1.device().is_cuda()) {
+#if defined(WITH_CUDA) || defined(WITH_HIP)
+ return box_iou_rotated_cuda(boxes1.contiguous(), boxes2.contiguous());
+#else
+ AT_ERROR("Detectron2 is not compiled with GPU support!");
+#endif
+ }
+
+ return box_iou_rotated_cpu(boxes1.contiguous(), boxes2.contiguous());
+}
+
+} // namespace detectron2
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c843487b5fa4e8077dd27402ec99009266ddda8d
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp
@@ -0,0 +1,39 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+#include "box_iou_rotated.h"
+#include "box_iou_rotated_utils.h"
+
+namespace detectron2 {
+
+template
+void box_iou_rotated_cpu_kernel(
+ const at::Tensor& boxes1,
+ const at::Tensor& boxes2,
+ at::Tensor& ious) {
+ auto num_boxes1 = boxes1.size(0);
+ auto num_boxes2 = boxes2.size(0);
+
+ for (int i = 0; i < num_boxes1; i++) {
+ for (int j = 0; j < num_boxes2; j++) {
+ ious[i * num_boxes2 + j] = single_box_iou_rotated(
+ boxes1[i].data_ptr(), boxes2[j].data_ptr());
+ }
+ }
+}
+
+at::Tensor box_iou_rotated_cpu(
+ // input must be contiguous:
+ const at::Tensor& boxes1,
+ const at::Tensor& boxes2) {
+ auto num_boxes1 = boxes1.size(0);
+ auto num_boxes2 = boxes2.size(0);
+ at::Tensor ious =
+ at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat));
+
+ box_iou_rotated_cpu_kernel(boxes1, boxes2, ious);
+
+ // reshape from 1d array to 2d array
+ auto shape = std::vector{num_boxes1, num_boxes2};
+ return ious.reshape(shape);
+}
+
+} // namespace detectron2
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
new file mode 100644
index 0000000000000000000000000000000000000000..952710e53041187907fbd113f8d0d0fa24134a86
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
@@ -0,0 +1,130 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+#include
+#include
+#include
+#include
+#include "box_iou_rotated_utils.h"
+
+namespace detectron2 {
+
+// 2D block with 32 * 16 = 512 threads per block
+const int BLOCK_DIM_X = 32;
+const int BLOCK_DIM_Y = 16;
+
+template
+__global__ void box_iou_rotated_cuda_kernel(
+ const int n_boxes1,
+ const int n_boxes2,
+ const T* dev_boxes1,
+ const T* dev_boxes2,
+ T* dev_ious) {
+ const int row_start = blockIdx.x * blockDim.x;
+ const int col_start = blockIdx.y * blockDim.y;
+
+ const int row_size = min(n_boxes1 - row_start, blockDim.x);
+ const int col_size = min(n_boxes2 - col_start, blockDim.y);
+
+ __shared__ float block_boxes1[BLOCK_DIM_X * 5];
+ __shared__ float block_boxes2[BLOCK_DIM_Y * 5];
+
+ // It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y
+ if (threadIdx.x < row_size && threadIdx.y == 0) {
+ block_boxes1[threadIdx.x * 5 + 0] =
+ dev_boxes1[(row_start + threadIdx.x) * 5 + 0];
+ block_boxes1[threadIdx.x * 5 + 1] =
+ dev_boxes1[(row_start + threadIdx.x) * 5 + 1];
+ block_boxes1[threadIdx.x * 5 + 2] =
+ dev_boxes1[(row_start + threadIdx.x) * 5 + 2];
+ block_boxes1[threadIdx.x * 5 + 3] =
+ dev_boxes1[(row_start + threadIdx.x) * 5 + 3];
+ block_boxes1[threadIdx.x * 5 + 4] =
+ dev_boxes1[(row_start + threadIdx.x) * 5 + 4];
+ }
+
+ if (threadIdx.x < col_size && threadIdx.y == 0) {
+ block_boxes2[threadIdx.x * 5 + 0] =
+ dev_boxes2[(col_start + threadIdx.x) * 5 + 0];
+ block_boxes2[threadIdx.x * 5 + 1] =
+ dev_boxes2[(col_start + threadIdx.x) * 5 + 1];
+ block_boxes2[threadIdx.x * 5 + 2] =
+ dev_boxes2[(col_start + threadIdx.x) * 5 + 2];
+ block_boxes2[threadIdx.x * 5 + 3] =
+ dev_boxes2[(col_start + threadIdx.x) * 5 + 3];
+ block_boxes2[threadIdx.x * 5 + 4] =
+ dev_boxes2[(col_start + threadIdx.x) * 5 + 4];
+ }
+ __syncthreads();
+
+ if (threadIdx.x < row_size && threadIdx.y < col_size) {
+ int offset = (row_start + threadIdx.x) * n_boxes2 + col_start + threadIdx.y;
+ dev_ious[offset] = single_box_iou_rotated(
+ block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5);
+ }
+}
+
+at::Tensor box_iou_rotated_cuda(
+ // input must be contiguous
+ const at::Tensor& boxes1,
+ const at::Tensor& boxes2) {
+ using scalar_t = float;
+ AT_ASSERTM(
+ boxes1.scalar_type() == at::kFloat, "boxes1 must be a float tensor");
+ AT_ASSERTM(
+ boxes2.scalar_type() == at::kFloat, "boxes2 must be a float tensor");
+ AT_ASSERTM(boxes1.is_cuda(), "boxes1 must be a CUDA tensor");
+ AT_ASSERTM(boxes2.is_cuda(), "boxes2 must be a CUDA tensor");
+ at::cuda::CUDAGuard device_guard(boxes1.device());
+
+ auto num_boxes1 = boxes1.size(0);
+ auto num_boxes2 = boxes2.size(0);
+
+ at::Tensor ious =
+ at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat));
+
+ bool transpose = false;
+ if (num_boxes1 > 0 && num_boxes2 > 0) {
+ scalar_t *data1 = boxes1.data_ptr(),
+ *data2 = boxes2.data_ptr();
+
+ if (num_boxes2 > 65535 * BLOCK_DIM_Y) {
+ AT_ASSERTM(
+ num_boxes1 <= 65535 * BLOCK_DIM_Y,
+ "Too many boxes for box_iou_rotated_cuda!");
+ // x dim is allowed to be large, but y dim cannot,
+ // so we transpose the two to avoid "invalid configuration argument"
+ // error. We assume one of them is small. Otherwise the result is hard to
+ // fit in memory anyway.
+ std::swap(num_boxes1, num_boxes2);
+ std::swap(data1, data2);
+ transpose = true;
+ }
+
+ const int blocks_x =
+ at::cuda::ATenCeilDiv(static_cast(num_boxes1), BLOCK_DIM_X);
+ const int blocks_y =
+ at::cuda::ATenCeilDiv(static_cast(num_boxes2), BLOCK_DIM_Y);
+
+ dim3 blocks(blocks_x, blocks_y);
+ dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ box_iou_rotated_cuda_kernel<<>>(
+ num_boxes1,
+ num_boxes2,
+ data1,
+ data2,
+ (scalar_t*)ious.data_ptr());
+
+ AT_CUDA_CHECK(cudaGetLastError());
+ }
+
+ // reshape from 1d array to 2d array
+ auto shape = std::vector{num_boxes1, num_boxes2};
+ if (transpose) {
+ return ious.view(shape).t();
+ } else {
+ return ious.view(shape);
+ }
+}
+
+} // namespace detectron2
diff --git a/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..b54a5dde2ca11a74d29c4d8adb7fe1634f5baf9c
--- /dev/null
+++ b/sd-webui-controlnet/annotator/oneformer/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h
@@ -0,0 +1,370 @@
+// Copyright (c) Facebook, Inc. and its affiliates.
+#pragma once
+
+#include
+#include
+
+#if defined(__CUDACC__) || __HCC__ == 1 || __HIP__ == 1
+// Designates functions callable from the host (CPU) and the device (GPU)
+#define HOST_DEVICE __host__ __device__
+#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
+#else
+#include
+#define HOST_DEVICE
+#define HOST_DEVICE_INLINE HOST_DEVICE inline
+#endif
+
+namespace detectron2 {
+
+namespace {
+
+template
+struct RotatedBox {
+ T x_ctr, y_ctr, w, h, a;
+};
+
+template
+struct Point {
+ T x, y;
+ HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {}
+ HOST_DEVICE_INLINE Point operator+(const Point& p) const {
+ return Point(x + p.x, y + p.y);
+ }
+ HOST_DEVICE_INLINE Point& operator+=(const Point& p) {
+ x += p.x;
+ y += p.y;
+ return *this;
+ }
+ HOST_DEVICE_INLINE Point operator-(const Point& p) const {
+ return Point(x - p.x, y - p.y);
+ }
+ HOST_DEVICE_INLINE Point operator*(const T coeff) const {
+ return Point(x * coeff, y * coeff);
+ }
+};
+
+template
+HOST_DEVICE_INLINE T dot_2d(const Point& A, const Point& B) {
+ return A.x * B.x + A.y * B.y;
+}
+
+// R: result type. can be different from input type
+template
+HOST_DEVICE_INLINE R cross_2d(const Point& A, const Point& B) {
+ return static_cast(A.x) * static_cast(B.y) -
+ static_cast(B.x) * static_cast(A.y);
+}
+
+template
+HOST_DEVICE_INLINE void get_rotated_vertices(
+ const RotatedBox& box,
+ Point (&pts)[4]) {
+ // M_PI / 180. == 0.01745329251
+ double theta = box.a * 0.01745329251;
+ T cosTheta2 = (T)cos(theta) * 0.5f;
+ T sinTheta2 = (T)sin(theta) * 0.5f;
+
+ // y: top --> down; x: left --> right
+ pts[0].x = box.x_ctr + sinTheta2 * box.h + cosTheta2 * box.w;
+ pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
+ pts[1].x = box.x_ctr - sinTheta2 * box.h + cosTheta2 * box.w;
+ pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
+ pts[2].x = 2 * box.x_ctr - pts[0].x;
+ pts[2].y = 2 * box.y_ctr - pts[0].y;
+ pts[3].x = 2 * box.x_ctr - pts[1].x;
+ pts[3].y = 2 * box.y_ctr - pts[1].y;
+}
+
+template
+HOST_DEVICE_INLINE int get_intersection_points(
+ const Point (&pts1)[4],
+ const Point (&pts2)[4],
+ Point