Spaces:
Sleeping
Sleeping
kumahiyo
commited on
Commit
•
f8075a2
1
Parent(s):
d634ab3
change model id
Browse files
main.py
CHANGED
@@ -9,12 +9,11 @@ from fastapi import FastAPI
|
|
9 |
from fastapi.staticfiles import StaticFiles
|
10 |
from pydantic import BaseModel
|
11 |
from pydantic import Field
|
12 |
-
from diffusers import
|
13 |
from transformers import (
|
14 |
pipeline,
|
15 |
MBart50TokenizerFast,
|
16 |
MBartForConditionalGeneration,
|
17 |
-
CLIPFeatureExtractor
|
18 |
)
|
19 |
|
20 |
app = FastAPI()
|
@@ -69,19 +68,16 @@ def draw(data: Data):
|
|
69 |
# Add model for language translation
|
70 |
trans_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
|
71 |
trans_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device)
|
72 |
-
feature_extractor = CLIPFeatureExtractor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
|
73 |
|
74 |
-
model_id = '
|
75 |
|
76 |
#pipe = StableDiffusionPipeline.from_pretrained(model_id)
|
77 |
-
pipe =
|
78 |
model_id,
|
79 |
custom_pipeline="multilingual_stable_diffusion",
|
80 |
detection_pipeline=language_detection_pipeline,
|
81 |
translation_model=trans_model,
|
82 |
translation_tokenizer=trans_tokenizer,
|
83 |
-
feature_extractor=feature_extractor,
|
84 |
-
safety_checker=null_safety,
|
85 |
revision='fp16',
|
86 |
torch_dtype=torch.float16
|
87 |
)
|
@@ -119,7 +115,3 @@ def image_grid(imgs, rows, cols):
|
|
119 |
for i, img in enumerate(imgs):
|
120 |
grid.paste(img, box=(i%cols*w, i//cols*h))
|
121 |
return grid
|
122 |
-
|
123 |
-
def null_safety(images, **kwargs):
|
124 |
-
return images, False
|
125 |
-
|
|
|
9 |
from fastapi.staticfiles import StaticFiles
|
10 |
from pydantic import BaseModel
|
11 |
from pydantic import Field
|
12 |
+
from diffusers import DiffusionPipeline
|
13 |
from transformers import (
|
14 |
pipeline,
|
15 |
MBart50TokenizerFast,
|
16 |
MBartForConditionalGeneration,
|
|
|
17 |
)
|
18 |
|
19 |
app = FastAPI()
|
|
|
68 |
# Add model for language translation
|
69 |
trans_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
|
70 |
trans_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device)
|
|
|
71 |
|
72 |
+
model_id = 'CompVis/stable-diffusion-v1-4'
|
73 |
|
74 |
#pipe = StableDiffusionPipeline.from_pretrained(model_id)
|
75 |
+
pipe = DiffusionPipeline.from_pretrained(
|
76 |
model_id,
|
77 |
custom_pipeline="multilingual_stable_diffusion",
|
78 |
detection_pipeline=language_detection_pipeline,
|
79 |
translation_model=trans_model,
|
80 |
translation_tokenizer=trans_tokenizer,
|
|
|
|
|
81 |
revision='fp16',
|
82 |
torch_dtype=torch.float16
|
83 |
)
|
|
|
115 |
for i, img in enumerate(imgs):
|
116 |
grid.paste(img, box=(i%cols*w, i//cols*h))
|
117 |
return grid
|
|
|
|
|
|
|
|