Regd. your implementation
Hello! Thanks for trying out our model and finetuning/distilling upon it! We're planning on releasing our code once we get our technical paper out. Might I know more about what your loss strategy exactly was and your experience with the outputs?
Also, nice idea with the cosine similarities! Might I ask what you compared and how similar they are?
I attempted a merge based on the following formula.
NewModel = SSD-1B + w * ( FINETUNEDSDXL - SDXL)
However, since the correspondence was not known, especially for layers with reduced transformer_depth, a comparison was made in terms of cosine similarity.
import torch
from torch.nn.functional import cosine_similarity
import numpy as np
sdxl = torch.load("sdxl.pt") # diffusers unet
ssd = torch.load("ssd.pt")
sdxl_keys = [key.replace("to_q.weight", "") for key in sdxl.keys() if "to_q.weight" in key]
ssd_keys = [key.replace("to_q.weight", "") for key in ssd.keys() if "to_q.weight" in key]
ssd2sdxl = {
"to_q": {},
"to_k": {},
"to_v": {},
"to_out.0": {}
}
for to_x in ssd2sdxl.keys():
for ssd_key in ssd_keys:
sims = []
target = ssd[ssd_key+to_x+".weight"]
for sdxl_key in sdxl_keys:
if target.shape == sdxl[sdxl_key+to_x+".weight"].shape:
sims.append(cosine_similarity(target.view(1,-1), sdxl[sdxl_key+to_x+".weight"].view(1,-1)).item())
else:
sims.append(-100)
ssd2sdxl[to_x][ssd_key] = sdxl_keys[np.array(sims).argmax()]
print(ssd2sdxl["to_q"] == ssd2sdxl["to_k"] == ssd2sdxl["to_v"] == ssd2sdxl["to_out.0"]) # True
ssd2sdxl["to_q"]
The output is here
The results for up_blocks.0.attentions.2
were odd, so I changed them manually.
Since w=1 had little effect and w=1.5 resulted in a coarser image, w=1.3 was used.
To further improve accuracy, I have distilled the model down to the original model. The only loss is the squared error of the final output.
The data set consists of 30,000 actual images.
I don't know about a detailed performance comparison, but I believe this method is superior to distilling or fine-tuning from scratch.
Ah, nice observations. Arigato!