# coding=utf-8 # Copyright 2023 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tempfile import unittest from diffusers import ( DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, PNDMScheduler, logging, ) from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils.testing_utils import CaptureLogger class SampleObject(ConfigMixin): config_name = "config.json" @register_to_config def __init__( self, a=2, b=5, c=(2, 5), d="for diffusion", e=[1, 3], ): pass class SampleObject2(ConfigMixin): config_name = "config.json" @register_to_config def __init__( self, a=2, b=5, c=(2, 5), d="for diffusion", f=[1, 3], ): pass class SampleObject3(ConfigMixin): config_name = "config.json" @register_to_config def __init__( self, a=2, b=5, c=(2, 5), d="for diffusion", e=[1, 3], f=[1, 3], ): pass class SampleObject4(ConfigMixin): config_name = "config.json" @register_to_config def __init__( self, a=2, b=5, c=(2, 5), d="for diffusion", e=[1, 5], f=[5, 4], ): pass class ConfigTester(unittest.TestCase): def test_load_not_from_mixin(self): with self.assertRaises(ValueError): ConfigMixin.load_config("dummy_path") def test_register_to_config(self): obj = SampleObject() config = obj.config assert config["a"] == 2 assert config["b"] == 5 assert config["c"] == (2, 5) assert config["d"] == "for diffusion" assert config["e"] == [1, 3] # init ignore private arguments obj = SampleObject(_name_or_path="lalala") config = obj.config assert config["a"] == 2 assert config["b"] == 5 assert config["c"] == (2, 5) assert config["d"] == "for diffusion" assert config["e"] == [1, 3] # can override default obj = SampleObject(c=6) config = obj.config assert config["a"] == 2 assert config["b"] == 5 assert config["c"] == 6 assert config["d"] == "for diffusion" assert config["e"] == [1, 3] # can use positional arguments. obj = SampleObject(1, c=6) config = obj.config assert config["a"] == 1 assert config["b"] == 5 assert config["c"] == 6 assert config["d"] == "for diffusion" assert config["e"] == [1, 3] def test_save_load(self): obj = SampleObject() config = obj.config assert config["a"] == 2 assert config["b"] == 5 assert config["c"] == (2, 5) assert config["d"] == "for diffusion" assert config["e"] == [1, 3] with tempfile.TemporaryDirectory() as tmpdirname: obj.save_config(tmpdirname) new_obj = SampleObject.from_config(SampleObject.load_config(tmpdirname)) new_config = new_obj.config # unfreeze configs config = dict(config) new_config = dict(new_config) assert config.pop("c") == (2, 5) # instantiated as tuple assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json config.pop("_use_default_values") assert config == new_config def test_load_ddim_from_pndm(self): logger = logging.get_logger("diffusers.configuration_utils") # 30 for warning logger.setLevel(30) with CaptureLogger(logger) as cap_logger: ddim = DDIMScheduler.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" ) assert ddim.__class__ == DDIMScheduler # no warning should be thrown assert cap_logger.out == "" def test_load_euler_from_pndm(self): logger = logging.get_logger("diffusers.configuration_utils") # 30 for warning logger.setLevel(30) with CaptureLogger(logger) as cap_logger: euler = EulerDiscreteScheduler.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" ) assert euler.__class__ == EulerDiscreteScheduler # no warning should be thrown assert cap_logger.out == "" def test_load_euler_ancestral_from_pndm(self): logger = logging.get_logger("diffusers.configuration_utils") # 30 for warning logger.setLevel(30) with CaptureLogger(logger) as cap_logger: euler = EulerAncestralDiscreteScheduler.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" ) assert euler.__class__ == EulerAncestralDiscreteScheduler # no warning should be thrown assert cap_logger.out == "" def test_load_pndm(self): logger = logging.get_logger("diffusers.configuration_utils") # 30 for warning logger.setLevel(30) with CaptureLogger(logger) as cap_logger: pndm = PNDMScheduler.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" ) assert pndm.__class__ == PNDMScheduler # no warning should be thrown assert cap_logger.out == "" def test_overwrite_config_on_load(self): logger = logging.get_logger("diffusers.configuration_utils") # 30 for warning logger.setLevel(30) with CaptureLogger(logger) as cap_logger: ddpm = DDPMScheduler.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler", prediction_type="sample", beta_end=8, ) with CaptureLogger(logger) as cap_logger_2: ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88) assert ddpm.__class__ == DDPMScheduler assert ddpm.config.prediction_type == "sample" assert ddpm.config.beta_end == 8 assert ddpm_2.config.beta_start == 88 # no warning should be thrown assert cap_logger.out == "" assert cap_logger_2.out == "" def test_load_dpmsolver(self): logger = logging.get_logger("diffusers.configuration_utils") # 30 for warning logger.setLevel(30) with CaptureLogger(logger) as cap_logger: dpm = DPMSolverMultistepScheduler.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" ) assert dpm.__class__ == DPMSolverMultistepScheduler # no warning should be thrown assert cap_logger.out == "" def test_use_default_values(self): # let's first save a config that should be in the form # a=2, # b=5, # c=(2, 5), # d="for diffusion", # e=[1, 3], config = SampleObject() config_dict = {k: v for k, v in config.config.items() if not k.startswith("_")} # make sure that default config has all keys in `_use_default_values` assert set(config_dict.keys()) == set(config.config._use_default_values) with tempfile.TemporaryDirectory() as tmpdirname: config.save_config(tmpdirname) # now loading it with SampleObject2 should put f into `_use_default_values` config = SampleObject2.from_config(tmpdirname) assert "f" in config._use_default_values assert config.f == [1, 3] # now loading the config, should **NOT** use [1, 3] for `f`, but the default [1, 4] value # **BECAUSE** it is part of `config._use_default_values` new_config = SampleObject4.from_config(config.config) assert new_config.f == [5, 4] config.config._use_default_values.pop() new_config_2 = SampleObject4.from_config(config.config) assert new_config_2.f == [1, 3] # Nevertheless "e" should still be correctly loaded to [1, 3] from SampleObject2 instead of defaulting to [1, 5] assert new_config_2.e == [1, 3]