|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import csv |
|
import onnxruntime as ort |
|
|
|
from PIL import Image |
|
from onnxruntime import InferenceSession |
|
from modules.config import path_clip_vision |
|
from modules.model_loader import load_file_from_url |
|
|
|
|
|
global_model = None |
|
global_csv = None |
|
|
|
|
|
def default_interrogator(image_rgb, threshold=0.35, character_threshold=0.85, exclude_tags=""): |
|
global global_model, global_csv |
|
|
|
model_name = "wd-v1-4-moat-tagger-v2" |
|
|
|
model_onnx_filename = load_file_from_url( |
|
url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.onnx', |
|
model_dir=path_clip_vision, |
|
file_name=f'{model_name}.onnx', |
|
) |
|
|
|
model_csv_filename = load_file_from_url( |
|
url=f'https://huggingface.co/lllyasviel/misc/resolve/main/{model_name}.csv', |
|
model_dir=path_clip_vision, |
|
file_name=f'{model_name}.csv', |
|
) |
|
|
|
if global_model is not None: |
|
model = global_model |
|
else: |
|
model = InferenceSession(model_onnx_filename, providers=ort.get_available_providers()) |
|
global_model = model |
|
|
|
input = model.get_inputs()[0] |
|
height = input.shape[1] |
|
|
|
image = Image.fromarray(image_rgb) |
|
ratio = float(height)/max(image.size) |
|
new_size = tuple([int(x*ratio) for x in image.size]) |
|
image = image.resize(new_size, Image.LANCZOS) |
|
square = Image.new("RGB", (height, height), (255, 255, 255)) |
|
square.paste(image, ((height-new_size[0])//2, (height-new_size[1])//2)) |
|
|
|
image = np.array(square).astype(np.float32) |
|
image = image[:, :, ::-1] |
|
image = np.expand_dims(image, 0) |
|
|
|
if global_csv is not None: |
|
csv_lines = global_csv |
|
else: |
|
csv_lines = [] |
|
with open(model_csv_filename) as f: |
|
reader = csv.reader(f) |
|
next(reader) |
|
for row in reader: |
|
csv_lines.append(row) |
|
global_csv = csv_lines |
|
|
|
tags = [] |
|
general_index = None |
|
character_index = None |
|
for line_num, row in enumerate(csv_lines): |
|
if general_index is None and row[2] == "0": |
|
general_index = line_num |
|
elif character_index is None and row[2] == "4": |
|
character_index = line_num |
|
tags.append(row[1]) |
|
|
|
label_name = model.get_outputs()[0].name |
|
probs = model.run([label_name], {input.name: image})[0] |
|
|
|
result = list(zip(tags, probs[0])) |
|
|
|
general = [item for item in result[general_index:character_index] if item[1] > threshold] |
|
character = [item for item in result[character_index:] if item[1] > character_threshold] |
|
|
|
all = character + general |
|
remove = [s.strip() for s in exclude_tags.lower().split(",")] |
|
all = [tag for tag in all if tag[0] not in remove] |
|
|
|
res = ", ".join((item[0].replace("(", "\\(").replace(")", "\\)") for item in all)).replace('_', ' ') |
|
return res |
|
|