yujiepan commited on
Commit
cf12865
·
verified ·
1 Parent(s): 176529b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +115 -1
README.md CHANGED
@@ -21,4 +21,118 @@ image = pipe(
21
  guidance_scale=7.0,
22
  ).images[0]
23
  image
24
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  guidance_scale=7.0,
22
  ).images[0]
23
  image
24
+ ```
25
+
26
+ ## Codes
27
+ ```python
28
+ import importlib
29
+
30
+ import torch
31
+ import transformers
32
+
33
+ import diffusers
34
+ import rich
35
+
36
+
37
+ def get_original_model_configs(pipeline_cls: type[diffusers.DiffusionPipeline], pipeline_id: str):
38
+ pipeline_config: dict[str, list[str]] = pipeline_cls.load_config(pipeline_id)
39
+ model_configs = {}
40
+
41
+ for subfolder, import_strings in pipeline_config.items():
42
+ if subfolder.startswith("_"):
43
+ continue
44
+ module = importlib.import_module(".".join(import_strings[:-1]))
45
+ cls = getattr(module, import_strings[-1])
46
+ if issubclass(cls, transformers.PreTrainedModel):
47
+ config_class: transformers.PretrainedConfig = cls.config_class
48
+ config = config_class.from_pretrained(pipeline_id, subfolder=subfolder)
49
+ model_configs[subfolder] = config
50
+ elif issubclass(cls, diffusers.ModelMixin) and issubclass(cls, diffusers.ConfigMixin):
51
+ config = cls.load_config(pipeline_id, subfolder=subfolder)
52
+ model_configs[subfolder] = config
53
+
54
+ return model_configs
55
+
56
+
57
+ def load_pipeline(pipeline_cls: type[diffusers.DiffusionPipeline], pipeline_id: str, model_configs: dict[str, dict]):
58
+ pipeline_config: dict[str, list[str]] = pipeline_cls.load_config(pipeline_id)
59
+ components = {}
60
+ for subfolder, import_strings in pipeline_config.items():
61
+ if subfolder.startswith("_"):
62
+ continue
63
+ module = importlib.import_module(".".join(import_strings[:-1]))
64
+ cls = getattr(module, import_strings[-1])
65
+ print(f"Loading:", ".".join(import_strings))
66
+ if issubclass(cls, transformers.PreTrainedModel):
67
+ config = model_configs[subfolder]
68
+ component = cls(config)
69
+ elif issubclass(cls, transformers.PreTrainedTokenizerBase):
70
+ component = cls.from_pretrained(pipeline_id, subfolder=subfolder)
71
+ elif issubclass(cls, diffusers.ModelMixin) and issubclass(cls, diffusers.ConfigMixin):
72
+ config = model_configs[subfolder]
73
+ component = cls.from_config(config)
74
+ elif issubclass(cls, diffusers.SchedulerMixin) and issubclass(cls, diffusers.ConfigMixin):
75
+ component = cls.from_pretrained(pipeline_id, subfolder=subfolder)
76
+ else:
77
+ raise (f"unknown {subfolder}: {import_strings}")
78
+ components[subfolder] = component
79
+ pipeline = pipeline_cls(**components)
80
+ return pipeline
81
+
82
+
83
+ def get_pipeline():
84
+ torch.manual_seed(42)
85
+ pipeline_id = "stabilityai/stable-diffusion-3-medium-diffusers"
86
+ pipeline_cls = diffusers.StableDiffusion3Pipeline
87
+ model_configs = get_original_model_configs(pipeline_cls, pipeline_id)
88
+ rich.print(model_configs)
89
+
90
+ HIDDEN_SIZE = 8
91
+
92
+ model_configs["text_encoder"].hidden_size = HIDDEN_SIZE
93
+ model_configs["text_encoder"].intermediate_size = HIDDEN_SIZE * 2
94
+ model_configs["text_encoder"].num_attention_heads = 2
95
+ model_configs["text_encoder"].num_hidden_layers = 2
96
+ model_configs["text_encoder"].projection_dim = HIDDEN_SIZE
97
+
98
+ model_configs["text_encoder_2"].hidden_size = HIDDEN_SIZE
99
+ model_configs["text_encoder_2"].intermediate_size = HIDDEN_SIZE * 2
100
+ model_configs["text_encoder_2"].num_attention_heads = 2
101
+ model_configs["text_encoder_2"].num_hidden_layers = 2
102
+ model_configs["text_encoder_2"].projection_dim = HIDDEN_SIZE
103
+
104
+ model_configs["text_encoder_3"].d_model = HIDDEN_SIZE
105
+ model_configs["text_encoder_3"].d_ff = HIDDEN_SIZE * 2
106
+ model_configs["text_encoder_3"].d_kv = HIDDEN_SIZE // 2
107
+ model_configs["text_encoder_3"].num_heads = 2
108
+ model_configs["text_encoder_3"].num_layers = 2
109
+
110
+ model_configs["transformer"]["num_layers"] = 2
111
+ model_configs["transformer"]["num_attention_heads"] = 2
112
+ model_configs["transformer"]["attention_head_dim"] = HIDDEN_SIZE // 2
113
+ model_configs["transformer"]["pooled_projection_dim"] = HIDDEN_SIZE * 2
114
+ model_configs["transformer"]["joint_attention_dim"] = HIDDEN_SIZE
115
+ model_configs["transformer"]["caption_projection_dim"] = HIDDEN_SIZE
116
+
117
+ model_configs["vae"]["layers_per_block"] = 1
118
+ model_configs["vae"]["block_out_channels"] = [HIDDEN_SIZE] * 4
119
+ model_configs["vae"]["norm_num_groups"] = 2
120
+ model_configs["vae"]["latent_channels"] = 16
121
+
122
+ pipeline = load_pipeline(pipeline_cls, pipeline_id, model_configs)
123
+ return pipeline
124
+
125
+
126
+ pipeline = get_pipeline()
127
+ image = pipeline(
128
+ "hello world",
129
+ negative_prompt="runtime error",
130
+ num_inference_steps=2,
131
+ guidance_scale=7.0,
132
+ ).images[0]
133
+
134
+
135
+ pipeline = pipeline.to(torch.float16)
136
+ pipeline.save_pretrained("/tmp/stable-diffusion-3-tiny-random")
137
+ pipeline.push_to_hub("yujiepan/stable-diffusion-3-tiny-random")
138
+ ```