Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
import ast | |
""" | |
TRAIN FUNCTION DEFINITION: | |
train(model: StableDiffusionPipeline, | |
projection_matrices: list[size=L](nn.Module), | |
og_matrices: list[size=L](nn.Module), | |
contexts: list[size=N](torch.tensor[size=MAX_LEN,...]), | |
valuess: list[size=N](list[size=L](torch.tensor[size=MAX_LEN,...])), | |
old_texts: list[size=N](str), | |
new_texts: list[size=N](str), | |
**kwargs) | |
where L is the number of matrices to edit, and N is the number of sentences to train on (batch size). | |
PARAMS: | |
model: the model to use. | |
projection_matrices: list of projection matrices to edit from the model. | |
og_matrices: list of original values for the projection matrices. detached from the model. | |
contexts: list of context vectors (inputs to the matrices) to edit. | |
valuess: list of results from all matrices for each context vector. | |
old_texts: list of sentences to be edited. | |
new_texts: list of target sentences to be aimed at. | |
**kwargs: additional command line arguments. | |
TRAIN_FUNC_DICT defined at the bottom of the file. | |
""" | |
def baseline_train(model, projection_matrices, og_matrices, contexts, valuess, old_texts, new_texts): | |
return None | |
def train_closed_form(ldm_stable, projection_matrices, og_matrices, contexts, valuess, old_texts, | |
new_texts, layers_to_edit=None, lamb=0.1): | |
layers_to_edit = ast.literal_eval(layers_to_edit) if type(layers_to_edit) == str else layers_to_edit | |
lamb = ast.literal_eval(lamb) if type(lamb) == str else lamb | |
for layer_num in range(len(projection_matrices)): | |
if (layers_to_edit is not None) and (layer_num not in layers_to_edit): | |
continue | |
with torch.no_grad(): | |
#mat1 = \lambda W + \sum{v k^T} | |
mat1 = lamb * projection_matrices[layer_num].weight | |
#mat2 = \lambda I + \sum{k k^T} | |
mat2 = lamb * torch.eye(projection_matrices[layer_num].weight.shape[1], device = projection_matrices[layer_num].weight.device) | |
#aggregate sums for mat1, mat2 | |
for context, values in zip(contexts, valuess): | |
context_vector = context.reshape(context.shape[0], context.shape[1], 1) | |
context_vector_T = context.reshape(context.shape[0], 1, context.shape[1]) | |
value_vector = values[layer_num].reshape(values[layer_num].shape[0], values[layer_num].shape[1], 1) | |
for_mat1 = (value_vector @ context_vector_T).sum(dim=0) | |
for_mat2 = (context_vector @ context_vector_T).sum(dim=0) | |
mat1 += for_mat1 | |
mat2 += for_mat2 | |
#update projection matrix | |
projection_matrices[layer_num].weight = torch.nn.Parameter(mat1 @ torch.inverse(mat2)) | |
TRAIN_FUNC_DICT = { | |
"baseline": baseline_train, | |
"train_closed_form": train_closed_form, | |
} |