How can I make this model run faster?

#78
by Sengil - opened

How can I make this model run faster? It's working too slow...

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) # can replace schnell with dev

# to run on low vram GPUs (i.e. between 4 and 32 GB VRAM)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()

pipe.to(torch.float16) # casting here instead of in the pipeline constructor because doing so in the constructor loads all models into CPU memory at once

prompt = f"{ai_msg_prompt.content}"
out = pipe(
    prompt=prompt,
    guidance_scale=0.,
    height=1024,
    width=1024,
    num_inference_steps=4,
    max_sequence_length=256,
).images[0]
out.show()
out.save("image.png")

Just from the code seen here you can save some inference time by using FP8. The reduction in quality is well worth it for the increase in speed.

thank you so much @colinw2292

Sign up or log in to comment