text2face / gen_w_lora.py
michaeltrs
Re-add files with LFS tracking
cb25637
raw
history blame
3.96 kB
from diffusers import StableDiffusionPipeline
import torch
from transformers import CLIPTextModel
pipe_id = "stabilityai/stable-diffusion-2-1"
# checkpoint_dir = "/home/michaila/Projects/github/diffusers/examples/text_to_image/sd-2-1-train-finetune-LoRA-test5/checkpoint-2800/"
# checkpoint_dir = "/home/michaila/Projects/github/diffusers/examples/text_to_image/sd-2-1-train-finetune-wText-LoRA-lr1e5-r8/checkpoint-15500/"
# checkpoint_dir = '/home/michaila/Projects/github/diffusers/examples/text_to_image/sd-2-1-train-finetune-LoRA-ffhq-easyportr-2/checkpoint-100/'
# checkpoint_dir = "/home/michaila/Projects/github/diffusers/examples/text_to_image/sd-2-1-train-finetune-wText-LoRA-EasyPortait_lr1e5-r8/checkpoint-22000/"
# checkpoint_dir = "/home/michaila/Projects/github/diffusers/examples/text_to_image/sd-2-1-train-finetune-wText-LoRA-FFHQ-EasyPortrait_lr1e5-r8_768/checkpoint-30000/"
checkpoint_dir = "checkpoints/lora30k"
pipe = StableDiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")
# pipe.load_lora_weights("/home/michaila/Projects/github/diffusers/examples/text_to_image/sd-2-1-train-finetune-LoRA-ffhq-easyportr-2/checkpoint-500", weight_name="pytorch_lora_weights.safetensors") # , adapter_name="toy")
# pipe.load_lora_weights(checkpoint_dir, weight_name="pytorch_lora_weights.safetensors") # , adapter_name="toy")
# pipe.text_encoder.load_lora_weights(checkpoint_dir, weight_name="pytorch_lora_weights.safetensors") # , adapter_name="toy")
state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(
# Path to my trained lora output_dir
checkpoint_dir,
weight_name="pytorch_lora_weights.safetensors"
)
pipe.load_lora_into_unet(state_dict, network_alphas, pipe.unet, adapter_name='test_lora')
pipe.load_lora_into_text_encoder(state_dict, network_alphas, pipe.text_encoder, adapter_name='test_lora')
pipe.set_adapters(["test_lora"], adapter_weights=[1.0])
# pipe.set_adapters(["text_lora"], adapter_weights=[1.0])
# def generate(prompt, name='example', seed=1):
# lora_scale = 1.0
# image = pipe(
# prompt, num_inference_steps=50, cross_attention_kwargs={"scale": lora_scale}, generator=torch.manual_seed(seed)
# ).images[0]
# image.save(f"{checkpoint_dir}/{name}.png")
def generate(prompt, negprompt='', steps=50, name='example', seed=1):
lora_scale = 1.0
image = pipe(
prompt, negative_prompt=negprompt, num_inference_steps=steps, cross_attention_kwargs={"scale": lora_scale}, generator=torch.manual_seed(seed)
).images[0]
image.save(f"{checkpoint_dir}/{'_'.join(prompt.replace('.', ' ').split(' '))}.png")
# prompt = "a color photo of a 30 year old man with a sad expression, beard, very little hair, a slightly open mouth, his eyes look directly at the camera."
# prompt = "a color photo of a 30 year old man with a sad expression, beard, very little hair, a fully open mouth, his eyes look directly at the camera."
# prompt = "a 50 year old asian woman with a neutral expression, little hair, a slightly open mouth and visible teeth."
# prompt = "a 50 year old asian woman smiling."
# prompt = "an 20 year old white man with slightly open mouth, visible teeth. His tongue is out, clearly visible."
# prompt = "A baby with fully closed mouth."
# prompt = "A 25 year old female with long, blonde hair, green eyes and neutral expression looking at the camera."
# prompt = "A black african female with long, straight blond hair and happy expression."
# prompt = "A black female with blonde hair."
# prompt = 'An attractive blond male'
# prompt = 'A happy 55 year old black woman with a hat, sunglasses, earrings and visible teeth. High resolution, sharp image.' #at the camera.'
prompt = 'A happy 25 year old woman with blond hair. Her head is looking significantly to the right.'
negprompt = '' #'bad teeth'
# generate(prompt, name='example', seed=4)
generate(prompt, negprompt=negprompt, steps=50, name='example', seed=200)