File size: 2,344 Bytes
8206b09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import unittest
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from diffusers.utils.outputs import BaseOutput
@dataclass
class CustomOutput(BaseOutput):
images: Union[List[PIL.Image.Image], np.ndarray]
class ConfigTester(unittest.TestCase):
def test_outputs_single_attribute(self):
outputs = CustomOutput(images=np.random.rand(1, 3, 4, 4))
# check every way of getting the attribute
assert isinstance(outputs.images, np.ndarray)
assert outputs.images.shape == (1, 3, 4, 4)
assert isinstance(outputs["images"], np.ndarray)
assert outputs["images"].shape == (1, 3, 4, 4)
assert isinstance(outputs[0], np.ndarray)
assert outputs[0].shape == (1, 3, 4, 4)
# test with a non-tensor attribute
outputs = CustomOutput(images=[PIL.Image.new("RGB", (4, 4))])
# check every way of getting the attribute
assert isinstance(outputs.images, list)
assert isinstance(outputs.images[0], PIL.Image.Image)
assert isinstance(outputs["images"], list)
assert isinstance(outputs["images"][0], PIL.Image.Image)
assert isinstance(outputs[0], list)
assert isinstance(outputs[0][0], PIL.Image.Image)
def test_outputs_dict_init(self):
# test output reinitialization with a `dict` for compatibility with `accelerate`
outputs = CustomOutput({"images": np.random.rand(1, 3, 4, 4)})
# check every way of getting the attribute
assert isinstance(outputs.images, np.ndarray)
assert outputs.images.shape == (1, 3, 4, 4)
assert isinstance(outputs["images"], np.ndarray)
assert outputs["images"].shape == (1, 3, 4, 4)
assert isinstance(outputs[0], np.ndarray)
assert outputs[0].shape == (1, 3, 4, 4)
# test with a non-tensor attribute
outputs = CustomOutput({"images": [PIL.Image.new("RGB", (4, 4))]})
# check every way of getting the attribute
assert isinstance(outputs.images, list)
assert isinstance(outputs.images[0], PIL.Image.Image)
assert isinstance(outputs["images"], list)
assert isinstance(outputs["images"][0], PIL.Image.Image)
assert isinstance(outputs[0], list)
assert isinstance(outputs[0][0], PIL.Image.Image)
|