poke-lora / tests /pipelines /test_pipeline_utils.py
stillerman's picture
End of training
8206b09
import unittest
from diffusers.pipelines.pipeline_utils import is_safetensors_compatible
class IsSafetensorsCompatibleTests(unittest.TestCase):
def test_all_is_compatible(self):
filenames = [
"safety_checker/pytorch_model.bin",
"safety_checker/model.safetensors",
"vae/diffusion_pytorch_model.bin",
"vae/diffusion_pytorch_model.safetensors",
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
def test_diffusers_model_is_compatible(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
def test_diffusers_model_is_not_compatible(self):
filenames = [
"safety_checker/pytorch_model.bin",
"safety_checker/model.safetensors",
"vae/diffusion_pytorch_model.bin",
"vae/diffusion_pytorch_model.safetensors",
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
"unet/diffusion_pytorch_model.bin",
# Removed: 'unet/diffusion_pytorch_model.safetensors',
]
self.assertFalse(is_safetensors_compatible(filenames))
def test_transformer_model_is_compatible(self):
filenames = [
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
def test_transformer_model_is_not_compatible(self):
filenames = [
"safety_checker/pytorch_model.bin",
"safety_checker/model.safetensors",
"vae/diffusion_pytorch_model.bin",
"vae/diffusion_pytorch_model.safetensors",
"text_encoder/pytorch_model.bin",
# Removed: 'text_encoder/model.safetensors',
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames))
def test_all_is_compatible_variant(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
def test_diffusers_model_is_compatible_variant(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
def test_diffusers_model_is_compatible_variant_partial(self):
# pass variant but use the non-variant filenames
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
def test_diffusers_model_is_not_compatible_variant(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
"unet/diffusion_pytorch_model.fp16.bin",
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
]
variant = "fp16"
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
def test_transformer_model_is_compatible_variant(self):
filenames = [
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
def test_transformer_model_is_compatible_variant_partial(self):
# pass variant but use the non-variant filenames
filenames = [
"text_encoder/pytorch_model.bin",
"text_encoder/model.safetensors",
]
variant = "fp16"
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
def test_transformer_model_is_not_compatible_variant(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
# 'text_encoder/model.fp16.safetensors',
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
variant = "fp16"
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))