RuntimeWarning: invalid value encountered in cast images = (images * 255).round().astype("uint8")

#11
by balalala - opened

I changed torch.bfloat16 in the sample code to torch.float16,Then I ran the code.
output image is pure black. And warning,

diffusers/lib/python3.10/site-packages/diffusers/image_processor.py:112: RuntimeWarning: invalid value encountered in cast images = (images * 255).round().astype("uint8")

Is there a solution to run on torch.float16?

my code is as follows

import torch from diffusers import FluxPipeline 
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16) 
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power 
prompt = "A cat holding a sign that says hello world" 
image = pipe( prompt, guidance_scale=0.0, output_type ="pil", 
       num_inference_steps=4, max_sequence_length=256, 
       generator=torch.Generator("cpu").manual_seed(0) ).images[0] 
image.save("flux-schnell.png")

made a PR to fix this issue: https://github.com/huggingface/diffusers/pull/9097
if it isn't merged yet, you can install my diffusers fork:
!pip install -U git+https://github.com/latentCall145/diffusers.git@flux-fp16-fix

edit: it's merged into the main diffusers fork:
!pip install -U git+https://github.com/huggingface/diffusers.git

Diffusers Code

from diffusers import FluxPipeline
import torch

ckpt_id = "black-forest-labs/FLUX.1-schnell"
prompt = [
    "an astronaut riding a horse on mars",
    # more prompts here
]
height, width = 1024, 1024

# denoising
pipe = FluxPipeline.from_pretrained(
    ckpt_id,
    torch_dtype=torch.bfloat16, # setting this to torch.float16 is much slower than casting later
)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()
pipe.to(torch.half) # now we cast (enable_sequential_cpu_offload allows weights to be casted to fp16 as needed instead of all weights at once, saving ~30 GB CPU RAM for this model)

image = pipe(
    prompt,
    num_inference_steps=1,
    guidance_scale=0.0,
    height=height,
    width=width,
).images[0]

import matplotlib.pyplot as plt
plt.imshow(image)
plt.show()

Sign up or log in to comment