kumahiyo commited on
Commit
f8075a2
1 Parent(s): d634ab3

change model id

Browse files
Files changed (1) hide show
  1. main.py +3 -11
main.py CHANGED
@@ -9,12 +9,11 @@ from fastapi import FastAPI
9
  from fastapi.staticfiles import StaticFiles
10
  from pydantic import BaseModel
11
  from pydantic import Field
12
- from diffusers import StableDiffusionPipeline
13
  from transformers import (
14
  pipeline,
15
  MBart50TokenizerFast,
16
  MBartForConditionalGeneration,
17
- CLIPFeatureExtractor
18
  )
19
 
20
  app = FastAPI()
@@ -69,19 +68,16 @@ def draw(data: Data):
69
  # Add model for language translation
70
  trans_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
71
  trans_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device)
72
- feature_extractor = CLIPFeatureExtractor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
73
 
74
- model_id = 'stabilityai/stable-diffusion-2'
75
 
76
  #pipe = StableDiffusionPipeline.from_pretrained(model_id)
77
- pipe = StableDiffusionPipeline.from_pretrained(
78
  model_id,
79
  custom_pipeline="multilingual_stable_diffusion",
80
  detection_pipeline=language_detection_pipeline,
81
  translation_model=trans_model,
82
  translation_tokenizer=trans_tokenizer,
83
- feature_extractor=feature_extractor,
84
- safety_checker=null_safety,
85
  revision='fp16',
86
  torch_dtype=torch.float16
87
  )
@@ -119,7 +115,3 @@ def image_grid(imgs, rows, cols):
119
  for i, img in enumerate(imgs):
120
  grid.paste(img, box=(i%cols*w, i//cols*h))
121
  return grid
122
-
123
- def null_safety(images, **kwargs):
124
- return images, False
125
-
 
9
  from fastapi.staticfiles import StaticFiles
10
  from pydantic import BaseModel
11
  from pydantic import Field
12
+ from diffusers import DiffusionPipeline
13
  from transformers import (
14
  pipeline,
15
  MBart50TokenizerFast,
16
  MBartForConditionalGeneration,
 
17
  )
18
 
19
  app = FastAPI()
 
68
  # Add model for language translation
69
  trans_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
70
  trans_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device)
 
71
 
72
+ model_id = 'CompVis/stable-diffusion-v1-4'
73
 
74
  #pipe = StableDiffusionPipeline.from_pretrained(model_id)
75
+ pipe = DiffusionPipeline.from_pretrained(
76
  model_id,
77
  custom_pipeline="multilingual_stable_diffusion",
78
  detection_pipeline=language_detection_pipeline,
79
  translation_model=trans_model,
80
  translation_tokenizer=trans_tokenizer,
 
 
81
  revision='fp16',
82
  torch_dtype=torch.float16
83
  )
 
115
  for i, img in enumerate(imgs):
116
  grid.paste(img, box=(i%cols*w, i//cols*h))
117
  return grid