Spaces:
Runtime error
Runtime error
import os | |
from random import choice | |
from string import ascii_uppercase | |
from PIL import Image | |
from tqdm import tqdm | |
from scripts.latent_editor_wrapper import LatentEditorWrapper | |
from evaluation.experiment_setting_creator import ExperimentRunner | |
import torch | |
from configs import paths_config, hyperparameters, evaluation_config | |
from utils.log_utils import save_concat_image, save_single_image | |
from utils.models_utils import load_tuned_G | |
class EditComparison: | |
def __init__(self, save_single_images, save_concatenated_images, run_id): | |
self.run_id = run_id | |
self.experiment_creator = ExperimentRunner(run_id) | |
self.save_single_images = save_single_images | |
self.save_concatenated_images = save_concatenated_images | |
self.latent_editor = LatentEditorWrapper() | |
def save_reconstruction_images(self, image_latents, new_inv_image_latent, new_G, target_image): | |
if self.save_concatenated_images: | |
save_concat_image(self.concat_base_dir, image_latents, new_inv_image_latent, new_G, | |
self.experiment_creator.old_G, | |
'rec', | |
target_image) | |
if self.save_single_images: | |
save_single_image(self.single_base_dir, new_inv_image_latent, new_G, 'rec') | |
target_image.save(f'{self.single_base_dir}/Original.jpg') | |
def create_output_dirs(self, full_image_name): | |
output_base_dir_path = f'{paths_config.experiments_output_dir}/{paths_config.input_data_id}/{self.run_id}/{full_image_name}' | |
os.makedirs(output_base_dir_path, exist_ok=True) | |
self.concat_base_dir = f'{output_base_dir_path}/concat_images' | |
self.single_base_dir = f'{output_base_dir_path}/single_images' | |
os.makedirs(self.concat_base_dir, exist_ok=True) | |
os.makedirs(self.single_base_dir, exist_ok=True) | |
def get_image_latent_codes(self, image_name): | |
image_latents = [] | |
for method in evaluation_config.evaluated_methods: | |
if method == 'SG2': | |
image_latents.append(torch.load( | |
f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/' | |
f'{paths_config.pti_results_keyword}/{image_name}/0.pt')) | |
else: | |
image_latents.append(torch.load( | |
f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{method}/{image_name}/0.pt')) | |
new_inv_image_latent = torch.load( | |
f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{paths_config.pti_results_keyword}/{image_name}/0.pt') | |
return image_latents, new_inv_image_latent | |
def save_interfacegan_edits(self, image_latents, new_inv_image_latent, interfacegan_factors, new_G, target_image): | |
new_w_inv_edits = self.latent_editor.get_single_interface_gan_edits(new_inv_image_latent, | |
interfacegan_factors) | |
inv_edits = [] | |
for latent in image_latents: | |
inv_edits.append(self.latent_editor.get_single_interface_gan_edits(latent, interfacegan_factors)) | |
for direction, edits in new_w_inv_edits.items(): | |
for factor, edit_tensor in edits.items(): | |
if self.save_concatenated_images: | |
save_concat_image(self.concat_base_dir, [edits[direction][factor] for edits in inv_edits], | |
new_w_inv_edits[direction][factor], | |
new_G, | |
self.experiment_creator.old_G, | |
f'{direction}_{factor}', target_image) | |
if self.save_single_images: | |
save_single_image(self.single_base_dir, new_w_inv_edits[direction][factor], new_G, | |
f'{direction}_{factor}') | |
def save_ganspace_edits(self, image_latents, new_inv_image_latent, factors, new_G, target_image): | |
new_w_inv_edits = self.latent_editor.get_single_ganspace_edits(new_inv_image_latent, factors) | |
inv_edits = [] | |
for latent in image_latents: | |
inv_edits.append(self.latent_editor.get_single_ganspace_edits(latent, factors)) | |
for idx in range(len(new_w_inv_edits)): | |
if self.save_concatenated_images: | |
save_concat_image(self.concat_base_dir, [edit[idx] for edit in inv_edits], new_w_inv_edits[idx], | |
new_G, | |
self.experiment_creator.old_G, | |
f'ganspace_{idx}', target_image) | |
if self.save_single_images: | |
save_single_image(self.single_base_dir, new_w_inv_edits[idx], new_G, | |
f'ganspace_{idx}') | |
def run_experiment(self, run_pt, create_other_latents, use_multi_id_training, use_wandb=False): | |
images_counter = 0 | |
new_G = None | |
interfacegan_factors = [val / 2 for val in range(-6, 7) if val != 0] | |
ganspace_factors = range(-20, 25, 5) | |
self.experiment_creator.run_experiment(run_pt, create_other_latents, use_multi_id_training, use_wandb) | |
if use_multi_id_training: | |
new_G = load_tuned_G(self.run_id, paths_config.multi_id_model_type) | |
for idx, image_path in tqdm(enumerate(self.experiment_creator.images_paths), | |
total=len(self.experiment_creator.images_paths)): | |
if images_counter >= hyperparameters.max_images_to_invert: | |
break | |
image_name = image_path.split('.')[0].split('/')[-1] | |
target_image = Image.open(self.experiment_creator.target_paths[idx]) | |
if not use_multi_id_training: | |
new_G = load_tuned_G(self.run_id, image_name) | |
image_latents, new_inv_image_latent = self.get_image_latent_codes(image_name) | |
self.create_output_dirs(image_name) | |
self.save_reconstruction_images(image_latents, new_inv_image_latent, new_G, target_image) | |
self.save_interfacegan_edits(image_latents, new_inv_image_latent, interfacegan_factors, new_G, target_image) | |
self.save_ganspace_edits(image_latents, new_inv_image_latent, ganspace_factors, new_G, target_image) | |
target_image.close() | |
torch.cuda.empty_cache() | |
images_counter += 1 | |
def run_pti_and_full_edit(iid): | |
evaluation_config.evaluated_methods = ['SG2Plus', 'e4e', 'SG2'] | |
edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True, | |
run_id=f'{paths_config.input_data_id}_pti_full_edit_{iid}') | |
edit_figure_creator.run_experiment(True, True, use_multi_id_training=False, use_wandb=False) | |
def pti_no_comparison(iid): | |
evaluation_config.evaluated_methods = [] | |
edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True, | |
run_id=f'{paths_config.input_data_id}_pti_no_comparison_{iid}') | |
edit_figure_creator.run_experiment(True, False, use_multi_id_training=False, use_wandb=False) | |
def edits_for_existed_experiment(run_id): | |
evaluation_config.evaluated_methods = ['SG2Plus', 'e4e', 'SG2'] | |
edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True, | |
run_id=run_id) | |
edit_figure_creator.run_experiment(False, True, use_multi_id_training=False, use_wandb=False) | |
if __name__ == '__main__': | |
iid = ''.join(choice(ascii_uppercase) for i in range(7)) | |
pti_no_comparison(iid) | |