|
import gc |
|
import os |
|
import os.path |
|
import re |
|
import json |
|
import shutil |
|
from tqdm import tqdm |
|
import torch |
|
from statistics import mean |
|
import csv |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from importlib import reload |
|
from pprint import pprint |
|
import gradio as gr |
|
from modules import (script_callbacks, sd_models,sd_vae, shared) |
|
from modules.scripts import basedir |
|
from modules.sd_models import checkpoints_loaded, load_model,unload_model_weights |
|
from modules.shared import opts |
|
from modules.sd_samplers import samplers |
|
from modules.ui import create_output_panel, create_refresh_button |
|
import scripts.mergers.mergers |
|
import scripts.mergers.pluslora |
|
import scripts.mergers.xyplot |
|
import scripts.mergers.components as components |
|
from importlib import reload |
|
reload(scripts.mergers.mergers) |
|
reload(scripts.mergers.xyplot) |
|
reload(scripts.mergers.pluslora) |
|
import csv |
|
import scripts.mergers.pluslora as pluslora |
|
from scripts.mergers.mergers import (TYPESEG,EXCLUDE_CHOICES, freezemtime, rwmergelog, blockfromkey, clearcache, getcachelist) |
|
from scripts.mergers.xyplot import freezetime, nulister |
|
from scripts.mergers.model_util import filenamecutter, savemodel |
|
|
|
path_root = basedir() |
|
xyzpath = os.path.join(path_root,"xyzpresets.json") |
|
|
|
CALCMODES = ["normal", "cosineA", "cosineB","trainDifference","smoothAdd","smoothAdd MT","extract","tensor","tensor2","self","plus random"] |
|
|
|
class ResizeHandleRow(gr.Row): |
|
"""Same as gr.Row but fits inside gradio forms""" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.elem_classes.append("resize-handle-row") |
|
|
|
def get_block_name(self): |
|
return "row" |
|
|
|
from typing import Union |
|
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): |
|
self.network_current_names = () |
|
self.network_weights_backup = None |
|
self.network_bias_backup = None |
|
|
|
|
|
def fix_network_reset_cached_weight(): |
|
try: |
|
import networks as net |
|
net.network_reset_cached_weight = network_reset_cached_weight |
|
except: |
|
pass |
|
|
|
def on_ui_tabs(): |
|
fix_network_reset_cached_weight() |
|
|
|
weights_presets="" |
|
userfilepath = os.path.join(path_root, "scripts","mbwpresets.txt") |
|
|
|
if os.path.isfile(userfilepath): |
|
try: |
|
with open(userfilepath) as f: |
|
weights_presets = f.read() |
|
filepath = userfilepath |
|
except OSError as e: |
|
pass |
|
else: |
|
filepath = os.path.join(path_root, "scripts","mbwpresets_master.txt") |
|
try: |
|
with open(filepath) as f: |
|
weights_presets = f.read() |
|
shutil.copyfile(filepath, userfilepath) |
|
except OSError as e: |
|
pass |
|
|
|
if "ALLR" not in weights_presets: weights_presets += ADDRAND |
|
|
|
with gr.Blocks() as supermergerui: |
|
with gr.Tab("Merge"): |
|
with ResizeHandleRow(equal_height=False): |
|
with gr.Column(variant="compact"): |
|
gr.HTML(value="<p>Merge models and load it for generation</p>") |
|
|
|
with gr.Row(): |
|
s_reverse= gr.Button(value="Load settings from:",elem_classes=["compact_button"],variant='primary') |
|
mergeid = gr.Textbox(label="merged model ID (-1 for last)", elem_id="model_converter_custom_name",value = "-1") |
|
mclearcache= gr.Button(value="Clear Cache",elem_classes=["compact_button"],variant='primary') |
|
|
|
with gr.Row(variant="compact"): |
|
model_a = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Model A",interactive=True) |
|
create_refresh_button(model_a, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z") |
|
|
|
model_b = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Model B",interactive=True) |
|
create_refresh_button(model_b, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z") |
|
|
|
model_c = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Model C",interactive=True) |
|
create_refresh_button(model_c, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z") |
|
|
|
mode = gr.Radio(label = "Merge Mode",choices = ["Weight sum", "Add difference", "Triple sum", "sum Twice"], value="Weight sum", info="A*(1-alpha)+B*alpha") |
|
calcmode = gr.Radio(label = "Calculation Mode",choices = CALCMODES, value = "normal") |
|
with gr.Row(variant="compact"): |
|
with gr.Column(scale = 1): |
|
useblocks = gr.Checkbox(label="use MBW", info="use Merge Block Weights") |
|
with gr.Column(scale = 3), gr.Group() as alpha_group: |
|
with gr.Row(): |
|
base_alpha = gr.Slider(label="alpha", minimum=-1.0, maximum=2, step=0.001, value=0.5) |
|
base_beta = gr.Slider(label="beta", minimum=-1.0, maximum=2, step=0.001, value=0.25, interactive=False) |
|
|
|
|
|
with gr.Accordion("Options", open=False): |
|
with gr.Row(variant="compact"): |
|
save_sets = gr.CheckboxGroup(["use cuda","save model", "overwrite","safetensors","fp16","save metadata","copy config","prune","Reset CLIP ids","use old calc method","debug"], value=["safetensors"], show_label=False, label="save settings") |
|
with gr.Row(): |
|
components.id_sets = gr.CheckboxGroup(["image", "PNG info"], label="save merged model ID to") |
|
opt_value = gr.Slider(label="option(gamma) ", minimum=-1.0, maximum=20, step=0.1, value=0.3, interactive=True) |
|
with gr.Row(variant="compact"): |
|
with gr.Column(min_width = 50): |
|
with gr.Row(): |
|
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="model_converter_custom_name") |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae") |
|
create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae") |
|
|
|
with gr.Row(variant="compact"): |
|
savecurrent = gr.Button(elem_id="savecurrent", elem_classes=["compact_button"], value="Save current merge(fp16 only)") |
|
|
|
with gr.Row(): |
|
components.merge = gr.Button(elem_id="model_merger_merge", elem_classes=["compact_button"], value="Merge!",variant='primary') |
|
components.mergeandgen = gr.Button(elem_id="model_merger_merge", elem_classes=["compact_button"], value="Merge&Gen",variant='primary') |
|
components.gen = gr.Button(elem_id="model_merger_merge", elem_classes=["compact_button"], value="Gen",variant='primary') |
|
stopmerge = gr.Button(elem_id="stopmerge", elem_classes=["compact_button"], value="Stop") |
|
|
|
|
|
with gr.Accordion("Merging Block Weights", open=False): |
|
with gr.Row(): |
|
isxl = gr.Radio(label = "Block Type",choices = ["1.X or 2.X", "XL"], value = "1.X or 2.X", type="index") |
|
|
|
with gr.Tab("Weights Setting"): |
|
with gr.Group(), gr.Tabs(): |
|
with gr.Tab("Weights for alpha"): |
|
with gr.Row(variant="compact"): |
|
weights_a = gr.Textbox(label="BASE,IN00,IN02,...IN11,M00,OUT00,...,OUT11",value = "0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5", show_copy_button=True) |
|
with gr.Row(scale=2): |
|
setalpha = gr.Button(elem_id="copytogen", value="↑ Set alpha",variant='primary', scale=3) |
|
readalpha = gr.Button(elem_id="copytogen", value="↓ Read alpha",variant='primary', scale=3) |
|
setx = gr.Button(elem_id="copytogen", value="↑ Set X", min_width="80px", scale=1) |
|
with gr.Tab("beta"): |
|
with gr.Row(variant="compact"): |
|
weights_b = gr.Textbox(label="BASE,IN00,IN02,...IN11,M00,OUT00,...,OUT11",value = "0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2", show_copy_button=True) |
|
with gr.Row(scale=2): |
|
setbeta = gr.Button(elem_id="copytogen", value="↑ Set beta",variant='primary', scale=3) |
|
readbeta = gr.Button(elem_id="copytogen", value="↓ Read beta",variant='primary', scale=3) |
|
sety = gr.Button(elem_id="copytogen", value="↑ Set Y", min_width="80px", scale=1) |
|
|
|
with gr.Group(), gr.Tabs(): |
|
with gr.Tab("Preset"): |
|
with gr.Row(): |
|
dd_preset_weight = gr.Dropdown(label="Select preset", choices=preset_name_list(weights_presets), interactive=True, elem_id="refresh_presets") |
|
preset_refresh = gr.Button(value='\U0001f504', elem_classes=["tool"]) |
|
|
|
with gr.Tab("Random Preset"): |
|
with gr.Row(): |
|
dd_preset_weight_r = gr.Dropdown(label="Load Romdom preset", choices=preset_name_list(weights_presets,True), interactive=True, elem_id="refresh_presets") |
|
preset_refresh_r = gr.Button(value='\U0001f504', elem_classes=["tool"]) |
|
luckab = gr.Radio(label = "for",choices = ["none", "alpha", "beta"], value = "none", type="value") |
|
|
|
with gr.Tab("Helper"): |
|
with gr.Column(): |
|
resetval = gr.Slider(label="Value", show_label=False, info="Value to set/add/mul", minimum=0, maximum=2, step=0.0001, value=0) |
|
resetopt = gr.Radio(label="Pre defined", show_label=False, choices = ["0", "0.25", "0.5", "0.75", "1"], value = "0", type="value") |
|
with gr.Column(): |
|
resetblockopt = gr.CheckboxGroup(["BASE","INP*","MID","OUT*"], value=["INP*","OUT*"], label="Blocks", show_label=False, info="Select blocks to change") |
|
with gr.Column(): |
|
with gr.Row(): |
|
resetweight = gr.Button(elem_classes=["reset"], value="Set") |
|
addweight = gr.Button(elem_classes=["reset"], value="Add") |
|
mulweight = gr.Button(elem_classes=["reset"], value="Mul") |
|
with gr.Row(): |
|
lower = gr.Slider(label="Slider Lower Limit", minimum=-2, maximum=3, step=0.1, value=0) |
|
upper = gr.Slider(label="Slider Upper Limit", minimum=-2, maximum=3, step=0.1, value=1) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=100): |
|
gr.Slider(visible=False) |
|
with gr.Column(scale=2, min_width=200): |
|
base = gr.Slider(label="Base", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
with gr.Column(scale=1, min_width=100): |
|
gr.Slider(visible=False) |
|
with gr.Row(): |
|
with gr.Column(scale=2, min_width=200): |
|
in00 = gr.Slider(label="IN00", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
in01 = gr.Slider(label="IN01", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
in02 = gr.Slider(label="IN02", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
in03 = gr.Slider(label="IN03", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
in04 = gr.Slider(label="IN04", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
in05 = gr.Slider(label="IN05", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
in06 = gr.Slider(label="IN06", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
in07 = gr.Slider(label="IN07", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
in08 = gr.Slider(label="IN08", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
in09 = gr.Slider(label="IN09", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
in10 = gr.Slider(label="IN10", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
in11 = gr.Slider(label="IN11", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
with gr.Column(scale=2, min_width=200): |
|
ou11 = gr.Slider(label="OUT11", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
ou10 = gr.Slider(label="OUT10", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
ou09 = gr.Slider(label="OUT09", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
ou08 = gr.Slider(label="OUT08", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
ou07 = gr.Slider(label="OUT07", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
ou06 = gr.Slider(label="OUT06", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
ou05 = gr.Slider(label="OUT05", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
ou04 = gr.Slider(label="OUT04", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
ou03 = gr.Slider(label="OUT03", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
ou02 = gr.Slider(label="OUT02", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
ou01 = gr.Slider(label="OUT01", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
ou00 = gr.Slider(label="OUT00", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=100): |
|
gr.Slider(visible=False) |
|
with gr.Column(scale=2, min_width=200): |
|
mi00 = gr.Slider(label="M00", minimum=0, maximum=1, step=0.0001, value=0.5) |
|
with gr.Column(scale=1, min_width=100): |
|
gr.Slider(visible=False) |
|
|
|
with gr.Tab("Weights Presets"): |
|
with gr.Row(): |
|
s_reloadtext = gr.Button(value="Reload Presets",variant='primary') |
|
s_reloadtags = gr.Button(value="Reload Tags",variant='primary') |
|
s_savetext = gr.Button(value="Save Presets",variant='primary') |
|
s_openeditor = gr.Button(value="Open TextEditor",variant='primary') |
|
weightstags= gr.Textbox(label="available",lines = 2,value=tagdicter(weights_presets),visible =True,interactive =True) |
|
wpresets= gr.TextArea(label="",value=(weights_presets+ADDRAND),visible =True,interactive = True) |
|
|
|
with gr.Accordion("XYZ Plot", open=False): |
|
with gr.Row(): |
|
x_type = gr.Dropdown(label="X type", choices=[x for x in TYPESEG], value="alpha", type="index") |
|
x_randseednum = gr.Number(value=3, label="number of -1", interactive=True, visible = True) |
|
xgrid = gr.Textbox(label="X Values",lines=3,value="0.25,0.5,0.75") |
|
y_type = gr.Dropdown(label="Y type", choices=[y for y in TYPESEG], value="none", type="index") |
|
ygrid = gr.Textbox(label="Y Values (Disabled if blank)",lines=3,value="",visible =False) |
|
z_type = gr.Dropdown(label="Z type", choices=[y for y in TYPESEG], value="none", type="index") |
|
zgrid = gr.Textbox(label="Z Values (Disabled if blank)",lines=3,value="",visible =False) |
|
esettings = gr.CheckboxGroup(label = "XYZ plot settings",choices=["swap XY","save model","save csv","save anime gif","not save grid","print change","0 stock"],type="value",interactive=True) |
|
|
|
with gr.Row(): |
|
components.gengrid = gr.Button(elem_id="model_merger_merge", value="Run XYZ Plot",variant='primary') |
|
stopgrid = gr.Button(elem_id="model_merger_merge", value="Stop XYZ Plot") |
|
components.s_reserve1 = gr.Button(value="Reserve XYZ Plot",variant='primary') |
|
|
|
with gr.Accordion("XYZ presets",open = True): |
|
with gr.Row(): |
|
xyzpresets = gr.Dropdown(label="Preset name",allow_custom_value=True,choices=get_xyzpreset_keylist(),scale=10) |
|
refreshxyzpresets_b = gr.Button(value='\U0001f504', elem_classes=["tool"],scale=1) |
|
savexyzpreset_overwrite = gr.CheckboxGroup(label = " ",choices=["Overwrite"],type="index",interactive=True,scale=1) |
|
with gr.Row(): |
|
loadxyzpreset_b = gr.Button(value="Load preset",variant='primary') |
|
savexyzpreset_b = gr.Button(value="Save current plot as preset",variant='primary') |
|
deletexyzpreset_b = gr.Button(value="Delete preset",variant='primary') |
|
openxyzpreset = gr.Button(value="Open XYZ Preset file") |
|
|
|
openxyzpreset.click(fn=lambda:subprocess.Popen(['start', xyzpath], shell=True)) |
|
|
|
with gr.Column(visible = False, variant="compact") as row_inputers: |
|
with gr.Row(variant="compact"): |
|
inputer = gr.Textbox(label="Selected", lines=1, value="", show_copy_button=True) |
|
with gr.Row(variant="compact"): |
|
addtox = gr.Button(value="↑ Add to X Values") |
|
addtoy = gr.Button(value="↑ Add to Y Values") |
|
addtoz = gr.Button(value="↑ Add to Z Values") |
|
with gr.Row(visible = False) as row_blockids: |
|
blockids = gr.CheckboxGroup(label = "block IDs",choices=BLOCKID[:-1],type="value",interactive=True) |
|
with gr.Row(visible = False) as row_calcmode: |
|
calcmodes = gr.CheckboxGroup(label = "calcmode",choices=CALCMODES,type="value",interactive=True) |
|
with gr.Row(visible = False) as row_checkpoints: |
|
checkpoints = gr.CheckboxGroup(label = "checkpoints",choices=[x.model_name for x in sd_models.checkpoints_list.values()],type="value",interactive=True) |
|
create_refresh_button(checkpoints, sd_models.list_models, lambda: {"choices": [x.model_name for x in sd_models.checkpoints_list.values()]}, "refresh_checkpoint_xyz") |
|
with gr.Row(visible = False) as row_blocks: |
|
gr.HTML(value="<p>BASE,IN00,IN01,IN02,IN03,IN04,IN05,IN06,IN07,IN08,IN09,IN10,IN11<br>,M00,OUT00,OUT01,OUT02,OUT03,OUT04,OUT05,OUT06,OUT07,OUT08,OUT09,OUT10,OUT11,Adjust,VAE,print</p>") |
|
|
|
with gr.Accordion("Reservation", open=False): |
|
with gr.Row(): |
|
components.s_reserve = gr.Button(value="Reserve XY Plot",variant='primary') |
|
s_reloadreserve = gr.Button(value="Reloat List",variant='primary') |
|
components.s_startreserve = gr.Button(value="Start XY plot",variant='primary') |
|
s_delreserve = gr.Button(value="Delete list(-1 for all)",variant='primary') |
|
s_delnum = gr.Number(value=1, label="Delete num : ", interactive=True, visible = True,precision =0) |
|
with gr.Row(): |
|
components.numaframe = gr.Dataframe( |
|
headers=["No.","status","xtype","xmenber","ytype","ymenber","ztype","zmenber","model A","model B","model C","alpha","beta","mode","use MBW","weights alpha","weights beta"], |
|
row_count=5,) |
|
|
|
components.dtrue = gr.Checkbox(value = True, visible = False) |
|
components.dfalse = gr.Checkbox(value = False,visible = False) |
|
dummy_t = gr.Textbox(value = "",visible = False) |
|
|
|
with gr.Accordion("Elemental Merge",open = False): |
|
with gr.Row(): |
|
components.esettings1 = gr.CheckboxGroup(label = "settings",choices=["print change"],type="value",interactive=True) |
|
with gr.Row(): |
|
deep = gr.Textbox(label="Blocks:Element:Ratio,Blocks:Element:Ratio,...",lines=2,value="") |
|
|
|
with gr.Accordion("Adjust", open=False) as acc_ad: |
|
with gr.Row(variant="compact"): |
|
finetune = gr.Textbox(label="Adjust", show_label=False, info="Adjust IN,OUT,OUT2,Contrast,Brightness,COL1,COL2,COL3", visible=True, value="", lines=1) |
|
finetune_write = gr.Button(value="↑", elem_classes=["tool"]) |
|
finetune_read = gr.Button(value="↓", elem_classes=["tool"]) |
|
finetune_reset = gr.Button(value="\U0001f5d1\ufe0f", elem_classes=["tool"]) |
|
with gr.Row(variant="compact"): |
|
with gr.Column(scale=1, min_width=100): |
|
detail1 = gr.Slider(label="IN", minimum=-6, maximum=6, step=0.01, value=0, info="Detail/Noise") |
|
with gr.Column(scale=1, min_width=100): |
|
detail2 = gr.Slider(label="OUT", minimum=-6, maximum=6, step=0.01, value=0, info="Detail/Noise") |
|
with gr.Column(scale=1, min_width=100): |
|
detail3 = gr.Slider(label="OUT2", minimum=-6, maximum=6, step=0.01, value=0, info="Detail/Noise") |
|
with gr.Row(variant="compact"): |
|
with gr.Column(scale=1, min_width=100): |
|
contrast = gr.Slider(label="Contrast", minimum=-10, maximum=10, step=0.01, value=0, info="Contrast/Detail") |
|
with gr.Column(scale=1, min_width=100): |
|
bri = gr.Slider(label="Brightness", minimum=-10, maximum=10, step=0.01, value=0, info="Dark(Minius)-Bright(Plus)") |
|
with gr.Row(variant="compact"): |
|
with gr.Column(scale=1, min_width=100): |
|
col1 = gr.Slider(label="Cyan-Red", minimum=-10, maximum=10, step=0.01, value=0, info="Cyan(Minius)-Red(Plus)") |
|
with gr.Column(scale=1, min_width=100): |
|
col2 = gr.Slider(label="Magenta-Green", minimum=-10, maximum=10, step=0.01, value=0, info="Magenta(Minius)-Green(Plus)") |
|
with gr.Column(scale=1, min_width=100): |
|
col3 = gr.Slider(label="Yellow-Blue", minimum=-10, maximum=10, step=0.01, value=0, info="Yellow(Minius)-Blue(Plus)") |
|
|
|
finetune.change(fn=lambda x:gr.update(label = f"Adjust : {x}"if x != "" and x !="0,0,0,0,0,0,0,0" else "Adjust"),inputs=[finetune],outputs = [acc_ad]) |
|
|
|
with gr.Accordion("Let the Dice roll",open = False,visible=True): |
|
with gr.Row(): |
|
gr.HTML(value="<p>R:0~1, U: -0.5~1.5</p>") |
|
with gr.Row(): |
|
luckmode = gr.Radio(label = "Random Mode",choices = ["off", "R", "U", "X", "ER", "EU", "EX","custom"], value = "off") |
|
with gr.Row(): |
|
lucksets = gr.CheckboxGroup(label = "Settings",choices=["alpha","beta","save E-list"],value=["alpha"],type="value",interactive=True) |
|
with gr.Row(): |
|
luckseed = gr.Number(minimum=-1, maximum=4294967295, step=1, label='Seed for Random Ratio', value=-1, elem_id="luckseed") |
|
luckround = gr.Number(minimum=1, maximum=4294967295, step=1, label='Round', value=3, elem_id="luckround") |
|
luckserial = gr.Number(minimum=1, maximum=4294967295, step=1, label='Num of challenge', value=1, elem_id="luckchallenge") |
|
with gr.Row(): |
|
luckcustom = gr.Textbox(label="custom",value = "U,0,0,0,0,0,0,0,0,0,0,0,0,R,R,R,R,R,R,R,R,R,R,R,R,R") |
|
with gr.Row(): |
|
lucklimits_u = gr.Textbox(label="Upper limit for X",value = "1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1") |
|
with gr.Row(): |
|
lucklimits_l = gr.Textbox(label="Lower limit for X",value = "0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0") |
|
components.rand_merge = gr.Button(elem_id="runrandmerge", value="Run Rand",variant='primary') |
|
|
|
with gr.Accordion("Generation Parameters",open = False): |
|
gr.HTML(value='If blank or set to 0, parameters in the "txt2img" tab are used.<br>batch size, restore face, hires fix settigns must be set here') |
|
prompt = gr.Textbox(label="prompt",lines=1,value="") |
|
neg_prompt = gr.Textbox(label="neg_prompt",lines=1,value="") |
|
with gr.Row(): |
|
sampler = gr.Dropdown(label='Sampling method', elem_id=f"sampling", choices=[" ",*[x.name for x in samplers]], value=" ", type="index") |
|
steps = gr.Slider(minimum=0.0, maximum=150, step=1, label='Steps',value=0, elem_id="Steps") |
|
cfg = gr.Slider(minimum=0.0, maximum=30, step=0.5, label='CFG scale', value=0, elem_id="cfg") |
|
with gr.Row(): |
|
width = gr.Slider(minimum=0, maximum=2048, step=8, label="Width", value=0, elem_id="txt2img_width") |
|
height = gr.Slider(minimum=0, maximum=2048, step=8, label="Height", value=0, elem_id="txt2img_height") |
|
seed = gr.Number(minimum=-1, maximum=4294967295, step=1, label='Seed', value=0, elem_id="seed") |
|
batch_size = denois_str = gr.Slider(minimum=0, maximum=8, step=1, label='Batch size', value=1, elem_id="sm_txt2img_batch_size") |
|
genoptions = gr.CheckboxGroup(label = "Gen Options",choices=["Restore faces", "Tiling", "Hires. fix"], visible = True,interactive=True,type="value") |
|
with gr.Row(elem_id="txt2img_hires_fix_row1", variant="compact"): |
|
hrupscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) |
|
hr2ndsteps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") |
|
denois_str = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") |
|
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") |
|
with gr.Row(): |
|
setdefault = gr.Button(elem_id="setdefault", value="set to default",variant='primary') |
|
resetdefault = gr.Button(elem_id="resetdefault", value="reset default",variant='primary') |
|
resetcurrent = gr.Button(elem_id="resetcurrent", value="reset current",variant='primary') |
|
|
|
with gr.Accordion("Include/Exclude", open=False) as acc_ex: |
|
with gr.Row(): |
|
inex = gr.Radio(label="Mode", choices=["Off","Include","Exclude"], value="Off") |
|
with gr.Row(): |
|
ex_blocks = gr.CheckboxGroup(choices=EXCLUDE_CHOICES + ["print"], visible = True,interactive=True,type="value") |
|
with gr.Row(): |
|
ex_elems = gr.Textbox(label="Elements") |
|
inex.change(fn=lambda i, x,y: gr.update(label =f"{i} : " + ",".join(x) +","+ y if x != [] or y != "" else "Include/Exclude"), inputs = [inex,ex_blocks,ex_elems],outputs = [acc_ex]) |
|
ex_blocks.change(fn=lambda i, x,y: gr.update(label =f"{i} : " + ",".join(x) +","+ y if x != [] or y != "" else "Include/Exclude"), inputs = [inex,ex_blocks,ex_elems],outputs = [acc_ex]) |
|
ex_elems.change(fn=lambda i, x,y: gr.update(label =f"{i} : " + ",".join(x) +","+ y if x != [] or y != "" else "Include/Exclude"),inputs=[inex,ex_blocks,ex_elems],outputs = [acc_ex]) |
|
|
|
with gr.Accordion("Advanced", open=False): |
|
with gr.Row(): |
|
currentcache = gr.Textbox(label="Current Cache") |
|
loadcachelist = gr.Button(elem_id="model_merger_merge", value="Reload Cache List",variant='primary') |
|
unloadmodel = gr.Button(value="unload model",variant='primary') |
|
|
|
with gr.Column(variant="compact"): |
|
components.currentmodel = gr.Textbox(label="Current Model",lines=1,value="") |
|
components.submit_result = gr.Textbox(label="Message") |
|
|
|
output_panel = create_output_panel("txt2img", opts.outdir_txt2img_samples) |
|
|
|
mgallery = output_panel[0] if isinstance(output_panel, tuple) else output_panel.gallery |
|
mgeninfo = output_panel[1] if isinstance(output_panel, tuple) else output_panel.generation_info |
|
mhtmlinfo = output_panel[2] if isinstance(output_panel, tuple) else output_panel.infotext |
|
mhtmllog = output_panel[3] if isinstance(output_panel, tuple) else output_panel.html_log |
|
|
|
|
|
|
|
with gr.Tab("LoRA", elem_id="tab_lora"): |
|
pluslora.on_ui_tabs() |
|
|
|
with gr.Tab("Analysis", elem_id="tab_analysis"): |
|
with gr.Tab("Models"): |
|
with gr.Row(): |
|
an_model_a = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Checkpoint A",interactive=True) |
|
create_refresh_button(an_model_a, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z") |
|
an_model_b = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Checkpoint B",interactive=True) |
|
create_refresh_button(an_model_b, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z") |
|
with gr.Row(): |
|
an_mode = gr.Radio(label = "Analysis Mode",choices = ["ASimilarity","Block","Element","Both"], value = "ASimilarity",type = "value") |
|
an_calc = gr.Radio(label = "Block method",choices = ["Mean","Min","attn2"], value = "Mean",type = "value") |
|
an_include = gr.CheckboxGroup(label = "Include",choices = ["Textencoder(BASE)","U-Net","VAE"], value = ["Textencoder(BASE)","U-Net"],type = "value") |
|
an_settings = gr.CheckboxGroup(label = "Settings",choices=["save as txt", "save as csv"],type="value",interactive=True) |
|
with gr.Row(): |
|
run_analysis = gr.Button(value="Run Analysis",variant='primary') |
|
with gr.Row(): |
|
analysis_cosdif = gr.Dataframe(headers=["block","key","similarity[%]"],) |
|
with gr.Tab("Text Encoder"): |
|
with gr.Row(): |
|
te_smd_loadkeys = gr.Button(value="Calculate Textencoer",variant='primary') |
|
te_smd_searchkeys = gr.Button(value="Search Word(red,blue,girl,...)",variant='primary') |
|
exclude = gr.Checkbox(label="exclude non numeric,alphabet,symbol word") |
|
pickupword = gr.TextArea() |
|
encoded = gr.Dataframe() |
|
|
|
run_analysis.click(fn=calccosinedif,inputs=[an_model_a,an_model_b,an_mode,an_settings,an_include,an_calc],outputs=[analysis_cosdif]) |
|
|
|
with gr.Tab("History", elem_id="tab_history"): |
|
|
|
with gr.Row(): |
|
with gr.Column(scale = 2): |
|
with gr.Row(): |
|
count = gr.Dropdown(choices=["20", "30", "40", "50", "100"], value="20", label="Load count") |
|
load_history = gr.Button(value="Load history",variant='primary', elem_classes=["reset"]) |
|
reload_history = gr.Button(value="Reload history", elem_classes=["reset"]) |
|
with gr.Column(scale = 2): |
|
with gr.Row(): |
|
searchwrods = gr.Textbox(label="",lines=1,value="") |
|
search = gr.Button(value="search", elem_classes=["reset"]) |
|
searchmode = gr.Radio(label = "Search Mode",choices = ["or","and"], value = "or",type = "value") |
|
with gr.Row(): |
|
history = gr.Dataframe( |
|
headers=["ID","Time","Name","Weights alpha","Weights beta","Model A","Model B","Model C","alpha","beta","Mode","use MBW","custum name","save setting","use ID"], |
|
) |
|
|
|
import lora |
|
|
|
with gr.Tab("Elements", elem_id="tab_deep"): |
|
with gr.Row(): |
|
smd_model_a = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="Checkpoint",interactive=True) |
|
create_refresh_button(smd_model_a, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z") |
|
smd_loadkeys = gr.Button(value="load keys",variant='primary') |
|
with gr.Row(): |
|
smd_lora = gr.Dropdown(list(lora.available_loras.keys()),elem_id="model_converter_model_name",label="LoRA",interactive=True) |
|
create_refresh_button(smd_lora, lora.list_available_loras, lambda: {"choices": list(lora.available_loras.keys())},"refresh_checkpoint_Z") |
|
smd_loadkeys_l = gr.Button(value="load keys",variant='primary') |
|
with gr.Row(): |
|
keys = gr.Dataframe(headers=["No.","block","key"],) |
|
|
|
with gr.Tab("Metadeta", elem_id="tab_metadata"): |
|
with gr.Row(): |
|
meta_model_a = gr.Dropdown(sd_models.checkpoint_tiles(),elem_id="model_converter_model_name",label="read metadata",interactive=True) |
|
create_refresh_button(meta_model_a, sd_models.list_models,lambda: {"choices": sd_models.checkpoint_tiles()},"refresh_checkpoint_Z") |
|
smd_loadmetadata = gr.Button(value="load keys",variant='primary') |
|
with gr.Row(): |
|
metadata = gr.TextArea() |
|
|
|
smd_loadmetadata.click( |
|
fn=loadmetadata, |
|
inputs=[meta_model_a], |
|
outputs=[metadata] |
|
) |
|
|
|
mclearcache.click(fn=clearcache) |
|
smd_loadkeys.click(fn=loadkeys,inputs=[smd_model_a,components.dfalse],outputs=[keys]) |
|
smd_loadkeys_l.click(fn=loadkeys,inputs=[smd_lora,components.dtrue],outputs=[keys]) |
|
|
|
te_smd_loadkeys.click(fn=encodetexts,inputs=[exclude],outputs=[encoded]) |
|
te_smd_searchkeys.click(fn=pickupencode,inputs=[pickupword],outputs=[encoded]) |
|
|
|
|
|
def unload(): |
|
if shared.sd_model == None: return "already unloaded" |
|
load_model,unload_model_weights() |
|
return "model unloaded" |
|
|
|
unloadmodel.click(fn=unload,outputs=[components.submit_result]) |
|
|
|
load_history.click(fn=load_historyf,inputs=[history,count],outputs=[history]) |
|
reload_history.click(fn=load_historyf,inputs=[history,count,components.dtrue],outputs=[history]) |
|
|
|
components.msettings=[weights_a,weights_b,model_a,model_b,model_c,base_alpha,base_beta,mode,calcmode,useblocks,custom_name,save_sets,components.id_sets,wpresets,deep,finetune,bake_in_vae,opt_value,inex,ex_blocks,ex_elems] |
|
components.imagegal = [mgallery,mgeninfo,mhtmlinfo,mhtmllog] |
|
components.xysettings=[x_type,xgrid,y_type,ygrid,z_type,zgrid,esettings] |
|
components.genparams=[prompt,neg_prompt,steps,sampler,cfg,seed,width,height,batch_size] |
|
components.hiresfix = [genoptions,hrupscaler,hr2ndsteps,denois_str,hr_scale] |
|
components.lucks = [luckmode,lucksets,lucklimits_u,lucklimits_l,luckseed,luckserial,luckcustom,luckround] |
|
|
|
setdefault.click(fn = configdealer, |
|
inputs =[*components.genparams,*components.hiresfix[1:],components.dfalse], |
|
) |
|
|
|
resetdefault.click(fn = configdealer, |
|
inputs =[*components.genparams,*components.hiresfix[1:],components.dtrue], |
|
) |
|
|
|
resetcurrent.click(fn = lambda x : [gr.update(value = x) for x in RESETVALS] ,outputs =[*components.genparams,*components.hiresfix[1:]],) |
|
|
|
s_reverse.click(fn = reversparams, |
|
inputs =mergeid, |
|
outputs = [components.submit_result,*components.msettings[0:8],*components.msettings[9:13],deep,calcmode,luckseed,finetune,opt_value,inex,ex_blocks,ex_elems] |
|
) |
|
|
|
search.click(fn = searchhistory,inputs=[searchwrods,searchmode],outputs=[history]) |
|
|
|
s_reloadreserve.click(fn=nulister,inputs=[components.dfalse],outputs=[components.numaframe]) |
|
s_delreserve.click(fn=nulister,inputs=[s_delnum],outputs=[components.numaframe]) |
|
loadcachelist.click(fn=getcachelist,inputs=[],outputs=[currentcache]) |
|
addtox.click(fn=lambda x:gr.Textbox.update(value = x),inputs=[inputer],outputs=[xgrid]) |
|
addtoy.click(fn=lambda x:gr.Textbox.update(value = x),inputs=[inputer],outputs=[ygrid]) |
|
addtoz.click(fn=lambda x:gr.Textbox.update(value = x),inputs=[inputer],outputs=[zgrid]) |
|
|
|
stopgrid.click(fn=freezetime) |
|
stopmerge.click(fn=freezemtime) |
|
|
|
checkpoints.change(fn=lambda x:",".join(x),inputs=[checkpoints],outputs=[inputer]) |
|
blockids.change(fn=lambda x:" ".join(x),inputs=[blockids],outputs=[inputer]) |
|
calcmodes.change(fn=lambda x:",".join(x),inputs=[calcmodes],outputs=[inputer]) |
|
|
|
menbers = [base,in00,in01,in02,in03,in04,in05,in06,in07,in08,in09,in10,in11,mi00,ou00,ou01,ou02,ou03,ou04,ou05,ou06,ou07,ou08,ou09,ou10,ou11] |
|
menbers_plus = menbers + [resetval] |
|
|
|
lower.change(fn = lambda x: [gr.update(minimum = x) for i in range(len(menbers_plus))],inputs = [lower],outputs = menbers_plus) |
|
upper.change(fn = lambda x: [gr.update(maximum = x) for i in range(len(menbers_plus))],inputs = [upper],outputs = menbers_plus) |
|
|
|
setalpha.click(fn=slider2text,inputs=[*menbers,wpresets, dd_preset_weight,isxl],outputs=[weights_a]) |
|
setbeta.click(fn=slider2text,inputs=[*menbers,wpresets, dd_preset_weight,isxl],outputs=[weights_b]) |
|
setx.click(fn=add_to_seq,inputs=[xgrid,weights_a],outputs=[xgrid]) |
|
sety.click(fn=add_to_seq,inputs=[ygrid,weights_b],outputs=[ygrid]) |
|
|
|
mode_info = { |
|
"Weight sum": "A*(1-alpha)+B*alpha", |
|
"Add difference": "A+(B-C)*alpha", |
|
"Triple sum": "A*(1-alpha-beta)+B*alpha+C*beta", |
|
"sum Twice": "(A*(1-alpha)+B*alpha)*(1-beta)+C*beta" |
|
} |
|
mode.change(fn=lambda mode,calcmode: [gr.update(info=mode_info[mode]), gr.update(interactive=True if mode in ["Triple sum", "sum Twice"] or calcmode in ["tensor", "tensor2"] else False)], inputs=[mode,calcmode], outputs=[mode, base_beta], show_progress=False) |
|
calcmode.change(fn=lambda calcmode: gr.update(interactive=True) if calcmode in ["tensor", "tensor2","extract"] else gr.update(), inputs=[calcmode], outputs=base_beta, show_progress=False) |
|
useblocks.change(fn=lambda mbw: gr.update(visible=False if mbw else True), inputs=[useblocks], outputs=[alpha_group]) |
|
|
|
def save_current_merge(custom_name, save_settings): |
|
msg = savemodel(None,None,custom_name,save_settings) |
|
return gr.update(value=msg) |
|
|
|
def addblockweights(val, blockopt, *blocks): |
|
if val == "none": |
|
val = 0 |
|
|
|
value = float(val) |
|
|
|
if "BASE" in blockopt: |
|
vals = [blocks[0] + value] |
|
else: |
|
vals = [blocks[0]] |
|
|
|
if "INP*" in blockopt: |
|
inp = [blocks[i + 1] + value for i in range(12)] |
|
else: |
|
inp = [blocks[i + 1] for i in range(12)] |
|
vals = vals + inp |
|
|
|
if "MID" in blockopt: |
|
mid = [blocks[13] + value] |
|
else: |
|
mid = [blocks[13]] |
|
vals = vals + mid |
|
|
|
if "OUT*" in blockopt: |
|
out = [blocks[i + 14] + value for i in range(12)] |
|
else: |
|
out = [blocks[i + 14] for i in range(12)] |
|
vals = vals + out |
|
|
|
return setblockweights(vals, blockopt) |
|
|
|
def mulblockweights(val, blockopt, *blocks): |
|
if val == "none": |
|
val = 0 |
|
|
|
value = float(val) |
|
|
|
if "BASE" in blockopt: |
|
vals = [blocks[0] * value] |
|
else: |
|
vals = [blocks[0]] |
|
|
|
if "INP*" in blockopt: |
|
inp = [blocks[i + 1] * value for i in range(12)] |
|
else: |
|
inp = [blocks[i + 1] for i in range(12)] |
|
vals = vals + inp |
|
|
|
if "MID" in blockopt: |
|
mid = [blocks[13] * value] |
|
else: |
|
mid = [blocks[13]] |
|
vals = vals + mid |
|
|
|
if "OUT*" in blockopt: |
|
out = [blocks[i + 14] * value for i in range(12)] |
|
else: |
|
out = [blocks[i + 14] for i in range(12)] |
|
vals = vals + out |
|
|
|
return setblockweights(vals, blockopt) |
|
|
|
def resetblockweights(val, blockopt): |
|
if val == "none": |
|
val = 0 |
|
vals = [float(val)] * 26 |
|
return setblockweights(vals, blockopt) |
|
|
|
def setblockweights(vals, blockopt): |
|
if "BASE" in blockopt: |
|
ret = [gr.update(value = vals[0])] |
|
else: |
|
ret = [gr.update()] |
|
|
|
if "INP*" in blockopt: |
|
inp = [gr.update(value = vals[i + 1]) for i in range(12)] |
|
else: |
|
inp = [gr.update() for _ in range(12)] |
|
ret = ret + inp |
|
|
|
if "MID" in blockopt: |
|
mid = [gr.update(value = vals[13])] |
|
else: |
|
mid = [gr.update()] |
|
ret = ret + mid |
|
|
|
if "OUT*" in blockopt: |
|
out = [gr.update(value = vals[i + 14]) for i in range(12)] |
|
else: |
|
out = [gr.update() for _ in range(12)] |
|
ret = ret + out |
|
|
|
return ret |
|
|
|
def resetvalopt(opt): |
|
if opt == "none": |
|
value = 0.0 |
|
else: |
|
value = float(opt) |
|
|
|
return gr.update(value = value) |
|
|
|
def finetune_update(finetune, detail1, detail2, detail3, contrast, bri, col1, col2, col3): |
|
arr = [detail1, detail2, detail3, contrast, bri, col1, col2, col3] |
|
tmp = ",".join(map(lambda x: str(int(x)) if x == 0.0 else str(x), arr)) |
|
if finetune != tmp: |
|
return gr.update(value=tmp) |
|
return gr.update() |
|
|
|
def finetune_reader(finetune): |
|
tmp = [t.strip() for t in finetune.split(",")] |
|
ret = [gr.update()]*7 |
|
for i, f in enumerate(tmp[0:7]): |
|
try: |
|
f = float(f) |
|
ret[i] = gr.update(value=f) |
|
except: |
|
pass |
|
return ret |
|
|
|
|
|
finetunes = [detail1, detail2, detail3, contrast, bri, col1, col2, col3] |
|
finetune_reset.click(fn=lambda: [gr.update(value="")]+[gr.update(value=0.0)]*8, inputs=[], outputs=[finetune, *finetunes]) |
|
finetune_read.click(fn=finetune_reader, inputs=[finetune], outputs=[*finetunes]) |
|
finetune_write.click(fn=finetune_update, inputs=[finetune, *finetunes], outputs=[finetune]) |
|
detail1.release(fn=finetune_update, inputs=[finetune, *finetunes], outputs=finetune, show_progress=False) |
|
detail2.release(fn=finetune_update, inputs=[finetune, *finetunes], outputs=finetune, show_progress=False) |
|
detail3.release(fn=finetune_update, inputs=[finetune, *finetunes], outputs=finetune, show_progress=False) |
|
contrast.release(fn=finetune_update, inputs=[finetune, *finetunes], outputs=finetune, show_progress=False) |
|
bri.release(fn=finetune_update, inputs=[finetune, *finetunes], outputs=finetune, show_progress=False) |
|
col1.release(fn=finetune_update, inputs=[finetune, *finetunes], outputs=finetune, show_progress=False) |
|
col2.release(fn=finetune_update, inputs=[finetune, *finetunes], outputs=finetune, show_progress=False) |
|
col3.release(fn=finetune_update, inputs=[finetune, *finetunes], outputs=finetune, show_progress=False) |
|
|
|
savecurrent.click(fn=save_current_merge, inputs=[custom_name, save_sets], outputs=[components.submit_result]) |
|
|
|
resetopt.change(fn=resetvalopt,inputs=[resetopt],outputs=[resetval]) |
|
resetweight.click(fn=resetblockweights,inputs=[resetval,resetblockopt],outputs=menbers) |
|
addweight.click(fn=addblockweights,inputs=[resetval,resetblockopt,*menbers],outputs=menbers) |
|
mulweight.click(fn=mulblockweights,inputs=[resetval,resetblockopt,*menbers],outputs=menbers) |
|
|
|
readalpha.click(fn=text2slider,inputs=[weights_a,isxl],outputs=menbers) |
|
readbeta.click(fn=text2slider,inputs=[weights_b,isxl],outputs=menbers) |
|
|
|
dd_preset_weight.change(fn=on_change_dd_preset_weight,inputs=[wpresets, dd_preset_weight],outputs=menbers) |
|
dd_preset_weight_r.change(fn=on_change_dd_preset_weight_r,inputs=[wpresets, dd_preset_weight_r,luckab],outputs=[weights_a,weights_b]) |
|
|
|
def refresh_presets(presets,rand,ab = ""): |
|
choices = preset_name_list(presets,rand) |
|
return gr.update(choices = choices) |
|
|
|
preset_refresh.click(fn=refresh_presets,inputs=[wpresets,components.dfalse],outputs=[dd_preset_weight]) |
|
preset_refresh_r.click(fn=refresh_presets,inputs=[wpresets,components.dtrue],outputs=[weights_a,weights_b]) |
|
|
|
def changexl(isxl): |
|
out = [True] * 26 |
|
if isxl: |
|
for i,id in enumerate(BLOCKID[:-1]): |
|
if id not in BLOCKIDXLL[:-1]: |
|
out[i] = False |
|
return [gr.update(visible = x) for x in out] |
|
|
|
isxl.change(fn=changexl,inputs=[isxl], outputs=menbers) |
|
|
|
x_type.change(fn=showxy,inputs=[x_type,y_type,z_type], outputs=[row_blockids,row_checkpoints,row_inputers,ygrid,zgrid,row_blocks,row_calcmode]) |
|
y_type.change(fn=showxy,inputs=[x_type,y_type,z_type], outputs=[row_blockids,row_checkpoints,row_inputers,ygrid,zgrid,row_blocks,row_calcmode]) |
|
z_type.change(fn=showxy,inputs=[x_type,y_type,z_type], outputs=[row_blockids,row_checkpoints,row_inputers,ygrid,zgrid,row_blocks,row_calcmode]) |
|
x_randseednum.change(fn=makerand,inputs=[x_randseednum],outputs=[xgrid]) |
|
|
|
import subprocess |
|
def openeditors(): |
|
subprocess.Popen(['start', filepath], shell=True) |
|
|
|
def reloadpresets(): |
|
try: |
|
with open(filepath) as f: |
|
weights_presets = f.read() |
|
choices = preset_name_list(weights_presets) |
|
return [weights_presets, gr.update(choices = choices)] |
|
except OSError as e: |
|
pass |
|
|
|
def savepresets(text): |
|
with open(filepath,mode = 'w') as f: |
|
f.write(text) |
|
|
|
s_reloadtext.click(fn=reloadpresets,inputs=[],outputs=[wpresets, dd_preset_weight]) |
|
s_reloadtags.click(fn=tagdicter,inputs=[wpresets],outputs=[weightstags]) |
|
s_savetext.click(fn=savepresets,inputs=[wpresets],outputs=[]) |
|
s_openeditor.click(fn=openeditors,inputs=[],outputs=[]) |
|
|
|
def savexyzpreset_f(xtype, xvals, ytype, yvals, ztype, zvals, name, mode_overwrite): |
|
new_data = {"xtype": TYPESEG[xtype], "xvalues": xvals, |
|
"ytype": TYPESEG[ytype], "yvalues": yvals, |
|
"ztype": TYPESEG[ztype], "zvalues": zvals |
|
} |
|
data = get_xyzpreset_data() |
|
|
|
if mode_overwrite: |
|
data[name] = new_data |
|
else: |
|
if name in data: |
|
gr.Info(f"Supermerger: Preset {name} already exists.") |
|
else: |
|
data[name] = new_data |
|
|
|
with open(xyzpath, 'w') as file: |
|
json.dump(data, file, indent=4) |
|
|
|
data_keys = list(data.keys()) |
|
return gr.update(choices = sorted(data_keys)) |
|
|
|
def deletexyzpreset_f(name): |
|
data = get_xyzpreset_data() |
|
|
|
try: del data[name] |
|
except KeyError: gr.Info(f"Supermerger: Preset {name} not found.") |
|
|
|
with open(xyzpath, 'w') as file: |
|
json.dump(data, file, indent=4) |
|
|
|
keys_list = list(data.keys()) |
|
return gr.update(choices = sorted(keys_list)) |
|
|
|
def loadxyzpreset_f(name): |
|
data = get_xyzpreset_data() |
|
|
|
preset_data = data.get(name) |
|
if not preset_data: |
|
gr.Info(f"Supermerger: Preset {name} not found.") |
|
return [gr.update(value = x) for x in ["alpha","","none","","none",""]] |
|
|
|
sets = [("xtype"),"xvalues","ytype","yvalues","ztype","zvalues"] |
|
|
|
return [gr.update(value = preset_data.get(x)) for x in sets] |
|
|
|
def refreshxyzpresets_f(): |
|
return gr.update(choices = get_xyzpreset_keylist()) |
|
|
|
savexyzpreset_b.click(fn=savexyzpreset_f,inputs=[x_type, xgrid, y_type, ygrid, z_type, zgrid,xyzpresets,savexyzpreset_overwrite],outputs=[xyzpresets]) |
|
loadxyzpreset_b.click(fn=loadxyzpreset_f,inputs=[xyzpresets],outputs=[x_type, xgrid, y_type, ygrid, z_type, zgrid]) |
|
deletexyzpreset_b.click(fn=deletexyzpreset_f,inputs=[xyzpresets],outputs=[xyzpresets]) |
|
refreshxyzpresets_b.click(fn=refreshxyzpresets_f,outputs=[xyzpresets]) |
|
|
|
return (supermergerui, "SuperMerger", "supermerger"), |
|
|
|
msearch = [] |
|
mlist=[] |
|
|
|
def loadmetadata(model): |
|
import json |
|
checkpoint_info = sd_models.get_closet_checkpoint_match(model) |
|
if ".safetensors" not in checkpoint_info.filename: return "no metadata(not safetensors)" |
|
sdict = sd_models.read_metadata_from_safetensors(checkpoint_info.filename) |
|
if sdict == {}: return "no metadata" |
|
return json.dumps(sdict,indent=4) |
|
|
|
def load_historyf(data, count=20, reload=False): |
|
filepath = os.path.join(path_root,"mergehistory.csv") |
|
global mlist,msearch |
|
try: |
|
with open(filepath, 'r') as f: |
|
reader = csv.reader(f) |
|
next(reader) |
|
row_count = sum(1 for row in reader) |
|
count = int(count) |
|
|
|
nth = None |
|
if not reload and data is not None and len(data) > 1: |
|
old = data.loc[len(data)-1, 'ID'] |
|
if old != '': |
|
nth = int(old) - count - 1 |
|
|
|
if nth is None: |
|
msearch = [] |
|
mlist = [] |
|
nth = row_count - count |
|
|
|
f.seek(0) |
|
next(reader) |
|
nlist = [raw for n,raw in enumerate(reader, start=1) if n > nth and n <= (nth + count)] |
|
nlist.reverse() |
|
for m in nlist: |
|
msearch.append(" ".join(m)) |
|
maxlen = len(nlist[-1][0]) |
|
for i,m in enumerate(nlist): |
|
nlist[i][0] = nlist[i][0].zfill(maxlen) |
|
mlist += nlist |
|
return mlist |
|
except: |
|
return [["no data","",""],] |
|
|
|
def searchhistory(words,searchmode): |
|
outs =[] |
|
ando = "and" in searchmode |
|
words = words.split(" ") if " " in words else [words] |
|
for i, m in enumerate(msearch): |
|
hit = ando |
|
for w in words: |
|
if ando: |
|
if w not in m:hit = False |
|
else: |
|
if w in m:hit = True |
|
if hit :outs.append(mlist[i]) |
|
|
|
if outs == []:return [["no result","",""],] |
|
return outs |
|
|
|
|
|
|
|
MSETSNUM = 20 |
|
|
|
def reversparams(id): |
|
def selectfromhash(hash): |
|
for model in sd_models.checkpoint_tiles(): |
|
if hash in model: |
|
return model |
|
return "" |
|
try: |
|
idsets = rwmergelog(id = id) |
|
except: |
|
return [gr.update(value = "ERROR: history file could not open"),*[gr.update() for x in range(MSETSNUM)]] |
|
if type(idsets) == str: |
|
print("ERROR") |
|
return [gr.update(value = idsets),*[gr.update() for x in range(MSETSNUM)]] |
|
if idsets[0] == "ID":return [gr.update(value ="ERROR: no history"),*[gr.update() for x in range(MSETSNUM)]] |
|
mgs = idsets[3:] |
|
if mgs[0] == "":mgs[0] = "0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5" |
|
if mgs[1] == "":mgs[1] = "0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2,0.2" |
|
def cutter(text): |
|
text = text.replace("[","").replace("]","").replace("'", "") |
|
return [x.strip() for x in text.split(",") if x != ""] |
|
mgs[2] = selectfromhash(mgs[2]) if len(mgs[2]) > 5 else "" |
|
mgs[3] = selectfromhash(mgs[3]) if len(mgs[3]) > 5 else "" |
|
mgs[4] = selectfromhash(mgs[4]) if len(mgs[4]) > 5 else "" |
|
mgs[7] = mgs[7].split(":")[0] |
|
mgs[8] = mgs[8] =="True" |
|
mgs[10] = cutter(mgs[10]) |
|
mgs[11] = cutter(mgs[11]) |
|
while len(mgs) < MSETSNUM: |
|
mgs.append("") |
|
mgs[13] = "normal" if mgs[13] == "" else mgs[13] |
|
mgs[14] = -1 if mgs[14] == "" else mgs[14] |
|
mgs[16] = 0.3 if mgs[16] == "" else float(mgs[16]) |
|
mgs[17] = "Off" if mgs[17] == "" else mgs[17] |
|
mgs[18] = cutter(mgs[18]) |
|
mgs[18] = [x for x in mgs[18] if x in EXCLUDE_CHOICES + ["print"]] |
|
return [gr.update(value = "setting loaded") ,*[gr.update(value = x) for x in mgs[0:MSETSNUM]]] |
|
|
|
def add_to_seq(seq,maker): |
|
return gr.Textbox.update(value = maker if seq=="" else seq+"\r\n"+maker) |
|
|
|
def load_cachelist(): |
|
text = "" |
|
for x in checkpoints_loaded.keys(): |
|
text = text +"\r\n"+ x.model_name |
|
return text.replace("\r\n","",1) |
|
|
|
def makerand(num): |
|
text = "" |
|
for x in range(int(num)): |
|
text = text +"-1," |
|
text = text[:-1] |
|
return text |
|
|
|
|
|
def showxy(x,y,z): |
|
flags =[False]*7 |
|
t = TYPESEG |
|
txy = t[x] + t[y] + t[z] |
|
if "model" in txy : flags[1] = flags[2] = True |
|
if "pinpoint" in txy : flags[0] = flags[2] = True |
|
if "clude" in txy in txy : flags[5] = True |
|
if "calcmode" in txy : flags[6] = True |
|
if not "none" in t[y] : flags[3] = flags[2] = True |
|
if not "none" in t[z] : flags[4] = flags[2] = True |
|
return [gr.update(visible = x) for x in flags] |
|
|
|
def get_xyzpreset_data(): |
|
try: |
|
with open(xyzpath, 'r') as file: |
|
return json.load(file) |
|
except FileNotFoundError: |
|
with open(xyzpath, 'w') as file: |
|
json.dump({}, file, indent=4) |
|
return {} |
|
|
|
def get_xyzpreset_keylist(): |
|
keys_list = list(get_xyzpreset_data()) |
|
return sorted(keys_list) |
|
|
|
def text2slider(text, isxl=False): |
|
vals = [t.strip() for t in text.split(",")] |
|
vals = [0 if v in "RUX" else v for v in vals] |
|
|
|
if isxl: |
|
j = 0 |
|
ret = [] |
|
for i, v in enumerate(ISXLBLOCK): |
|
if v: |
|
ret.append(gr.update(value = float(vals[j]))) |
|
j += 1 |
|
else: |
|
ret.append(gr.update()) |
|
return ret |
|
|
|
return [gr.update(value = float(v)) for v in vals] |
|
|
|
def slider2text(a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z,presets, preset, isxl): |
|
az = find_preset_by_name(presets, preset) |
|
if az is not None: |
|
if any(element in az for element in RANCHA):return az |
|
numbers = [a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z] |
|
if isxl: |
|
newnums = [] |
|
for i,id in enumerate(BLOCKID[:-1]): |
|
if id in BLOCKIDXLL[:-1]: |
|
newnums.append(numbers[i]) |
|
numbers = newnums |
|
numbers = [str(x) for x in numbers] |
|
return gr.update(value = ",".join(numbers) ) |
|
|
|
def on_change_dd_preset_weight(presets, preset): |
|
weights = find_preset_by_name(presets, preset) |
|
if weights is not None: |
|
return text2slider(weights) |
|
|
|
def on_change_dd_preset_weight_r(presets, preset, ab): |
|
weights = find_preset_by_name(presets, preset) |
|
if weights is not None: |
|
if "none" in ab : return gr.update(),gr.update() |
|
if "alpha" in ab : return gr.update(value = weights),gr.update() |
|
if "beta" in ab : return gr.update(),gr.update(value = weights) |
|
return gr.update(),gr.update() |
|
|
|
RANCHA = ["R","U","X"] |
|
|
|
def tagdicter(presets, rand = False): |
|
presets=presets.splitlines() |
|
wdict={} |
|
for l in presets: |
|
w="" |
|
if ":" in l : |
|
key = l.split(":",1)[0] |
|
w = l.split(":",1)[1] |
|
if "\t" in l: |
|
key = l.split("\t",1)[0] |
|
w = l.split("\t",1)[1] |
|
if len([w for w in w.split(",")]) == 26: |
|
if rand and not any(element in w for element in RANCHA) : continue |
|
wdict[key.strip()]=w |
|
return ",".join(list(wdict.keys())) |
|
|
|
def preset_name_list(presets, rand = False): |
|
return tagdicter(presets, rand).split(",") |
|
|
|
def find_preset_by_name(presets, preset): |
|
presets = presets.splitlines() |
|
for l in presets: |
|
if ":" in l: |
|
key = l.split(":",1)[0] |
|
w = l.split(":",1)[1] |
|
elif "\t" in l: |
|
key = l.split("\t",1)[0] |
|
w = l.split("\t",1)[1] |
|
else: |
|
continue |
|
if key == preset and len([w for w in w.split(",")]) == 26: |
|
return w |
|
|
|
return None |
|
|
|
BLOCKID=["BASE","IN00","IN01","IN02","IN03","IN04","IN05","IN06","IN07","IN08","IN09","IN10","IN11","M00","OUT00","OUT01","OUT02","OUT03","OUT04","OUT05","OUT06","OUT07","OUT08","OUT09","OUT10","OUT11","Not Merge"] |
|
BLOCKIDXL=['BASE', 'IN0', 'IN1', 'IN2', 'IN3', 'IN4', 'IN5', 'IN6', 'IN7', 'IN8', 'M', 'OUT0', 'OUT1', 'OUT2', 'OUT3', 'OUT4', 'OUT5', 'OUT6', 'OUT7', 'OUT8', 'VAE'] |
|
BLOCKIDXLL=['BASE', 'IN00', 'IN01', 'IN02', 'IN03', 'IN04', 'IN05', 'IN06', 'IN07', 'IN08', 'M00', 'OUT00', 'OUT01', 'OUT02', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'VAE'] |
|
ISXLBLOCK=[True, True, True, True, True, True, True, True, True, True, False, False, False, True, True, True, True, True, True, True, True, True, True, False, False, False] |
|
|
|
def modeltype(sd): |
|
if "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in sd.keys(): |
|
modeltype = "XL" |
|
else: |
|
modeltype = "1.X or 2.X" |
|
return modeltype |
|
|
|
def loadkeys(model_a, lora): |
|
if lora: |
|
import lora |
|
sd = sd_models.read_state_dict(lora.available_loras[model_a].filename,"cpu") |
|
else: |
|
sd = loadmodel(model_a) |
|
keys = [] |
|
mtype = modeltype(sd) |
|
if lora: |
|
for i, key in enumerate(sd.keys()): |
|
keys.append([i,"LoRA",key,sd[key].shape]) |
|
else: |
|
for i, key in enumerate(sd.keys()): |
|
keys.append([i,blockfromkey(key,mtype),key,sd[key].shape]) |
|
|
|
return keys |
|
|
|
def loadmodel(model): |
|
checkpoint_info = sd_models.get_closet_checkpoint_match(model) |
|
sd = sd_models.read_state_dict(checkpoint_info.filename,"cpu") |
|
return sd |
|
|
|
ADDRAND = "\n\ |
|
ALL_R R,R,R,R,R,R,R,R,R,R,R,R,R,R,R,R,R,R,R,R,R,R,R,R,R,R\n\ |
|
ALL_U U,U,U,U,U,U,U,U,U,U,U,U,U,U,U,U,U,U,U,U,U,U,U,U,U,U\n\ |
|
ALL_X X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X,X\n\ |
|
" |
|
|
|
def calccosinedif(model_a,model_b,mode,settings,include,calc): |
|
inc = " ".join(include) |
|
settings = " ".join(settings) |
|
a, b = loadmodel(model_a), loadmodel(model_b) |
|
name = filenamecutter(model_a) + "-" + filenamecutter(model_b) |
|
cosine_similarities = [] |
|
blocksim = {} |
|
blockvals = [] |
|
attn2 = {} |
|
isxl = "XL" == modeltype(a) |
|
blockids = BLOCKIDXLL if isxl else BLOCKID |
|
for bl in blockids: |
|
blocksim[bl] = [] |
|
blocksim["VAE"] = [] |
|
|
|
if "ASim" in mode: |
|
result = asimilarity(a,b,isxl) |
|
if len(settings) > 1: savecalc(result,name,settings,True,"Asim") |
|
del a ,b |
|
gc.collect() |
|
return result |
|
else: |
|
for key in tqdm(a.keys(), desc="Calculating cosine similarity"): |
|
block = None |
|
if blockfromkey(key,isxl) == "Not Merge": continue |
|
if "model_ema" in key: continue |
|
if "model" not in key:continue |
|
if "first_stage_model" in key and not ("VAE" in inc): |
|
continue |
|
elif "first_stage_model" in key and "VAE" in inc: |
|
block = "VAE" |
|
if "diffusion_model" in key and not ("U-Net" in inc): continue |
|
if "encoder" in key and not ("encoder" in inc): continue |
|
if key in b and a[key].size() == b[key].size(): |
|
a_flat = a[key].view(-1).to(torch.float32) |
|
b_flat = b[key].view(-1).to(torch.float32) |
|
simab = torch.nn.functional.cosine_similarity(a_flat.unsqueeze(0), b_flat.unsqueeze(0)) |
|
if block is None: block,blocks26 = blockfromkey(key,isxl) |
|
if block =="Not Merge" :continue |
|
cosine_similarities.append([block, key, round(simab.item()*100,3)]) |
|
blocksim[blocks26].append(round(simab.item()*100,3)) |
|
if "attn2.to_out.0.weight" in key: attn2[block] = round(simab.item()*100,3) |
|
|
|
for bl in blockids: |
|
val = None |
|
if bl == "Not Merge": continue |
|
if bl not in blocksim.keys():continue |
|
if blocksim[bl] == []: continue |
|
if "Mean" in calc: |
|
val = mean(blocksim[bl]) |
|
elif "Min" in calc: |
|
val = min(blocksim[bl]) |
|
else: |
|
if bl in attn2.keys():val = attn2[bl] |
|
if val:blockvals.append([bl,"",round(val,3)]) |
|
if mode != "Element": cosine_similarities.insert(0,[bl,"",round(mean(blocksim[bl]),3)]) |
|
|
|
if mode == "Block": |
|
if len(settings) > 1: savecalc(blockvals,name,settings,True,"Blocks") |
|
del a ,b |
|
gc.collect() |
|
return blockvals |
|
else: |
|
if len(settings) > 1: savecalc(cosine_similarities,name,settings,False,"Elements",) |
|
del a ,b |
|
gc.collect() |
|
return cosine_similarities |
|
|
|
def savecalc(data,name,settings,blocks,add): |
|
name = name + "_" + add |
|
csvpath = os.path.join(path_root,f"{name}.csv") |
|
txtpath = os.path.join(path_root,f"{name}.txt") |
|
|
|
txt = "" |
|
for row in data: |
|
row = [str(r) for r in row] |
|
txt = txt + ",".join(row)+"\n" |
|
if blocks: txt = txt.replace(",,",",") |
|
|
|
if "txt" in settings: |
|
with open(txtpath, 'w+') as f: |
|
f.writelines(txt) |
|
print("file saved to ",txtpath) |
|
if "csv" in settings: |
|
with open(csvpath, 'w+') as f: |
|
f.writelines(txt) |
|
print("file saved to ",csvpath) |
|
|
|
|
|
|
|
def cal_cross_attn(to_q, to_k, to_v, rand_input): |
|
hidden_dim, embed_dim = to_q.shape |
|
attn_to_q = nn.Linear(hidden_dim, embed_dim, bias=False) |
|
attn_to_k = nn.Linear(hidden_dim, embed_dim, bias=False) |
|
attn_to_v = nn.Linear(hidden_dim, embed_dim, bias=False) |
|
attn_to_q.load_state_dict({"weight": to_q}) |
|
attn_to_k.load_state_dict({"weight": to_k}) |
|
attn_to_v.load_state_dict({"weight": to_v}) |
|
|
|
return torch.einsum( |
|
"ik, jk -> ik", |
|
F.softmax(torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)), dim=-1), |
|
attn_to_v(rand_input) |
|
) |
|
|
|
def eval(model, n, input, block): |
|
qk = f"model.diffusion_model.{block}_block{n}.1.transformer_blocks.0.attn1.to_q.weight" |
|
uk = f"model.diffusion_model.{block}_block{n}.1.transformer_blocks.0.attn1.to_k.weight" |
|
vk = f"model.diffusion_model.{block}_block{n}.1.transformer_blocks.0.attn1.to_v.weight" |
|
atoq, atok, atov = model[qk], model[uk], model[vk] |
|
|
|
attn = cal_cross_attn(atoq, atok, atov, input) |
|
return attn |
|
|
|
ATTN1BLOCKS = [[1,"input"],[2,"input"],[4,"input"],[5,"input"],[7,"input"],[8,"input"],["","middle"], |
|
[3,"output"],[4,"output"],[5,"output"],[6,"output"],[7,"output"],[8,"output"],[9,"output"],[10,"output"],[11,"output"]] |
|
|
|
def asimilarity(model_a,model_b,mtype): |
|
torch.manual_seed(2244096) |
|
sims = [] |
|
|
|
for nblock in tqdm(ATTN1BLOCKS, desc="Calculating cosine similarity"): |
|
n,block = nblock[0],nblock[1] |
|
if n != "": n = f"s.{n}" |
|
key = f"model.diffusion_model.{block}_block{n}.1.transformer_blocks.0.attn1.to_q.weight" |
|
|
|
hidden_dim, embed_dim = model_a[key].shape |
|
rand_input = torch.randn([embed_dim, hidden_dim]) |
|
|
|
attn_a = eval(model_a, n, rand_input, block) |
|
attn_b = eval(model_b, n, rand_input, block) |
|
|
|
sim = torch.mean(torch.cosine_similarity(attn_a, attn_b)) |
|
sims.append([blockfromkey(key,mtype),"",round(sim.item() * 100,3)]) |
|
|
|
return sims |
|
|
|
CONFIGS = ["prompt","neg_prompt","Steps","Sampling method","CFG scale","Seed","Width","Height","Batch size","Upscaler","Hires steps","Denoising strength","Upscale by"] |
|
RESETVALS = ["","",0," ",0,0,0,0,1,"Latent",0,0.7,2] |
|
|
|
def configdealer(prompt,neg_prompt,steps,sampler,cfg,seed,width,height,batch_size, |
|
hrupscaler,hr2ndsteps,denois_str,hr_scale,reset): |
|
|
|
data = [prompt,neg_prompt,steps,sampler,cfg,seed,width,height,batch_size, |
|
hrupscaler,hr2ndsteps,denois_str,hr_scale] |
|
|
|
current_directory = os.getcwd() |
|
jsonpath = os.path.join(current_directory,"ui-config.json") |
|
print(jsonpath) |
|
|
|
with open(jsonpath, 'r') as file: |
|
json_data = json.load(file) |
|
|
|
for name,men,default in zip(CONFIGS,data,RESETVALS): |
|
key = f"supermerger/{name}/value" |
|
json_data[key] = default if reset else men |
|
|
|
with open(jsonpath, 'w') as file: |
|
json.dump(json_data, file, indent=4) |
|
|
|
sorted_output = [] |
|
|
|
def encodetexts(exclude): |
|
isxl = hasattr(shared.sd_model,"conditioner") |
|
model = shared.sd_model.conditioner.embedders[0] if isxl else shared.sd_model.cond_stage_model |
|
encoder = model.encode_with_transformers |
|
tokenizer = model.tokenizer |
|
vocab = tokenizer.get_vocab() |
|
byte_decoder = tokenizer.byte_decoder |
|
|
|
batch = 500 |
|
|
|
b_texts = [list(vocab.items())[i:i + batch] for i in range(0, len(vocab), batch)] |
|
|
|
output = [] |
|
|
|
for texts in tqdm(b_texts): |
|
batch = [] |
|
words = [] |
|
for word, idx in texts: |
|
tokens = [model.id_start, idx, model.id_end] + [model.id_end] * 74 |
|
batch.append(tokens) |
|
words.append((idx, word)) |
|
|
|
embedding = encoder(torch.IntTensor(batch).to("cuda"))[:,1,:] |
|
embedding = embedding.to('cuda') |
|
emb_norms = torch.linalg.vector_norm(embedding, dim=-1) |
|
|
|
for i, (word, token) in enumerate(texts): |
|
try: |
|
word = bytearray([byte_decoder[x] for x in word]).decode("utf-8") |
|
except UnicodeDecodeError: |
|
pass |
|
if exclude: |
|
if has_alphanumeric(word) : output.append([word,token,emb_norms[i].item()]) |
|
else: |
|
output.append([word,token,emb_norms[i].item()]) |
|
|
|
output = sorted(output, key=lambda x: x[2], reverse=True) |
|
for i in range(len(output)): |
|
output[i].insert(0,i) |
|
|
|
global sorted_output |
|
sorted_output = output |
|
|
|
return output[:1000] |
|
|
|
def pickupencode(texts): |
|
wordlist = [x[1] for x in sorted_output] |
|
texts = texts.split(",") |
|
output = [] |
|
for text in texts: |
|
if text in wordlist: |
|
output.append(sorted_output[wordlist.index(text)]) |
|
if text+"</w>" in wordlist: |
|
output.append(sorted_output[wordlist.index(text+"</w>")]) |
|
return output |
|
|
|
def has_alphanumeric(text): |
|
pattern = re.compile(r'[a-zA-Z0-9!@#$%^&*()_+{}\[\]:;"\'<>,.?/\|\\]') |
|
return bool(pattern.search(text.replace("</w>",""))) |
|
|
|
if __package__ == "supermerger": |
|
script_callbacks.on_ui_tabs(on_ui_tabs) |
|
|